[api] Speed up protobuf encode 17-20% with register-optimized write path (#15290)

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
J. Nick Koston
2026-04-06 12:42:18 -10:00
committed by GitHub
parent 5a14d6a4ad
commit 2b5ee69eb2
7 changed files with 1257 additions and 939 deletions
+2 -2
View File
@@ -1993,7 +1993,7 @@ bool APIConnection::send_message_(uint32_t payload_size, uint8_t message_type, M
size_t write_start = shared_buf.size();
shared_buf.resize(write_start + payload_size);
ProtoWriteBuffer buffer{&shared_buf, write_start};
encode_fn(msg, buffer);
encode_fn(msg, buffer PROTO_ENCODE_DEBUG_INIT(&shared_buf));
return this->send_buffer(ProtoWriteBuffer{&shared_buf}, message_type);
}
// Encodes a message to the buffer and returns the total number of bytes used,
@@ -2034,7 +2034,7 @@ uint16_t APIConnection::encode_to_buffer(uint32_t calculated_size, MessageEncode
shared_buf.resize(shared_buf.size() + to_add);
ProtoWriteBuffer buffer{&shared_buf, shared_buf.size() - calculated_size};
encode_fn(msg, buffer);
encode_fn(msg, buffer PROTO_ENCODE_DEBUG_INIT(&shared_buf));
// Return total size (header + payload + footer)
return static_cast<uint16_t>(total_calculated_size);
+4 -2
View File
@@ -324,7 +324,7 @@ class APIConnection final : public APIServerConnectionBase {
void on_no_setup_connection();
// Function pointer type for type-erased message encoding
using MessageEncodeFn = void (*)(const void *, ProtoWriteBuffer &);
using MessageEncodeFn = uint8_t *(*) (const void *, ProtoWriteBuffer &PROTO_ENCODE_DEBUG_PARAM);
// Function pointer type for type-erased size calculation
using CalculateSizeFn = uint32_t (*)(const void *);
@@ -403,7 +403,9 @@ class APIConnection final : public APIServerConnectionBase {
}
// Shared no-op encode thunk for empty messages (ESTIMATED_SIZE == 0)
static void encode_msg_noop(const void *, ProtoWriteBuffer &) {}
static uint8_t *encode_msg_noop(const void *, ProtoWriteBuffer &buf PROTO_ENCODE_DEBUG_PARAM) {
return buf.get_pos();
}
// Non-template buffer management for send_message
bool send_message_(uint32_t payload_size, uint8_t message_type, MessageEncodeFn encode_fn, const void *msg);
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+12 -5
View File
@@ -145,14 +145,15 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size
// [tag][v1][v2][body ..... body]
// ^-- pos_ = element end, within buffer
void ProtoWriteBuffer::encode_sub_message(uint32_t field_id, const void *value,
void (*encode_fn)(const void *, ProtoWriteBuffer &)) {
uint8_t *(*encode_fn)(const void *,
ProtoWriteBuffer &PROTO_ENCODE_DEBUG_PARAM)) {
this->encode_field_raw(field_id, 2);
// Reserve 1 byte for length varint (optimistic: submessage < 128 bytes)
uint8_t *len_pos = this->pos_;
this->debug_check_bounds_(1);
this->pos_++;
uint8_t *body_start = this->pos_;
encode_fn(value, *this);
this->pos_ = encode_fn(value, *this PROTO_ENCODE_DEBUG_INIT(this->buffer_));
uint32_t body_size = static_cast<uint32_t>(this->pos_ - body_start);
if (body_size < 128) [[likely]] {
// Common case: 1-byte varint, just backpatch
@@ -173,22 +174,27 @@ void ProtoWriteBuffer::encode_sub_message(uint32_t field_id, const void *value,
// Non-template core for encode_optional_sub_message.
void ProtoWriteBuffer::encode_optional_sub_message(uint32_t field_id, uint32_t nested_size, const void *value,
void (*encode_fn)(const void *, ProtoWriteBuffer &)) {
uint8_t *(*encode_fn)(const void *,
ProtoWriteBuffer &PROTO_ENCODE_DEBUG_PARAM)) {
if (nested_size == 0)
return;
this->encode_field_raw(field_id, 2);
this->encode_varint_raw(nested_size);
#ifdef ESPHOME_DEBUG_API
uint8_t *start = this->pos_;
encode_fn(value, *this);
this->pos_ = encode_fn(value, *this PROTO_ENCODE_DEBUG_INIT(this->buffer_));
if (static_cast<uint32_t>(this->pos_ - start) != nested_size)
this->debug_check_encode_size_(field_id, nested_size, this->pos_ - start);
#else
encode_fn(value, *this);
this->pos_ = encode_fn(value, *this PROTO_ENCODE_DEBUG_INIT(this->buffer_));
#endif
}
#ifdef ESPHOME_DEBUG_API
void proto_check_bounds_failed(const uint8_t *pos, size_t bytes, const uint8_t *end, const char *caller) {
ESP_LOGE(TAG, "Proto encode bounds check failed in %s: need %zu bytes, %td available", caller, bytes, end - pos);
abort();
}
void ProtoWriteBuffer::debug_check_bounds_(size_t bytes, const char *caller) {
if (this->pos_ + bytes > this->buffer_->data() + this->buffer_->size()) {
ESP_LOGE(TAG, "ProtoWriteBuffer bounds check failed in %s: bytes=%zu offset=%td buf_size=%zu", caller, bytes,
@@ -201,6 +207,7 @@ void ProtoWriteBuffer::debug_check_encode_size_(uint32_t field_id, uint32_t expe
expected, actual);
abort();
}
#endif
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
+252 -155
View File
@@ -195,6 +195,26 @@ class Proto32Bit {
// NOTE: Proto64Bit class removed - wire type 1 (64-bit fixed) not supported
// Debug bounds checking for proto encode functions.
// In debug mode (ESPHOME_DEBUG_API), an extra end-of-buffer pointer is threaded
// through the entire encode chain. In production, these expand to nothing.
#ifdef ESPHOME_DEBUG_API
#define PROTO_ENCODE_DEBUG_PARAM , uint8_t *proto_debug_end_
#define PROTO_ENCODE_DEBUG_ARG , proto_debug_end_
#define PROTO_ENCODE_DEBUG_INIT(buf) , (buf)->data() + (buf)->size()
#define PROTO_ENCODE_CHECK_BOUNDS(pos, n) \
do { \
if ((pos) + (n) > proto_debug_end_) \
proto_check_bounds_failed(pos, n, proto_debug_end_, __builtin_FUNCTION()); \
} while (0)
void proto_check_bounds_failed(const uint8_t *pos, size_t bytes, const uint8_t *end, const char *caller);
#else
#define PROTO_ENCODE_DEBUG_PARAM
#define PROTO_ENCODE_DEBUG_ARG
#define PROTO_ENCODE_DEBUG_INIT(buf)
#define PROTO_ENCODE_CHECK_BOUNDS(pos, n)
#endif
class ProtoWriteBuffer {
public:
ProtoWriteBuffer(APIBuffer *buffer) : buffer_(buffer), pos_(buffer->data() + buffer->size()) {}
@@ -207,15 +227,6 @@ class ProtoWriteBuffer {
}
this->encode_varint_raw_slow_(value);
}
void encode_varint_raw_64(uint64_t value) {
while (value > 0x7F) {
this->debug_check_bounds_(1);
*this->pos_++ = static_cast<uint8_t>(value | 0x80);
value >>= 7;
}
this->debug_check_bounds_(1);
*this->pos_++ = static_cast<uint8_t>(value);
}
/**
* Encode a field key (tag/wire type combination).
*
@@ -229,123 +240,6 @@ class ProtoWriteBuffer {
* Following https://protobuf.dev/programming-guides/encoding/#structure
*/
void encode_field_raw(uint32_t field_id, uint32_t type) { this->encode_varint_raw((field_id << 3) | type); }
/// Write a single precomputed tag byte. Tag must be < 128.
inline void write_raw_byte(uint8_t b) ESPHOME_ALWAYS_INLINE {
this->debug_check_bounds_(1);
*this->pos_++ = b;
}
/// Write raw bytes to the buffer (no tag, no length prefix).
inline void encode_raw(const void *data, size_t len) ESPHOME_ALWAYS_INLINE {
this->debug_check_bounds_(len);
std::memcpy(this->pos_, data, len);
this->pos_ += len;
}
/// Write a precomputed tag byte + 32-bit value in one operation.
/// Tag must be a single-byte varint (< 128). No zero check.
inline void write_tag_and_fixed32(uint8_t tag, uint32_t value) ESPHOME_ALWAYS_INLINE {
this->debug_check_bounds_(5);
this->pos_[0] = tag;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(this->pos_ + 1, &value, 4);
#else
this->pos_[1] = static_cast<uint8_t>(value & 0xFF);
this->pos_[2] = static_cast<uint8_t>((value >> 8) & 0xFF);
this->pos_[3] = static_cast<uint8_t>((value >> 16) & 0xFF);
this->pos_[4] = static_cast<uint8_t>((value >> 24) & 0xFF);
#endif
this->pos_ += 5;
}
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
if (len == 0 && !force)
return;
this->encode_field_raw(field_id, 2); // type 2: Length-delimited string
this->encode_varint_raw(len);
// Direct memcpy into pre-sized buffer — avoids push_back() per-byte capacity checks
// and vector::insert() iterator overhead. ~10-11x faster for 16-32 byte strings.
this->debug_check_bounds_(len);
std::memcpy(this->pos_, string, len);
this->pos_ += len;
}
void encode_string(uint32_t field_id, const std::string &value, bool force = false) {
this->encode_string(field_id, value.data(), value.size(), force);
}
void encode_string(uint32_t field_id, const StringRef &ref, bool force = false) {
this->encode_string(field_id, ref.c_str(), ref.size(), force);
}
void encode_bytes(uint32_t field_id, const uint8_t *data, size_t len, bool force = false) {
this->encode_string(field_id, reinterpret_cast<const char *>(data), len, force);
}
void encode_uint32(uint32_t field_id, uint32_t value, bool force = false) {
if (value == 0 && !force)
return;
this->encode_field_raw(field_id, 0); // type 0: Varint - uint32
this->encode_varint_raw(value);
}
void encode_uint64(uint32_t field_id, uint64_t value, bool force = false) {
if (value == 0 && !force)
return;
this->encode_field_raw(field_id, 0); // type 0: Varint - uint64
this->encode_varint_raw_64(value);
}
void encode_bool(uint32_t field_id, bool value, bool force = false) {
if (!value && !force)
return;
this->encode_field_raw(field_id, 0); // type 0: Varint - bool
this->debug_check_bounds_(1);
*this->pos_++ = value ? 0x01 : 0x00;
}
void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
if (value == 0 && !force)
return;
this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32
this->debug_check_bounds_(4);
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
// Protobuf fixed32 is little-endian, so direct copy works
std::memcpy(this->pos_, &value, 4);
this->pos_ += 4;
#else
*this->pos_++ = (value >> 0) & 0xFF;
*this->pos_++ = (value >> 8) & 0xFF;
*this->pos_++ = (value >> 16) & 0xFF;
*this->pos_++ = (value >> 24) & 0xFF;
#endif
}
// NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally
// not supported to reduce overhead on embedded systems. All ESPHome devices are
// 32-bit microcontrollers where 64-bit operations are expensive. If 64-bit support
// is needed in the future, the necessary encoding/decoding functions must be added.
void encode_float(uint32_t field_id, float value, bool force = false) {
if (value == 0.0f && !force)
return;
union {
float value;
uint32_t raw;
} val{};
val.value = value;
this->encode_fixed32(field_id, val.raw);
}
void encode_int32(uint32_t field_id, int32_t value, bool force = false) {
if (value < 0) {
// negative int32 is always 10 byte long
this->encode_int64(field_id, value, force);
return;
}
this->encode_uint32(field_id, static_cast<uint32_t>(value), force);
}
void encode_int64(uint32_t field_id, int64_t value, bool force = false) {
this->encode_uint64(field_id, static_cast<uint64_t>(value), force);
}
void encode_sint32(uint32_t field_id, int32_t value, bool force = false) {
this->encode_uint32(field_id, encode_zigzag32(value), force);
}
void encode_sint64(uint32_t field_id, int64_t value, bool force = false) {
this->encode_uint64(field_id, encode_zigzag64(value), force);
}
/// Encode a packed repeated sint32 field (zero-copy from vector)
void encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values);
/// Single-pass encode for repeated submessage elements.
/// Thin template wrapper; all buffer work is in the non-template core.
template<typename T> void encode_sub_message(uint32_t field_id, const T &value);
@@ -353,12 +247,17 @@ class ProtoWriteBuffer {
/// Thin template wrapper; all buffer work is in the non-template core.
template<typename T> void encode_optional_sub_message(uint32_t field_id, const T &value);
// NOLINTBEGIN(readability-identifier-naming)
// Non-template core for encode_sub_message — backpatch approach.
void encode_sub_message(uint32_t field_id, const void *value, void (*encode_fn)(const void *, ProtoWriteBuffer &));
void encode_sub_message(uint32_t field_id, const void *value,
uint8_t *(*encode_fn)(const void *, ProtoWriteBuffer &PROTO_ENCODE_DEBUG_PARAM));
// Non-template core for encode_optional_sub_message.
void encode_optional_sub_message(uint32_t field_id, uint32_t nested_size, const void *value,
void (*encode_fn)(const void *, ProtoWriteBuffer &));
uint8_t *(*encode_fn)(const void *, ProtoWriteBuffer &PROTO_ENCODE_DEBUG_PARAM));
// NOLINTEND(readability-identifier-naming)
APIBuffer *get_buffer() const { return buffer_; }
uint8_t *get_pos() const { return pos_; }
void set_pos(uint8_t *pos) { pos_ = pos; }
protected:
// Slow path for encode_varint_raw values >= 128, outlined to keep fast path small
@@ -375,6 +274,211 @@ class ProtoWriteBuffer {
uint8_t *pos_;
};
// Varint encoding thresholds — used by both proto_encode_* free functions and ProtoSize.
constexpr uint32_t VARINT_MAX_1_BYTE = 1 << 7; // 128
constexpr uint32_t VARINT_MAX_2_BYTE = 1 << 14; // 16384
/// Static encode helpers for generated encode() functions.
/// Generated code hoists buffer.pos_ into a local uint8_t *__restrict__ pos,
/// then calls these methods which take pos by reference. No struct, no overhead.
/// For sub-messages, pos is synced back to buffer before the call and reloaded after.
class ProtoEncode {
public:
/// Write a multi-byte varint directly through a pos pointer.
template<typename T>
static inline void encode_varint_raw_loop(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, T value) {
do {
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = static_cast<uint8_t>(value | 0x80);
value >>= 7;
} while (value > 0x7F);
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = static_cast<uint8_t>(value);
}
static inline void ESPHOME_ALWAYS_INLINE encode_varint_raw(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
uint32_t value) {
if (value < VARINT_MAX_1_BYTE) [[likely]] {
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = static_cast<uint8_t>(value);
return;
}
encode_varint_raw_loop(pos PROTO_ENCODE_DEBUG_ARG, value);
}
/// Encode a varint that is expected to be 1-2 bytes (e.g. zigzag RSSI, small lengths).
static inline void ESPHOME_ALWAYS_INLINE encode_varint_raw_short(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
uint32_t value) {
if (value < VARINT_MAX_1_BYTE) [[likely]] {
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = static_cast<uint8_t>(value);
return;
}
if (value < VARINT_MAX_2_BYTE) [[likely]] {
PROTO_ENCODE_CHECK_BOUNDS(pos, 2);
*pos++ = static_cast<uint8_t>(value | 0x80);
*pos++ = static_cast<uint8_t>(value >> 7);
return;
}
encode_varint_raw_loop(pos PROTO_ENCODE_DEBUG_ARG, value);
}
static inline void ESPHOME_ALWAYS_INLINE encode_varint_raw_64(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
uint64_t value) {
if (value < VARINT_MAX_1_BYTE) [[likely]] {
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = static_cast<uint8_t>(value);
return;
}
encode_varint_raw_loop(pos PROTO_ENCODE_DEBUG_ARG, value);
}
static inline void ESPHOME_ALWAYS_INLINE encode_field_raw(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
uint32_t field_id, uint32_t type) {
encode_varint_raw(pos PROTO_ENCODE_DEBUG_ARG, (field_id << 3) | type);
}
/// Write a single precomputed tag byte. Tag must be < 128.
static inline void ESPHOME_ALWAYS_INLINE write_raw_byte(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
uint8_t b) {
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = b;
}
/// Write raw bytes to the buffer (no tag, no length prefix).
static inline void ESPHOME_ALWAYS_INLINE encode_raw(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
const void *data, size_t len) {
PROTO_ENCODE_CHECK_BOUNDS(pos, len);
std::memcpy(pos, data, len);
pos += len;
}
/// Write a precomputed tag byte + 32-bit value in one operation.
static inline void ESPHOME_ALWAYS_INLINE write_tag_and_fixed32(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
uint8_t tag, uint32_t value) {
PROTO_ENCODE_CHECK_BOUNDS(pos, 5);
pos[0] = tag;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(pos + 1, &value, 4);
#else
pos[1] = static_cast<uint8_t>(value & 0xFF);
pos[2] = static_cast<uint8_t>((value >> 8) & 0xFF);
pos[3] = static_cast<uint8_t>((value >> 16) & 0xFF);
pos[4] = static_cast<uint8_t>((value >> 24) & 0xFF);
#endif
pos += 5;
}
static inline void encode_string(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
const char *string, size_t len, bool force = false) {
if (len == 0 && !force)
return;
encode_field_raw(pos PROTO_ENCODE_DEBUG_ARG, field_id, 2); // type 2: Length-delimited string
if (len < VARINT_MAX_1_BYTE) [[likely]] {
PROTO_ENCODE_CHECK_BOUNDS(pos, 1 + len);
*pos++ = static_cast<uint8_t>(len);
} else {
encode_varint_raw_loop(pos PROTO_ENCODE_DEBUG_ARG, len);
PROTO_ENCODE_CHECK_BOUNDS(pos, len);
}
std::memcpy(pos, string, len);
pos += len;
}
static inline void encode_string(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
const std::string &value, bool force = false) {
encode_string(pos PROTO_ENCODE_DEBUG_ARG, field_id, value.data(), value.size(), force);
}
static inline void encode_string(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
const StringRef &ref, bool force = false) {
encode_string(pos PROTO_ENCODE_DEBUG_ARG, field_id, ref.c_str(), ref.size(), force);
}
static inline void encode_bytes(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
const uint8_t *data, size_t len, bool force = false) {
encode_string(pos PROTO_ENCODE_DEBUG_ARG, field_id, reinterpret_cast<const char *>(data), len, force);
}
static inline void encode_uint32(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
uint32_t value, bool force = false) {
if (value == 0 && !force)
return;
encode_field_raw(pos PROTO_ENCODE_DEBUG_ARG, field_id, 0);
encode_varint_raw(pos PROTO_ENCODE_DEBUG_ARG, value);
}
static inline void encode_uint64(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
uint64_t value, bool force = false) {
if (value == 0 && !force)
return;
encode_field_raw(pos PROTO_ENCODE_DEBUG_ARG, field_id, 0);
encode_varint_raw_64(pos PROTO_ENCODE_DEBUG_ARG, value);
}
static inline void encode_bool(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id, bool value,
bool force = false) {
if (!value && !force)
return;
encode_field_raw(pos PROTO_ENCODE_DEBUG_ARG, field_id, 0);
PROTO_ENCODE_CHECK_BOUNDS(pos, 1);
*pos++ = value ? 0x01 : 0x00;
}
static inline void encode_fixed32(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
uint32_t value, bool force = false) {
if (value == 0 && !force)
return;
encode_field_raw(pos PROTO_ENCODE_DEBUG_ARG, field_id, 5);
PROTO_ENCODE_CHECK_BOUNDS(pos, 4);
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
std::memcpy(pos, &value, 4);
pos += 4;
#else
*pos++ = (value >> 0) & 0xFF;
*pos++ = (value >> 8) & 0xFF;
*pos++ = (value >> 16) & 0xFF;
*pos++ = (value >> 24) & 0xFF;
#endif
}
// NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally
// not supported to reduce overhead on embedded systems. All ESPHome devices are
// 32-bit microcontrollers where 64-bit operations are expensive. If 64-bit support
// is needed in the future, the necessary encoding/decoding functions must be added.
static inline void encode_float(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id, float value,
bool force = false) {
if (value == 0.0f && !force)
return;
union {
float value;
uint32_t raw;
} val{};
val.value = value;
encode_fixed32(pos PROTO_ENCODE_DEBUG_ARG, field_id, val.raw);
}
static inline void encode_int32(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id, int32_t value,
bool force = false) {
if (value < 0) {
// negative int32 is always 10 byte long
encode_uint64(pos PROTO_ENCODE_DEBUG_ARG, field_id, static_cast<uint64_t>(value), force);
return;
}
encode_uint32(pos PROTO_ENCODE_DEBUG_ARG, field_id, static_cast<uint32_t>(value), force);
}
static inline void encode_int64(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id, int64_t value,
bool force = false) {
encode_uint64(pos PROTO_ENCODE_DEBUG_ARG, field_id, static_cast<uint64_t>(value), force);
}
static inline void encode_sint32(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
int32_t value, bool force = false) {
encode_uint32(pos PROTO_ENCODE_DEBUG_ARG, field_id, encode_zigzag32(value), force);
}
static inline void encode_sint64(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, uint32_t field_id,
int64_t value, bool force = false) {
encode_uint64(pos PROTO_ENCODE_DEBUG_ARG, field_id, encode_zigzag64(value), force);
}
/// Sub-message encoding: sync pos to buffer, delegate, get pos from return value.
template<typename T>
static inline void encode_sub_message(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM, ProtoWriteBuffer &buffer,
uint32_t field_id, const T &value) {
buffer.set_pos(pos);
buffer.encode_sub_message(field_id, value);
pos = buffer.get_pos();
}
template<typename T>
static inline void encode_optional_sub_message(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
ProtoWriteBuffer &buffer, uint32_t field_id, const T &value) {
buffer.set_pos(pos);
buffer.encode_optional_sub_message(field_id, value);
pos = buffer.get_pos();
}
};
#ifdef HAS_PROTO_MESSAGE_DUMP
/**
* Fixed-size buffer for message dumps - avoids heap allocation.
@@ -470,7 +574,7 @@ class ProtoMessage {
// All call sites use templates to preserve the concrete type, so virtual
// dispatch is not needed. This eliminates per-message vtable entries for
// encode/calculate_size, saving ~1.3 KB of flash across all message types.
void encode(ProtoWriteBuffer &buffer) const {}
uint8_t *encode(ProtoWriteBuffer &buffer PROTO_ENCODE_DEBUG_PARAM) const { return buffer.get_pos(); }
uint32_t calculate_size() const { return 0; }
#ifdef HAS_PROTO_MESSAGE_DUMP
virtual const char *dump_to(DumpBuffer &out) const = 0;
@@ -512,9 +616,10 @@ class ProtoDecodableMessage : public ProtoMessage {
class ProtoSize {
public:
// Varint encoding thresholds: values below each threshold fit in N bytes
static constexpr uint32_t VARINT_THRESHOLD_1_BYTE = 1 << 7; // 128
static constexpr uint32_t VARINT_THRESHOLD_2_BYTE = 1 << 14; // 16384
// Varint encoding thresholds — use namespace-level constants for 1/2 byte,
// class-level for 3/4 byte (only used within ProtoSize).
static constexpr uint32_t VARINT_THRESHOLD_1_BYTE = VARINT_MAX_1_BYTE;
static constexpr uint32_t VARINT_THRESHOLD_2_BYTE = VARINT_MAX_2_BYTE;
static constexpr uint32_t VARINT_THRESHOLD_3_BYTE = 1 << 21; // 2097152
static constexpr uint32_t VARINT_THRESHOLD_4_BYTE = 1 << 28; // 268435456
@@ -531,6 +636,17 @@ class ProtoSize {
return varint_wide(value);
return varint_slow(value);
}
/// Size of a varint expected to be 1-2 bytes (e.g. zigzag RSSI, small lengths).
/// Inlines both checks; falls back to slow path for 3+ bytes.
static constexpr inline uint32_t ESPHOME_ALWAYS_INLINE varint_short(uint32_t value) {
if (value < VARINT_THRESHOLD_1_BYTE) [[likely]]
return 1;
if (value < VARINT_THRESHOLD_2_BYTE) [[likely]]
return 2;
if (__builtin_is_constant_evaluated())
return varint_wide(value);
return varint_slow(value);
}
private:
// Slow path for varint >= 128, outlined to keep fast path small
@@ -645,10 +761,10 @@ class ProtoSize {
return value ? field_id_size + 4 : 0;
}
static constexpr uint32_t calc_sint32(uint32_t field_id_size, int32_t value) {
return value ? field_id_size + varint(encode_zigzag32(value)) : 0;
return value ? field_id_size + varint_short(encode_zigzag32(value)) : 0;
}
static constexpr inline uint32_t ESPHOME_ALWAYS_INLINE calc_sint32_force(uint32_t field_id_size, int32_t value) {
return field_id_size + varint(encode_zigzag32(value));
return field_id_size + varint_short(encode_zigzag32(value));
}
static constexpr uint32_t calc_int64(uint32_t field_id_size, int64_t value) {
return value ? field_id_size + varint(value) : 0;
@@ -691,28 +807,9 @@ class ProtoSize {
// Implementation of methods that depend on ProtoSize being fully defined
// Implementation of encode_packed_sint32 - must be after ProtoSize is defined
inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values) {
if (values.empty())
return;
// Calculate packed size
size_t packed_size = 0;
for (int value : values) {
packed_size += ProtoSize::varint(encode_zigzag32(value));
}
// Write tag (LENGTH_DELIMITED) + length + all zigzag-encoded values
this->encode_field_raw(field_id, WIRE_TYPE_LENGTH_DELIMITED);
this->encode_varint_raw(packed_size);
for (int value : values) {
this->encode_varint_raw(encode_zigzag32(value));
}
}
// Encode thunk — converts void* back to concrete type for direct encode() call
template<typename T> void proto_encode_msg(const void *msg, ProtoWriteBuffer &buf) {
static_cast<const T *>(msg)->encode(buf);
template<typename T> uint8_t *proto_encode_msg(const void *msg, ProtoWriteBuffer &buf PROTO_ENCODE_DEBUG_PARAM) {
return static_cast<const T *>(msg)->encode(buf PROTO_ENCODE_DEBUG_ARG);
}
// Thin template wrapper; delegates to non-template core in proto.cpp.
+66 -49
View File
@@ -232,29 +232,31 @@ class TypeInfo(ABC):
# eliminating the zero-check branch and encode_field_raw indirection.
# {value} is replaced with the actual field expression.
RAW_ENCODE_MAP: dict[str, str] = {
"encode_uint32": "buffer.encode_varint_raw({value});",
"encode_uint64": "buffer.encode_varint_raw_64({value});",
"encode_sint32": "buffer.encode_varint_raw(encode_zigzag32({value}));",
"encode_sint64": "buffer.encode_varint_raw_64(encode_zigzag64({value}));",
"encode_int64": "buffer.encode_varint_raw_64(static_cast<uint64_t>({value}));",
"encode_bool": "buffer.write_raw_byte({value} ? 0x01 : 0x00);",
"encode_uint32": "ProtoEncode::encode_varint_raw(pos, {value});",
"encode_uint64": "ProtoEncode::encode_varint_raw_64(pos, {value});",
"encode_sint32": "ProtoEncode::encode_varint_raw_short(pos, encode_zigzag32({value}));",
"encode_sint64": "ProtoEncode::encode_varint_raw_64(pos, encode_zigzag64({value}));",
"encode_int64": "ProtoEncode::encode_varint_raw_64(pos, static_cast<uint64_t>({value}));",
"encode_bool": "ProtoEncode::write_raw_byte(pos, {value} ? 0x01 : 0x00);",
}
# When max_value < 128, the varint is always 1 byte — use a direct byte write
RAW_ENCODE_SMALL_MAP: dict[str, str] = {
"encode_uint32": "buffer.write_raw_byte(static_cast<uint8_t>({value}));",
"encode_uint64": "buffer.write_raw_byte(static_cast<uint8_t>({value}));",
"encode_uint32": "ProtoEncode::write_raw_byte(pos, static_cast<uint8_t>({value}));",
"encode_uint64": "ProtoEncode::write_raw_byte(pos, static_cast<uint8_t>({value}));",
}
def _encode_with_precomputed_tag(self, value_expr: str) -> str | None:
"""Try to emit a precomputed-tag encode for a forced field.
"""Try to emit a precomputed-tag encode for a field.
For forced fields: emits raw tag + value unconditionally.
For non-forced fields with single-byte tag: emits inline zero-check
+ raw tag + value, avoiding an outlined function call.
Returns the raw encode string if the tag is a single byte and the
encode_func has a known raw equivalent, or None otherwise.
When max_value < 128, uses direct byte write instead of varint encoding.
"""
if not self.force:
return None
tag = self.calculate_tag()
if tag >= 128:
return None
@@ -263,10 +265,17 @@ class TypeInfo(ABC):
if max_val is not None and max_val < 128:
raw_expr = self.RAW_ENCODE_SMALL_MAP.get(self.encode_func)
if raw_expr is None:
# Only use RAW_ENCODE_MAP for forced fields or fields with max_value
if not self.force and max_val is None:
return None
raw_expr = self.RAW_ENCODE_MAP.get(self.encode_func)
if raw_expr is None:
return None
return f"buffer.write_raw_byte({tag});\n{raw_expr.format(value=value_expr)}"
body = f"ProtoEncode::write_raw_byte(pos, {tag});\n{raw_expr.format(value=value_expr)}"
if self.force:
return body
# Non-forced with max_value: inline zero-check + raw encode
return f"if ({value_expr}) {{\n {body}\n}}"
def _encode_bytes_with_precomputed_tag(
self, data_expr: str, len_expr: str, max_len: int | None = None
@@ -283,14 +292,14 @@ class TypeInfo(ABC):
return None
# When max_len < 128, length varint is always 1 byte
len_encode = (
f"buffer.write_raw_byte(static_cast<uint8_t>({len_expr}));"
f"ProtoEncode::write_raw_byte(pos, static_cast<uint8_t>({len_expr}));"
if max_len is not None and max_len < 128
else f"buffer.encode_varint_raw({len_expr});"
else f"ProtoEncode::encode_varint_raw(pos, {len_expr});"
)
return (
f"buffer.write_raw_byte({tag});\n"
f"ProtoEncode::write_raw_byte(pos, {tag});\n"
f"{len_encode}\n"
f"buffer.encode_raw({data_expr}, {len_expr});"
f"ProtoEncode::encode_raw(pos, {data_expr}, {len_expr});"
)
@property
@@ -298,8 +307,8 @@ class TypeInfo(ABC):
if result := self._encode_with_precomputed_tag(f"this->{self.field_name}"):
return result
if self.force:
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, true);"
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name}, true);"
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name});"
encode_func = None
@@ -657,10 +666,10 @@ class Fixed32Type(TypeInfo):
tag = self.calculate_tag()
if self.force and tag < 128:
# Emit combined tag+value write: precomputed tag + direct memcpy
return f"buffer.write_tag_and_fixed32({tag}, this->{self.field_name});"
return f"ProtoEncode::write_tag_and_fixed32(pos, {tag}, this->{self.field_name});"
if self.force:
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, true);"
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name}, true);"
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name});"
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
@@ -734,8 +743,8 @@ class StringType(TypeInfo):
):
return result
if self.force:
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_, true);"
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_);"
return f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name}_ref_, true);"
return f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name}_ref_);"
def dump(self, name):
# If name is 'it', this is a repeated field element - always use string
@@ -822,8 +831,8 @@ class MessageType(TypeInfo):
@property
def encode_content(self) -> str:
# encode_sub_message always encodes (uses backpatch), no force needed
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
# Sub-message encoding needs buffer for backpatch/sync
return f"ProtoEncode::{self.encode_func}(pos, buffer, {self.number}, this->{self.field_name});"
@property
def decode_length(self) -> str:
@@ -904,8 +913,8 @@ class BytesType(TypeInfo):
):
return result
if self.force:
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_, true);"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_, true);"
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
def dump(self, name: str) -> str:
ptr_dump = f"format_hex_pretty(this->{self.field_name}_ptr_, this->{self.field_name}_len_)"
@@ -1015,8 +1024,8 @@ class PointerToBytesBufferType(PointerToBufferTypeBase):
):
return result
if self.force:
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len);"
@property
def decode_length_content(self) -> str | None:
@@ -1068,10 +1077,10 @@ class PointerToStringBufferType(PointerToBufferTypeBase):
):
return result
if self.force:
return f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name}, true);"
return (
f"buffer.encode_string({self.number}, this->{self.field_name}, true);"
f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name});"
)
return f"buffer.encode_string({self.number}, this->{self.field_name});"
@property
def decode_length_content(self) -> str | None:
@@ -1240,8 +1249,8 @@ class FixedArrayBytesType(TypeInfo):
):
return result
if self.force:
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len);"
def dump(self, name: str) -> str:
return f"out.append(format_hex_pretty({name}, {name}_len));"
@@ -1323,8 +1332,8 @@ class EnumType(TypeInfo):
):
return result
if self.force:
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}), true);"
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}));"
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, static_cast<uint32_t>(this->{self.field_name}), true);"
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, static_cast<uint32_t>(this->{self.field_name}));"
def dump(self, name: str) -> str:
return f"out.append_p(proto_enum_to_string<{self.cpp_type}>({name}));"
@@ -1487,11 +1496,13 @@ class FixedArrayRepeatedType(TypeInfo):
def _encode_element(self, element: str) -> str:
"""Helper to generate encode statement for a single element."""
if isinstance(self._ti, EnumType):
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
return f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, static_cast<uint32_t>({element}), true);"
# Repeated message elements use encode_sub_message (force=true is default)
if isinstance(self._ti, MessageType):
return f"buffer.encode_sub_message({self.number}, {element});"
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
return f"ProtoEncode::encode_sub_message(pos, buffer, {self.number}, {element});"
return (
f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, {element}, true);"
)
@property
def cpp_type(self) -> str:
@@ -1815,11 +1826,13 @@ class RepeatedTypeInfo(TypeInfo):
def _encode_element_call(self, element: str) -> str:
"""Helper to generate encode call for a single element."""
if isinstance(self._ti, EnumType):
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
return f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, static_cast<uint32_t>({element}), true);"
# Repeated message elements use encode_sub_message (force=true is default)
if isinstance(self._ti, MessageType):
return f"buffer.encode_sub_message({self.number}, {element});"
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
return f"ProtoEncode::encode_sub_message(pos, buffer, {self.number}, {element});"
return (
f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, {element}, true);"
)
@property
def encode_content(self) -> str:
@@ -1828,7 +1841,7 @@ class RepeatedTypeInfo(TypeInfo):
# Special handling for const char* elements (when container_no_template contains "const char")
if "const char" in self._container_no_template:
o = f"for (const char *it : *this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, strlen(it), true);\n"
o += f" ProtoEncode::{self._ti.encode_func}(pos, {self.number}, it, strlen(it), true);\n"
else:
o = f"for (const auto &it : *this->{self.field_name}) {{\n"
o += f" {self._encode_element_call('it')}\n"
@@ -2403,15 +2416,19 @@ def build_message_type(
# Only generate encode method if this message needs encoding and has fields
if needs_encode and encode:
o = f"void {desc.name}::encode(ProtoWriteBuffer &buffer) const {{"
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
o += f" {encode[0]} }}\n"
else:
o += "\n"
o += indent("\n".join(encode)) + "\n"
# Add PROTO_ENCODE_DEBUG_ARG after pos in all proto_* calls
encode_debug = [
line.replace("(pos,", "(pos PROTO_ENCODE_DEBUG_ARG,") for line in encode
]
o = f"uint8_t *{desc.name}::encode(ProtoWriteBuffer &buffer PROTO_ENCODE_DEBUG_PARAM) const {{\n"
o += " uint8_t *__restrict__ pos = buffer.get_pos();\n"
o += indent("\n".join(encode_debug)) + "\n"
o += " return pos;\n"
o += "}\n"
cpp += o
prot = "void encode(ProtoWriteBuffer &buffer) const;"
prot = (
"uint8_t *encode(ProtoWriteBuffer &buffer PROTO_ENCODE_DEBUG_PARAM) const;"
)
public_content.append(prot)
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used