[api] Write protobuf encode output to pre-sized buffer directly (#14018)

This commit is contained in:
J. Nick Koston
2026-02-20 21:39:18 -06:00
committed by GitHub
parent f8f98bf428
commit f77da803c9
6 changed files with 286 additions and 240 deletions
+15 -20
View File
@@ -347,9 +347,7 @@ uint16_t APIConnection::encode_message_to_buffer(ProtoMessage &msg, uint8_t mess
#endif #endif
// Calculate size // Calculate size
ProtoSize size_calc; uint32_t calculated_size = msg.calculated_size();
msg.calculate_size(size_calc);
uint32_t calculated_size = size_calc.get_size();
// Cache frame sizes to avoid repeated virtual calls // Cache frame sizes to avoid repeated virtual calls
const uint8_t header_padding = conn->helper_->frame_header_padding(); const uint8_t header_padding = conn->helper_->frame_header_padding();
@@ -377,19 +375,14 @@ uint16_t APIConnection::encode_message_to_buffer(ProtoMessage &msg, uint8_t mess
shared_buf.resize(current_size + footer_size + header_padding); shared_buf.resize(current_size + footer_size + header_padding);
} }
// Encode directly into buffer // Pre-resize buffer to include payload, then encode through raw pointer
size_t size_before_encode = shared_buf.size(); size_t write_start = shared_buf.size();
msg.encode({&shared_buf}); shared_buf.resize(write_start + calculated_size);
ProtoWriteBuffer buffer{&shared_buf, write_start};
msg.encode(buffer);
// Calculate actual encoded size (not including header that was already added) // Return total size (header + payload + footer)
size_t actual_payload_size = shared_buf.size() - size_before_encode; return static_cast<uint16_t>(header_padding + calculated_size + footer_size);
// Return actual total size (header + actual payload + footer)
size_t actual_total_size = header_padding + actual_payload_size + footer_size;
// Verify that calculate_size() returned the correct value
assert(calculated_size == actual_payload_size);
return static_cast<uint16_t>(actual_total_size);
} }
#ifdef USE_BINARY_SENSOR #ifdef USE_BINARY_SENSOR
@@ -1854,12 +1847,14 @@ bool APIConnection::try_to_clear_buffer(bool log_out_of_space) {
return false; return false;
} }
bool APIConnection::send_message_impl(const ProtoMessage &msg, uint8_t message_type) { bool APIConnection::send_message_impl(const ProtoMessage &msg, uint8_t message_type) {
ProtoSize size; uint32_t payload_size = msg.calculated_size();
msg.calculate_size(size);
std::vector<uint8_t> &shared_buf = this->parent_->get_shared_buffer_ref(); std::vector<uint8_t> &shared_buf = this->parent_->get_shared_buffer_ref();
this->prepare_first_message_buffer(shared_buf, size.get_size()); this->prepare_first_message_buffer(shared_buf, payload_size);
msg.encode({&shared_buf}); size_t write_start = shared_buf.size();
return this->send_buffer({&shared_buf}, message_type); shared_buf.resize(write_start + payload_size);
ProtoWriteBuffer buffer{&shared_buf, write_start};
msg.encode(buffer);
return this->send_buffer(ProtoWriteBuffer{&shared_buf}, message_type);
} }
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) { bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) {
const bool is_log_message = (message_type == SubscribeLogsResponse::MESSAGE_TYPE); const bool is_log_message = (message_type == SubscribeLogsResponse::MESSAGE_TYPE);
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+15
View File
@@ -70,6 +70,21 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size
return count; return count;
} }
#ifdef ESPHOME_DEBUG_API
void ProtoWriteBuffer::debug_check_bounds_(size_t bytes, const char *caller) {
if (this->pos_ + bytes > this->buffer_->data() + this->buffer_->size()) {
ESP_LOGE(TAG, "ProtoWriteBuffer bounds check failed in %s: bytes=%zu offset=%td buf_size=%zu", caller, bytes,
this->pos_ - this->buffer_->data(), this->buffer_->size());
abort();
}
}
void ProtoWriteBuffer::debug_check_encode_size_(uint32_t field_id, uint32_t expected, ptrdiff_t actual) {
ESP_LOGE(TAG, "encode_message: size mismatch for field %" PRIu32 ": calculated=%" PRIu32 " actual=%td", field_id,
expected, actual);
abort();
}
#endif
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) { void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
const uint8_t *ptr = buffer; const uint8_t *ptr = buffer;
const uint8_t *end = buffer + length; const uint8_t *end = buffer + length;
+68 -40
View File
@@ -217,21 +217,26 @@ class Proto32Bit {
class ProtoWriteBuffer { class ProtoWriteBuffer {
public: public:
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {} ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer), pos_(buffer->data() + buffer->size()) {}
void write(uint8_t value) { this->buffer_->push_back(value); } ProtoWriteBuffer(std::vector<uint8_t> *buffer, size_t write_pos)
: buffer_(buffer), pos_(buffer->data() + write_pos) {}
void encode_varint_raw(uint32_t value) { void encode_varint_raw(uint32_t value) {
while (value > 0x7F) { while (value > 0x7F) {
this->buffer_->push_back(static_cast<uint8_t>(value | 0x80)); this->debug_check_bounds_(1);
*this->pos_++ = static_cast<uint8_t>(value | 0x80);
value >>= 7; value >>= 7;
} }
this->buffer_->push_back(static_cast<uint8_t>(value)); this->debug_check_bounds_(1);
*this->pos_++ = static_cast<uint8_t>(value);
} }
void encode_varint_raw_64(uint64_t value) { void encode_varint_raw_64(uint64_t value) {
while (value > 0x7F) { while (value > 0x7F) {
this->buffer_->push_back(static_cast<uint8_t>(value | 0x80)); this->debug_check_bounds_(1);
*this->pos_++ = static_cast<uint8_t>(value | 0x80);
value >>= 7; value >>= 7;
} }
this->buffer_->push_back(static_cast<uint8_t>(value)); this->debug_check_bounds_(1);
*this->pos_++ = static_cast<uint8_t>(value);
} }
/** /**
* Encode a field key (tag/wire type combination). * Encode a field key (tag/wire type combination).
@@ -245,23 +250,18 @@ class ProtoWriteBuffer {
* *
* Following https://protobuf.dev/programming-guides/encoding/#structure * Following https://protobuf.dev/programming-guides/encoding/#structure
*/ */
void encode_field_raw(uint32_t field_id, uint32_t type) { void encode_field_raw(uint32_t field_id, uint32_t type) { this->encode_varint_raw((field_id << 3) | type); }
uint32_t val = (field_id << 3) | (type & WIRE_TYPE_MASK);
this->encode_varint_raw(val);
}
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) { void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
if (len == 0 && !force) if (len == 0 && !force)
return; return;
this->encode_field_raw(field_id, 2); // type 2: Length-delimited string this->encode_field_raw(field_id, 2); // type 2: Length-delimited string
this->encode_varint_raw(len); this->encode_varint_raw(len);
// Direct memcpy into pre-sized buffer — avoids push_back() per-byte capacity checks
// Using resize + memcpy instead of insert provides significant performance improvement: // and vector::insert() iterator overhead. ~10-11x faster for 16-32 byte strings.
// ~10-11x faster for 16-32 byte strings, ~3x faster for 64-byte strings this->debug_check_bounds_(len);
// as it avoids iterator checks and potential element moves that insert performs std::memcpy(this->pos_, string, len);
size_t old_size = this->buffer_->size(); this->pos_ += len;
this->buffer_->resize(old_size + len);
std::memcpy(this->buffer_->data() + old_size, string, len);
} }
void encode_string(uint32_t field_id, const std::string &value, bool force = false) { void encode_string(uint32_t field_id, const std::string &value, bool force = false) {
this->encode_string(field_id, value.data(), value.size(), force); this->encode_string(field_id, value.data(), value.size(), force);
@@ -288,17 +288,26 @@ class ProtoWriteBuffer {
if (!value && !force) if (!value && !force)
return; return;
this->encode_field_raw(field_id, 0); // type 0: Varint - bool this->encode_field_raw(field_id, 0); // type 0: Varint - bool
this->buffer_->push_back(value ? 0x01 : 0x00); this->debug_check_bounds_(1);
*this->pos_++ = value ? 0x01 : 0x00;
} }
void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) { // noinline: 51 call sites; inlining causes net code growth vs a single out-of-line copy
__attribute__((noinline)) void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
if (value == 0 && !force) if (value == 0 && !force)
return; return;
this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32 this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32
this->write((value >> 0) & 0xFF); this->debug_check_bounds_(4);
this->write((value >> 8) & 0xFF); #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
this->write((value >> 16) & 0xFF); // Protobuf fixed32 is little-endian, so direct copy works
this->write((value >> 24) & 0xFF); std::memcpy(this->pos_, &value, 4);
this->pos_ += 4;
#else
*this->pos_++ = (value >> 0) & 0xFF;
*this->pos_++ = (value >> 8) & 0xFF;
*this->pos_++ = (value >> 16) & 0xFF;
*this->pos_++ = (value >> 24) & 0xFF;
#endif
} }
// NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally // NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally
// not supported to reduce overhead on embedded systems. All ESPHome devices are // not supported to reduce overhead on embedded systems. All ESPHome devices are
@@ -334,11 +343,20 @@ class ProtoWriteBuffer {
} }
/// Encode a packed repeated sint32 field (zero-copy from vector) /// Encode a packed repeated sint32 field (zero-copy from vector)
void encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values); void encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values);
void encode_message(uint32_t field_id, const ProtoMessage &value); /// Encode a nested message field (force=true for repeated, false for singular)
void encode_message(uint32_t field_id, const ProtoMessage &value, bool force = true);
std::vector<uint8_t> *get_buffer() const { return buffer_; } std::vector<uint8_t> *get_buffer() const { return buffer_; }
protected: protected:
#ifdef ESPHOME_DEBUG_API
void debug_check_bounds_(size_t bytes, const char *caller = __builtin_FUNCTION());
void debug_check_encode_size_(uint32_t field_id, uint32_t expected, ptrdiff_t actual);
#else
void debug_check_bounds_([[maybe_unused]] size_t bytes) {}
#endif
std::vector<uint8_t> *buffer_; std::vector<uint8_t> *buffer_;
uint8_t *pos_;
}; };
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
@@ -416,9 +434,11 @@ class ProtoMessage {
public: public:
virtual ~ProtoMessage() = default; virtual ~ProtoMessage() = default;
// Default implementation for messages with no fields // Default implementation for messages with no fields
virtual void encode(ProtoWriteBuffer buffer) const {} virtual void encode(ProtoWriteBuffer &buffer) const {}
// Default implementation for messages with no fields // Default implementation for messages with no fields
virtual void calculate_size(ProtoSize &size) const {} virtual void calculate_size(ProtoSize &size) const {}
// Convenience: calculate and return size directly (defined after ProtoSize)
uint32_t calculated_size() const;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
virtual const char *dump_to(DumpBuffer &out) const = 0; virtual const char *dump_to(DumpBuffer &out) const = 0;
virtual const char *message_name() const { return "unknown"; } virtual const char *message_name() const { return "unknown"; }
@@ -877,6 +897,14 @@ class ProtoSize {
} }
}; };
// Implementation of methods that depend on ProtoSize being fully defined
inline uint32_t ProtoMessage::calculated_size() const {
ProtoSize size;
this->calculate_size(size);
return size.get_size();
}
// Implementation of encode_packed_sint32 - must be after ProtoSize is defined // Implementation of encode_packed_sint32 - must be after ProtoSize is defined
inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values) { inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values) {
if (values.empty()) if (values.empty())
@@ -897,30 +925,30 @@ inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std:
} }
// Implementation of encode_message - must be after ProtoMessage is defined // Implementation of encode_message - must be after ProtoMessage is defined
inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value) { inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value, bool force) {
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
// Calculate the message size first // Calculate the message size first
ProtoSize msg_size; ProtoSize msg_size;
value.calculate_size(msg_size); value.calculate_size(msg_size);
uint32_t msg_length_bytes = msg_size.get_size(); uint32_t msg_length_bytes = msg_size.get_size();
// Calculate how many bytes the length varint needs // Skip empty singular messages (matches add_message_field which skips when nested_size == 0)
uint32_t varint_length_bytes = ProtoSize::varint(msg_length_bytes); // Repeated messages (force=true) are always encoded since an empty item is meaningful
if (msg_length_bytes == 0 && !force)
return;
// Reserve exact space for the length varint this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
size_t begin = this->buffer_->size();
this->buffer_->resize(this->buffer_->size() + varint_length_bytes);
// Write the length varint directly // Write the length varint directly through pos_
encode_varint_to_buffer(msg_length_bytes, this->buffer_->data() + begin); this->encode_varint_raw(msg_length_bytes);
// Now encode the message content - it will append to the buffer
value.encode(*this);
// Encode nested message - pos_ advances directly through the reference
#ifdef ESPHOME_DEBUG_API #ifdef ESPHOME_DEBUG_API
// Verify that the encoded size matches what we calculated uint8_t *start = this->pos_;
assert(this->buffer_->size() == begin + varint_length_bytes + msg_length_bytes); value.encode(*this);
if (static_cast<uint32_t>(this->pos_ - start) != msg_length_bytes)
this->debug_check_encode_size_(field_id, msg_length_bytes, this->pos_ - start);
#else
value.encode(*this);
#endif #endif
} }
+10 -2
View File
@@ -689,6 +689,14 @@ class MessageType(TypeInfo):
def encode_func(self) -> str: def encode_func(self) -> str:
return "encode_message" return "encode_message"
@property
def encode_content(self) -> str:
# Singular message fields pass force=false (skip empty messages)
# The default for encode_nested_message is force=true (for repeated fields)
return (
f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, false);"
)
@property @property
def decode_length(self) -> str: def decode_length(self) -> str:
# Override to return None for message types because we can't use template-based # Override to return None for message types because we can't use template-based
@@ -2186,7 +2194,7 @@ def build_message_type(
# Only generate encode method if this message needs encoding and has fields # Only generate encode method if this message needs encoding and has fields
if needs_encode and encode: if needs_encode and encode:
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{" o = f"void {desc.name}::encode(ProtoWriteBuffer &buffer) const {{"
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120: if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
o += f" {encode[0]} }}\n" o += f" {encode[0]} }}\n"
else: else:
@@ -2194,7 +2202,7 @@ def build_message_type(
o += indent("\n".join(encode)) + "\n" o += indent("\n".join(encode)) + "\n"
o += "}\n" o += "}\n"
cpp += o cpp += o
prot = "void encode(ProtoWriteBuffer buffer) const override;" prot = "void encode(ProtoWriteBuffer &buffer) const override;"
public_content.append(prot) public_content.append(prot)
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used # If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used