mirror of
https://github.com/esphome/esphome.git
synced 2026-05-31 17:06:40 +08:00
[api] Write protobuf encode output to pre-sized buffer directly (#14018)
This commit is contained in:
@@ -347,9 +347,7 @@ uint16_t APIConnection::encode_message_to_buffer(ProtoMessage &msg, uint8_t mess
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Calculate size
|
// Calculate size
|
||||||
ProtoSize size_calc;
|
uint32_t calculated_size = msg.calculated_size();
|
||||||
msg.calculate_size(size_calc);
|
|
||||||
uint32_t calculated_size = size_calc.get_size();
|
|
||||||
|
|
||||||
// Cache frame sizes to avoid repeated virtual calls
|
// Cache frame sizes to avoid repeated virtual calls
|
||||||
const uint8_t header_padding = conn->helper_->frame_header_padding();
|
const uint8_t header_padding = conn->helper_->frame_header_padding();
|
||||||
@@ -377,19 +375,14 @@ uint16_t APIConnection::encode_message_to_buffer(ProtoMessage &msg, uint8_t mess
|
|||||||
shared_buf.resize(current_size + footer_size + header_padding);
|
shared_buf.resize(current_size + footer_size + header_padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode directly into buffer
|
// Pre-resize buffer to include payload, then encode through raw pointer
|
||||||
size_t size_before_encode = shared_buf.size();
|
size_t write_start = shared_buf.size();
|
||||||
msg.encode({&shared_buf});
|
shared_buf.resize(write_start + calculated_size);
|
||||||
|
ProtoWriteBuffer buffer{&shared_buf, write_start};
|
||||||
|
msg.encode(buffer);
|
||||||
|
|
||||||
// Calculate actual encoded size (not including header that was already added)
|
// Return total size (header + payload + footer)
|
||||||
size_t actual_payload_size = shared_buf.size() - size_before_encode;
|
return static_cast<uint16_t>(header_padding + calculated_size + footer_size);
|
||||||
|
|
||||||
// Return actual total size (header + actual payload + footer)
|
|
||||||
size_t actual_total_size = header_padding + actual_payload_size + footer_size;
|
|
||||||
|
|
||||||
// Verify that calculate_size() returned the correct value
|
|
||||||
assert(calculated_size == actual_payload_size);
|
|
||||||
return static_cast<uint16_t>(actual_total_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_BINARY_SENSOR
|
#ifdef USE_BINARY_SENSOR
|
||||||
@@ -1854,12 +1847,14 @@ bool APIConnection::try_to_clear_buffer(bool log_out_of_space) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool APIConnection::send_message_impl(const ProtoMessage &msg, uint8_t message_type) {
|
bool APIConnection::send_message_impl(const ProtoMessage &msg, uint8_t message_type) {
|
||||||
ProtoSize size;
|
uint32_t payload_size = msg.calculated_size();
|
||||||
msg.calculate_size(size);
|
|
||||||
std::vector<uint8_t> &shared_buf = this->parent_->get_shared_buffer_ref();
|
std::vector<uint8_t> &shared_buf = this->parent_->get_shared_buffer_ref();
|
||||||
this->prepare_first_message_buffer(shared_buf, size.get_size());
|
this->prepare_first_message_buffer(shared_buf, payload_size);
|
||||||
msg.encode({&shared_buf});
|
size_t write_start = shared_buf.size();
|
||||||
return this->send_buffer({&shared_buf}, message_type);
|
shared_buf.resize(write_start + payload_size);
|
||||||
|
ProtoWriteBuffer buffer{&shared_buf, write_start};
|
||||||
|
msg.encode(buffer);
|
||||||
|
return this->send_buffer(ProtoWriteBuffer{&shared_buf}, message_type);
|
||||||
}
|
}
|
||||||
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) {
|
bool APIConnection::send_buffer(ProtoWriteBuffer buffer, uint8_t message_type) {
|
||||||
const bool is_log_message = (message_type == SubscribeLogsResponse::MESSAGE_TYPE);
|
const bool is_log_message = (message_type == SubscribeLogsResponse::MESSAGE_TYPE);
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -70,6 +70,21 @@ uint32_t ProtoDecodableMessage::count_repeated_field(const uint8_t *buffer, size
|
|||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ESPHOME_DEBUG_API
|
||||||
|
void ProtoWriteBuffer::debug_check_bounds_(size_t bytes, const char *caller) {
|
||||||
|
if (this->pos_ + bytes > this->buffer_->data() + this->buffer_->size()) {
|
||||||
|
ESP_LOGE(TAG, "ProtoWriteBuffer bounds check failed in %s: bytes=%zu offset=%td buf_size=%zu", caller, bytes,
|
||||||
|
this->pos_ - this->buffer_->data(), this->buffer_->size());
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void ProtoWriteBuffer::debug_check_encode_size_(uint32_t field_id, uint32_t expected, ptrdiff_t actual) {
|
||||||
|
ESP_LOGE(TAG, "encode_message: size mismatch for field %" PRIu32 ": calculated=%" PRIu32 " actual=%td", field_id,
|
||||||
|
expected, actual);
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
void ProtoDecodableMessage::decode(const uint8_t *buffer, size_t length) {
|
||||||
const uint8_t *ptr = buffer;
|
const uint8_t *ptr = buffer;
|
||||||
const uint8_t *end = buffer + length;
|
const uint8_t *end = buffer + length;
|
||||||
|
|||||||
@@ -217,21 +217,26 @@ class Proto32Bit {
|
|||||||
|
|
||||||
class ProtoWriteBuffer {
|
class ProtoWriteBuffer {
|
||||||
public:
|
public:
|
||||||
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {}
|
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer), pos_(buffer->data() + buffer->size()) {}
|
||||||
void write(uint8_t value) { this->buffer_->push_back(value); }
|
ProtoWriteBuffer(std::vector<uint8_t> *buffer, size_t write_pos)
|
||||||
|
: buffer_(buffer), pos_(buffer->data() + write_pos) {}
|
||||||
void encode_varint_raw(uint32_t value) {
|
void encode_varint_raw(uint32_t value) {
|
||||||
while (value > 0x7F) {
|
while (value > 0x7F) {
|
||||||
this->buffer_->push_back(static_cast<uint8_t>(value | 0x80));
|
this->debug_check_bounds_(1);
|
||||||
|
*this->pos_++ = static_cast<uint8_t>(value | 0x80);
|
||||||
value >>= 7;
|
value >>= 7;
|
||||||
}
|
}
|
||||||
this->buffer_->push_back(static_cast<uint8_t>(value));
|
this->debug_check_bounds_(1);
|
||||||
|
*this->pos_++ = static_cast<uint8_t>(value);
|
||||||
}
|
}
|
||||||
void encode_varint_raw_64(uint64_t value) {
|
void encode_varint_raw_64(uint64_t value) {
|
||||||
while (value > 0x7F) {
|
while (value > 0x7F) {
|
||||||
this->buffer_->push_back(static_cast<uint8_t>(value | 0x80));
|
this->debug_check_bounds_(1);
|
||||||
|
*this->pos_++ = static_cast<uint8_t>(value | 0x80);
|
||||||
value >>= 7;
|
value >>= 7;
|
||||||
}
|
}
|
||||||
this->buffer_->push_back(static_cast<uint8_t>(value));
|
this->debug_check_bounds_(1);
|
||||||
|
*this->pos_++ = static_cast<uint8_t>(value);
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Encode a field key (tag/wire type combination).
|
* Encode a field key (tag/wire type combination).
|
||||||
@@ -245,23 +250,18 @@ class ProtoWriteBuffer {
|
|||||||
*
|
*
|
||||||
* Following https://protobuf.dev/programming-guides/encoding/#structure
|
* Following https://protobuf.dev/programming-guides/encoding/#structure
|
||||||
*/
|
*/
|
||||||
void encode_field_raw(uint32_t field_id, uint32_t type) {
|
void encode_field_raw(uint32_t field_id, uint32_t type) { this->encode_varint_raw((field_id << 3) | type); }
|
||||||
uint32_t val = (field_id << 3) | (type & WIRE_TYPE_MASK);
|
|
||||||
this->encode_varint_raw(val);
|
|
||||||
}
|
|
||||||
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
|
void encode_string(uint32_t field_id, const char *string, size_t len, bool force = false) {
|
||||||
if (len == 0 && !force)
|
if (len == 0 && !force)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
this->encode_field_raw(field_id, 2); // type 2: Length-delimited string
|
this->encode_field_raw(field_id, 2); // type 2: Length-delimited string
|
||||||
this->encode_varint_raw(len);
|
this->encode_varint_raw(len);
|
||||||
|
// Direct memcpy into pre-sized buffer — avoids push_back() per-byte capacity checks
|
||||||
// Using resize + memcpy instead of insert provides significant performance improvement:
|
// and vector::insert() iterator overhead. ~10-11x faster for 16-32 byte strings.
|
||||||
// ~10-11x faster for 16-32 byte strings, ~3x faster for 64-byte strings
|
this->debug_check_bounds_(len);
|
||||||
// as it avoids iterator checks and potential element moves that insert performs
|
std::memcpy(this->pos_, string, len);
|
||||||
size_t old_size = this->buffer_->size();
|
this->pos_ += len;
|
||||||
this->buffer_->resize(old_size + len);
|
|
||||||
std::memcpy(this->buffer_->data() + old_size, string, len);
|
|
||||||
}
|
}
|
||||||
void encode_string(uint32_t field_id, const std::string &value, bool force = false) {
|
void encode_string(uint32_t field_id, const std::string &value, bool force = false) {
|
||||||
this->encode_string(field_id, value.data(), value.size(), force);
|
this->encode_string(field_id, value.data(), value.size(), force);
|
||||||
@@ -288,17 +288,26 @@ class ProtoWriteBuffer {
|
|||||||
if (!value && !force)
|
if (!value && !force)
|
||||||
return;
|
return;
|
||||||
this->encode_field_raw(field_id, 0); // type 0: Varint - bool
|
this->encode_field_raw(field_id, 0); // type 0: Varint - bool
|
||||||
this->buffer_->push_back(value ? 0x01 : 0x00);
|
this->debug_check_bounds_(1);
|
||||||
|
*this->pos_++ = value ? 0x01 : 0x00;
|
||||||
}
|
}
|
||||||
void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
|
// noinline: 51 call sites; inlining causes net code growth vs a single out-of-line copy
|
||||||
|
__attribute__((noinline)) void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
|
||||||
if (value == 0 && !force)
|
if (value == 0 && !force)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32
|
this->encode_field_raw(field_id, 5); // type 5: 32-bit fixed32
|
||||||
this->write((value >> 0) & 0xFF);
|
this->debug_check_bounds_(4);
|
||||||
this->write((value >> 8) & 0xFF);
|
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||||
this->write((value >> 16) & 0xFF);
|
// Protobuf fixed32 is little-endian, so direct copy works
|
||||||
this->write((value >> 24) & 0xFF);
|
std::memcpy(this->pos_, &value, 4);
|
||||||
|
this->pos_ += 4;
|
||||||
|
#else
|
||||||
|
*this->pos_++ = (value >> 0) & 0xFF;
|
||||||
|
*this->pos_++ = (value >> 8) & 0xFF;
|
||||||
|
*this->pos_++ = (value >> 16) & 0xFF;
|
||||||
|
*this->pos_++ = (value >> 24) & 0xFF;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
// NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally
|
// NOTE: Wire type 1 (64-bit fixed: double, fixed64, sfixed64) is intentionally
|
||||||
// not supported to reduce overhead on embedded systems. All ESPHome devices are
|
// not supported to reduce overhead on embedded systems. All ESPHome devices are
|
||||||
@@ -334,11 +343,20 @@ class ProtoWriteBuffer {
|
|||||||
}
|
}
|
||||||
/// Encode a packed repeated sint32 field (zero-copy from vector)
|
/// Encode a packed repeated sint32 field (zero-copy from vector)
|
||||||
void encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values);
|
void encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values);
|
||||||
void encode_message(uint32_t field_id, const ProtoMessage &value);
|
/// Encode a nested message field (force=true for repeated, false for singular)
|
||||||
|
void encode_message(uint32_t field_id, const ProtoMessage &value, bool force = true);
|
||||||
std::vector<uint8_t> *get_buffer() const { return buffer_; }
|
std::vector<uint8_t> *get_buffer() const { return buffer_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
#ifdef ESPHOME_DEBUG_API
|
||||||
|
void debug_check_bounds_(size_t bytes, const char *caller = __builtin_FUNCTION());
|
||||||
|
void debug_check_encode_size_(uint32_t field_id, uint32_t expected, ptrdiff_t actual);
|
||||||
|
#else
|
||||||
|
void debug_check_bounds_([[maybe_unused]] size_t bytes) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
std::vector<uint8_t> *buffer_;
|
std::vector<uint8_t> *buffer_;
|
||||||
|
uint8_t *pos_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||||
@@ -416,9 +434,11 @@ class ProtoMessage {
|
|||||||
public:
|
public:
|
||||||
virtual ~ProtoMessage() = default;
|
virtual ~ProtoMessage() = default;
|
||||||
// Default implementation for messages with no fields
|
// Default implementation for messages with no fields
|
||||||
virtual void encode(ProtoWriteBuffer buffer) const {}
|
virtual void encode(ProtoWriteBuffer &buffer) const {}
|
||||||
// Default implementation for messages with no fields
|
// Default implementation for messages with no fields
|
||||||
virtual void calculate_size(ProtoSize &size) const {}
|
virtual void calculate_size(ProtoSize &size) const {}
|
||||||
|
// Convenience: calculate and return size directly (defined after ProtoSize)
|
||||||
|
uint32_t calculated_size() const;
|
||||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||||
virtual const char *dump_to(DumpBuffer &out) const = 0;
|
virtual const char *dump_to(DumpBuffer &out) const = 0;
|
||||||
virtual const char *message_name() const { return "unknown"; }
|
virtual const char *message_name() const { return "unknown"; }
|
||||||
@@ -877,6 +897,14 @@ class ProtoSize {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Implementation of methods that depend on ProtoSize being fully defined
|
||||||
|
|
||||||
|
inline uint32_t ProtoMessage::calculated_size() const {
|
||||||
|
ProtoSize size;
|
||||||
|
this->calculate_size(size);
|
||||||
|
return size.get_size();
|
||||||
|
}
|
||||||
|
|
||||||
// Implementation of encode_packed_sint32 - must be after ProtoSize is defined
|
// Implementation of encode_packed_sint32 - must be after ProtoSize is defined
|
||||||
inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values) {
|
inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std::vector<int32_t> &values) {
|
||||||
if (values.empty())
|
if (values.empty())
|
||||||
@@ -897,30 +925,30 @@ inline void ProtoWriteBuffer::encode_packed_sint32(uint32_t field_id, const std:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Implementation of encode_message - must be after ProtoMessage is defined
|
// Implementation of encode_message - must be after ProtoMessage is defined
|
||||||
inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value) {
|
inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessage &value, bool force) {
|
||||||
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
|
|
||||||
|
|
||||||
// Calculate the message size first
|
// Calculate the message size first
|
||||||
ProtoSize msg_size;
|
ProtoSize msg_size;
|
||||||
value.calculate_size(msg_size);
|
value.calculate_size(msg_size);
|
||||||
uint32_t msg_length_bytes = msg_size.get_size();
|
uint32_t msg_length_bytes = msg_size.get_size();
|
||||||
|
|
||||||
// Calculate how many bytes the length varint needs
|
// Skip empty singular messages (matches add_message_field which skips when nested_size == 0)
|
||||||
uint32_t varint_length_bytes = ProtoSize::varint(msg_length_bytes);
|
// Repeated messages (force=true) are always encoded since an empty item is meaningful
|
||||||
|
if (msg_length_bytes == 0 && !force)
|
||||||
|
return;
|
||||||
|
|
||||||
// Reserve exact space for the length varint
|
this->encode_field_raw(field_id, 2); // type 2: Length-delimited message
|
||||||
size_t begin = this->buffer_->size();
|
|
||||||
this->buffer_->resize(this->buffer_->size() + varint_length_bytes);
|
|
||||||
|
|
||||||
// Write the length varint directly
|
// Write the length varint directly through pos_
|
||||||
encode_varint_to_buffer(msg_length_bytes, this->buffer_->data() + begin);
|
this->encode_varint_raw(msg_length_bytes);
|
||||||
|
|
||||||
// Now encode the message content - it will append to the buffer
|
|
||||||
value.encode(*this);
|
|
||||||
|
|
||||||
|
// Encode nested message - pos_ advances directly through the reference
|
||||||
#ifdef ESPHOME_DEBUG_API
|
#ifdef ESPHOME_DEBUG_API
|
||||||
// Verify that the encoded size matches what we calculated
|
uint8_t *start = this->pos_;
|
||||||
assert(this->buffer_->size() == begin + varint_length_bytes + msg_length_bytes);
|
value.encode(*this);
|
||||||
|
if (static_cast<uint32_t>(this->pos_ - start) != msg_length_bytes)
|
||||||
|
this->debug_check_encode_size_(field_id, msg_length_bytes, this->pos_ - start);
|
||||||
|
#else
|
||||||
|
value.encode(*this);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -689,6 +689,14 @@ class MessageType(TypeInfo):
|
|||||||
def encode_func(self) -> str:
|
def encode_func(self) -> str:
|
||||||
return "encode_message"
|
return "encode_message"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encode_content(self) -> str:
|
||||||
|
# Singular message fields pass force=false (skip empty messages)
|
||||||
|
# The default for encode_nested_message is force=true (for repeated fields)
|
||||||
|
return (
|
||||||
|
f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, false);"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def decode_length(self) -> str:
|
def decode_length(self) -> str:
|
||||||
# Override to return None for message types because we can't use template-based
|
# Override to return None for message types because we can't use template-based
|
||||||
@@ -2186,7 +2194,7 @@ def build_message_type(
|
|||||||
|
|
||||||
# Only generate encode method if this message needs encoding and has fields
|
# Only generate encode method if this message needs encoding and has fields
|
||||||
if needs_encode and encode:
|
if needs_encode and encode:
|
||||||
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{"
|
o = f"void {desc.name}::encode(ProtoWriteBuffer &buffer) const {{"
|
||||||
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
|
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
|
||||||
o += f" {encode[0]} }}\n"
|
o += f" {encode[0]} }}\n"
|
||||||
else:
|
else:
|
||||||
@@ -2194,7 +2202,7 @@ def build_message_type(
|
|||||||
o += indent("\n".join(encode)) + "\n"
|
o += indent("\n".join(encode)) + "\n"
|
||||||
o += "}\n"
|
o += "}\n"
|
||||||
cpp += o
|
cpp += o
|
||||||
prot = "void encode(ProtoWriteBuffer buffer) const override;"
|
prot = "void encode(ProtoWriteBuffer &buffer) const override;"
|
||||||
public_content.append(prot)
|
public_content.append(prot)
|
||||||
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used
|
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user