[api] Fix debug asserts in production code, encode_bool bug, and reduce flash overhead (#13936)

Co-authored-by: Jonathan Swoboda <154711427+swoboda1337@users.noreply.github.com>
This commit is contained in:
J. Nick Koston
2026-02-11 13:57:08 -06:00
committed by GitHub
parent c9c125aa8d
commit 483b7693e1
4 changed files with 40 additions and 62 deletions
@@ -295,9 +295,8 @@ APIError APIPlaintextFrameHelper::write_protobuf_messages(ProtoWriteBuffer buffe
buf_start[header_offset] = 0x00; // indicator buf_start[header_offset] = 0x00; // indicator
// Encode varints directly into buffer // Encode varints directly into buffer
ProtoVarInt(msg.payload_size).encode_to_buffer_unchecked(buf_start + header_offset + 1, size_varint_len); encode_varint_to_buffer(msg.payload_size, buf_start + header_offset + 1);
ProtoVarInt(msg.message_type) encode_varint_to_buffer(msg.message_type, buf_start + header_offset + 1 + size_varint_len);
.encode_to_buffer_unchecked(buf_start + header_offset + 1 + size_varint_len, type_varint_len);
// Add iovec for this message (header + payload) // Add iovec for this message (header + payload)
size_t msg_len = static_cast<size_t>(total_header_len + msg.payload_size); size_t msg_len = static_cast<size_t>(total_header_len + msg.payload_size);
+36 -59
View File
@@ -57,6 +57,16 @@ inline uint16_t count_packed_varints(const uint8_t *data, size_t len) {
return count; return count;
} }
/// Encode a varint directly into a pre-allocated buffer.
/// Caller must ensure buffer has space (use ProtoSize::varint() to calculate).
inline void encode_varint_to_buffer(uint32_t val, uint8_t *buffer) {
while (val > 0x7F) {
*buffer++ = static_cast<uint8_t>(val | 0x80);
val >>= 7;
}
*buffer = static_cast<uint8_t>(val);
}
/* /*
* StringRef Ownership Model for API Protocol Messages * StringRef Ownership Model for API Protocol Messages
* =================================================== * ===================================================
@@ -93,17 +103,17 @@ class ProtoVarInt {
ProtoVarInt() : value_(0) {} ProtoVarInt() : value_(0) {}
explicit ProtoVarInt(uint64_t value) : value_(value) {} explicit ProtoVarInt(uint64_t value) : value_(value) {}
/// Parse a varint from buffer. consumed must be a valid pointer (not null).
static optional<ProtoVarInt> parse(const uint8_t *buffer, uint32_t len, uint32_t *consumed) { static optional<ProtoVarInt> parse(const uint8_t *buffer, uint32_t len, uint32_t *consumed) {
if (len == 0) { #ifdef ESPHOME_DEBUG_API
if (consumed != nullptr) assert(consumed != nullptr);
*consumed = 0; #endif
if (len == 0)
return {}; return {};
}
// Most common case: single-byte varint (values 0-127) // Most common case: single-byte varint (values 0-127)
if ((buffer[0] & 0x80) == 0) { if ((buffer[0] & 0x80) == 0) {
if (consumed != nullptr) *consumed = 1;
*consumed = 1;
return ProtoVarInt(buffer[0]); return ProtoVarInt(buffer[0]);
} }
@@ -122,14 +132,11 @@ class ProtoVarInt {
result |= uint64_t(val & 0x7F) << uint64_t(bitpos); result |= uint64_t(val & 0x7F) << uint64_t(bitpos);
bitpos += 7; bitpos += 7;
if ((val & 0x80) == 0) { if ((val & 0x80) == 0) {
if (consumed != nullptr) *consumed = i + 1;
*consumed = i + 1;
return ProtoVarInt(result); return ProtoVarInt(result);
} }
} }
if (consumed != nullptr)
*consumed = 0;
return {}; // Incomplete or invalid varint return {}; // Incomplete or invalid varint
} }
@@ -153,50 +160,6 @@ class ProtoVarInt {
// with ZigZag encoding // with ZigZag encoding
return decode_zigzag64(this->value_); return decode_zigzag64(this->value_);
} }
/**
* Encode the varint value to a pre-allocated buffer without bounds checking.
*
* @param buffer The pre-allocated buffer to write the encoded varint to
* @param len The size of the buffer in bytes
*
* @note The caller is responsible for ensuring the buffer is large enough
* to hold the encoded value. Use ProtoSize::varint() to calculate
* the exact size needed before calling this method.
* @note No bounds checking is performed for performance reasons.
*/
void encode_to_buffer_unchecked(uint8_t *buffer, size_t len) {
uint64_t val = this->value_;
if (val <= 0x7F) {
buffer[0] = val;
return;
}
size_t i = 0;
while (val && i < len) {
uint8_t temp = val & 0x7F;
val >>= 7;
if (val) {
buffer[i++] = temp | 0x80;
} else {
buffer[i++] = temp;
}
}
}
void encode(std::vector<uint8_t> &out) {
uint64_t val = this->value_;
if (val <= 0x7F) {
out.push_back(val);
return;
}
while (val) {
uint8_t temp = val & 0x7F;
val >>= 7;
if (val) {
out.push_back(temp | 0x80);
} else {
out.push_back(temp);
}
}
}
protected: protected:
uint64_t value_; uint64_t value_;
@@ -256,8 +219,20 @@ class ProtoWriteBuffer {
public: public:
ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {} ProtoWriteBuffer(std::vector<uint8_t> *buffer) : buffer_(buffer) {}
void write(uint8_t value) { this->buffer_->push_back(value); } void write(uint8_t value) { this->buffer_->push_back(value); }
void encode_varint_raw(ProtoVarInt value) { value.encode(*this->buffer_); } void encode_varint_raw(uint32_t value) {
void encode_varint_raw(uint32_t value) { this->encode_varint_raw(ProtoVarInt(value)); } while (value > 0x7F) {
this->buffer_->push_back(static_cast<uint8_t>(value | 0x80));
value >>= 7;
}
this->buffer_->push_back(static_cast<uint8_t>(value));
}
void encode_varint_raw_64(uint64_t value) {
while (value > 0x7F) {
this->buffer_->push_back(static_cast<uint8_t>(value | 0x80));
value >>= 7;
}
this->buffer_->push_back(static_cast<uint8_t>(value));
}
/** /**
* Encode a field key (tag/wire type combination). * Encode a field key (tag/wire type combination).
* *
@@ -307,13 +282,13 @@ class ProtoWriteBuffer {
if (value == 0 && !force) if (value == 0 && !force)
return; return;
this->encode_field_raw(field_id, 0); // type 0: Varint - uint64 this->encode_field_raw(field_id, 0); // type 0: Varint - uint64
this->encode_varint_raw(ProtoVarInt(value)); this->encode_varint_raw_64(value);
} }
void encode_bool(uint32_t field_id, bool value, bool force = false) { void encode_bool(uint32_t field_id, bool value, bool force = false) {
if (!value && !force) if (!value && !force)
return; return;
this->encode_field_raw(field_id, 0); // type 0: Varint - bool this->encode_field_raw(field_id, 0); // type 0: Varint - bool
this->write(0x01); this->buffer_->push_back(value ? 0x01 : 0x00);
} }
void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) { void encode_fixed32(uint32_t field_id, uint32_t value, bool force = false) {
if (value == 0 && !force) if (value == 0 && !force)
@@ -938,13 +913,15 @@ inline void ProtoWriteBuffer::encode_message(uint32_t field_id, const ProtoMessa
this->buffer_->resize(this->buffer_->size() + varint_length_bytes); this->buffer_->resize(this->buffer_->size() + varint_length_bytes);
// Write the length varint directly // Write the length varint directly
ProtoVarInt(msg_length_bytes).encode_to_buffer_unchecked(this->buffer_->data() + begin, varint_length_bytes); encode_varint_to_buffer(msg_length_bytes, this->buffer_->data() + begin);
// Now encode the message content - it will append to the buffer // Now encode the message content - it will append to the buffer
value.encode(*this); value.encode(*this);
#ifdef ESPHOME_DEBUG_API
// Verify that the encoded size matches what we calculated // Verify that the encoded size matches what we calculated
assert(this->buffer_->size() == begin + varint_length_bytes + msg_length_bytes); assert(this->buffer_->size() == begin + varint_length_bytes + msg_length_bytes);
#endif
} }
// Implementation of decode_to_message - must be after ProtoDecodableMessage is defined // Implementation of decode_to_message - must be after ProtoDecodableMessage is defined
+1
View File
@@ -14,6 +14,7 @@
#define ESPHOME_PROJECT_VERSION_30 "v2" #define ESPHOME_PROJECT_VERSION_30 "v2"
#define ESPHOME_VARIANT "ESP32" #define ESPHOME_VARIANT "ESP32"
#define ESPHOME_DEBUG_SCHEDULER #define ESPHOME_DEBUG_SCHEDULER
#define ESPHOME_DEBUG_API
// Default threading model for static analysis (ESP32 is multi-threaded with atomics) // Default threading model for static analysis (ESP32 is multi-threaded with atomics)
#define ESPHOME_THREAD_MULTI_ATOMICS #define ESPHOME_THREAD_MULTI_ATOMICS
+1
View File
@@ -197,6 +197,7 @@ async def yaml_config(request: pytest.FixtureRequest, unused_tcp_port: int) -> s
" platformio_options:\n" " platformio_options:\n"
" build_flags:\n" " build_flags:\n"
' - "-DDEBUG" # Enable assert() statements\n' ' - "-DDEBUG" # Enable assert() statements\n'
' - "-DESPHOME_DEBUG_API" # Enable API protocol asserts\n'
' - "-g" # Add debug symbols', ' - "-g" # Add debug symbols',
) )