Reserve buffer space to avoid frequent realloc when generating protobuf messages (#8707)

This commit is contained in:
J. Nick Koston
2025-05-07 21:56:54 -05:00
committed by GitHub
parent d60e1f02c0
commit 54ead9a6b4
7 changed files with 1705 additions and 7 deletions
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -1,5 +1,5 @@
// This file was automatically generated with a tool. // This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py // See script/api_protobuf/api_protobuf.py
#include "api_pb2_service.h" #include "api_pb2_service.h"
#include "esphome/core/log.h" #include "esphome/core/log.h"
+1 -1
View File
@@ -1,5 +1,5 @@
// This file was automatically generated with a tool. // This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py // See script/api_protobuf/api_protobuf.py
#pragma once #pragma once
#include "api_pb2.h" #include "api_pb2.h"
+361
View File
@@ -0,0 +1,361 @@
#pragma once
#include "proto.h"
#include <cstdint>
#include <string>
namespace esphome {
namespace api {
class ProtoSize {
public:
/**
* @brief ProtoSize class for Protocol Buffer serialization size calculation
*
* This class provides static methods to calculate the exact byte counts needed
* for encoding various Protocol Buffer field types. All methods are designed to be
* efficient for the common case where many fields have default values.
*
* Implements Protocol Buffer encoding size calculation according to:
* https://protobuf.dev/programming-guides/encoding/
*
* Key features:
* - Early-return optimization for zero/default values
* - Direct total_size updates to avoid unnecessary additions
* - Specialized handling for different field types according to protobuf spec
* - Templated helpers for repeated fields and messages
*/
/**
* @brief Calculates the size in bytes needed to encode a uint32_t value as a varint
*
* @param value The uint32_t value to calculate size for
* @return The number of bytes needed to encode the value
*/
static inline uint32_t varint(uint32_t value) {
// Optimized varint size calculation using leading zeros
// Each 7 bits requires one byte in the varint encoding
if (value < 128)
return 1; // 7 bits, common case for small values
// For larger values, count bytes needed based on the position of the highest bit set
if (value < 16384) {
return 2; // 14 bits
} else if (value < 2097152) {
return 3; // 21 bits
} else if (value < 268435456) {
return 4; // 28 bits
} else {
return 5; // 32 bits (maximum for uint32_t)
}
}
/**
* @brief Calculates the size in bytes needed to encode a uint64_t value as a varint
*
* @param value The uint64_t value to calculate size for
* @return The number of bytes needed to encode the value
*/
static inline uint32_t varint(uint64_t value) {
// Handle common case of values fitting in uint32_t (vast majority of use cases)
if (value <= UINT32_MAX) {
return varint(static_cast<uint32_t>(value));
}
// For larger values, determine size based on highest bit position
if (value < (1ULL << 35)) {
return 5; // 35 bits
} else if (value < (1ULL << 42)) {
return 6; // 42 bits
} else if (value < (1ULL << 49)) {
return 7; // 49 bits
} else if (value < (1ULL << 56)) {
return 8; // 56 bits
} else if (value < (1ULL << 63)) {
return 9; // 63 bits
} else {
return 10; // 64 bits (maximum for uint64_t)
}
}
/**
* @brief Calculates the size in bytes needed to encode an int32_t value as a varint
*
* Special handling is needed for negative values, which are sign-extended to 64 bits
* in Protocol Buffers, resulting in a 10-byte varint.
*
* @param value The int32_t value to calculate size for
* @return The number of bytes needed to encode the value
*/
static inline uint32_t varint(int32_t value) {
// Negative values are sign-extended to 64 bits in protocol buffers,
// which always results in a 10-byte varint for negative int32
if (value < 0) {
return 10; // Negative int32 is always 10 bytes long
}
// For non-negative values, use the uint32_t implementation
return varint(static_cast<uint32_t>(value));
}
/**
* @brief Calculates the size in bytes needed to encode an int64_t value as a varint
*
* @param value The int64_t value to calculate size for
* @return The number of bytes needed to encode the value
*/
static inline uint32_t varint(int64_t value) {
// For int64_t, we convert to uint64_t and calculate the size
// This works because the bit pattern determines the encoding size,
// and we've handled negative int32 values as a special case above
return varint(static_cast<uint64_t>(value));
}
/**
* @brief Calculates the size in bytes needed to encode a field ID and wire type
*
* @param field_id The field identifier
* @param type The wire type value (from the WireType enum in the protobuf spec)
* @return The number of bytes needed to encode the field ID and wire type
*/
static inline uint32_t field(uint32_t field_id, uint32_t type) {
uint32_t tag = (field_id << 3) | (type & 0b111);
return varint(tag);
}
/**
* @brief Common parameters for all add_*_field methods
*
* All add_*_field methods follow these common patterns:
*
* @param total_size Reference to the total message size to update
* @param field_id_size Pre-calculated size of the field ID in bytes
* @param value The value to calculate size for (type varies)
* @param force Whether to calculate size even if the value is default/zero/empty
*
* Each method follows this implementation pattern:
* 1. Skip calculation if value is default (0, false, empty) and not forced
* 2. Calculate the size based on the field's encoding rules
* 3. Add the field_id_size + calculated value size to total_size
*/
/**
* @brief Calculates and adds the size of an int32 field to the total message size
*/
static inline void add_int32_field(uint32_t &total_size, uint32_t field_id_size, int32_t value, bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// Calculate and directly add to total_size
if (value < 0) {
// Negative values are encoded as 10-byte varints in protobuf
total_size += field_id_size + 10;
} else {
// For non-negative values, use the standard varint size
total_size += field_id_size + varint(static_cast<uint32_t>(value));
}
}
/**
* @brief Calculates and adds the size of a uint32 field to the total message size
*/
static inline void add_uint32_field(uint32_t &total_size, uint32_t field_id_size, uint32_t value,
bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// Calculate and directly add to total_size
total_size += field_id_size + varint(value);
}
/**
* @brief Calculates and adds the size of a boolean field to the total message size
*/
static inline void add_bool_field(uint32_t &total_size, uint32_t field_id_size, bool value, bool force = false) {
// Skip calculation if value is false and not forced
if (!value && !force) {
return; // No need to update total_size
}
// Boolean fields always use 1 byte when true
total_size += field_id_size + 1;
}
/**
* @brief Calculates and adds the size of a fixed field to the total message size
*
* Fixed fields always take exactly N bytes (4 for fixed32/float, 8 for fixed64/double).
*
* @tparam NumBytes The number of bytes for this fixed field (4 or 8)
* @param is_nonzero Whether the value is non-zero
*/
template<uint32_t NumBytes>
static inline void add_fixed_field(uint32_t &total_size, uint32_t field_id_size, bool is_nonzero,
bool force = false) {
// Skip calculation if value is zero and not forced
if (!is_nonzero && !force) {
return; // No need to update total_size
}
// Fixed fields always take exactly NumBytes
total_size += field_id_size + NumBytes;
}
/**
* @brief Calculates and adds the size of an enum field to the total message size
*
* Enum fields are encoded as uint32 varints.
*/
static inline void add_enum_field(uint32_t &total_size, uint32_t field_id_size, uint32_t value, bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// Enums are encoded as uint32
total_size += field_id_size + varint(value);
}
/**
* @brief Calculates and adds the size of a sint32 field to the total message size
*
* Sint32 fields use ZigZag encoding, which is more efficient for negative values.
*/
static inline void add_sint32_field(uint32_t &total_size, uint32_t field_id_size, int32_t value, bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// ZigZag encoding for sint32: (n << 1) ^ (n >> 31)
uint32_t zigzag = (static_cast<uint32_t>(value) << 1) ^ (static_cast<uint32_t>(value >> 31));
total_size += field_id_size + varint(zigzag);
}
/**
* @brief Calculates and adds the size of an int64 field to the total message size
*/
static inline void add_int64_field(uint32_t &total_size, uint32_t field_id_size, int64_t value, bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// Calculate and directly add to total_size
total_size += field_id_size + varint(value);
}
/**
* @brief Calculates and adds the size of a uint64 field to the total message size
*/
static inline void add_uint64_field(uint32_t &total_size, uint32_t field_id_size, uint64_t value,
bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// Calculate and directly add to total_size
total_size += field_id_size + varint(value);
}
/**
* @brief Calculates and adds the size of a sint64 field to the total message size
*
* Sint64 fields use ZigZag encoding, which is more efficient for negative values.
*/
static inline void add_sint64_field(uint32_t &total_size, uint32_t field_id_size, int64_t value, bool force = false) {
// Skip calculation if value is zero and not forced
if (value == 0 && !force) {
return; // No need to update total_size
}
// ZigZag encoding for sint64: (n << 1) ^ (n >> 63)
uint64_t zigzag = (static_cast<uint64_t>(value) << 1) ^ (static_cast<uint64_t>(value >> 63));
total_size += field_id_size + varint(zigzag);
}
/**
* @brief Calculates and adds the size of a string/bytes field to the total message size
*/
static inline void add_string_field(uint32_t &total_size, uint32_t field_id_size, const std::string &str,
bool force = false) {
// Skip calculation if string is empty and not forced
if (str.empty() && !force) {
return; // No need to update total_size
}
// Calculate and directly add to total_size
const uint32_t str_size = static_cast<uint32_t>(str.size());
total_size += field_id_size + varint(str_size) + str_size;
}
/**
* @brief Calculates and adds the size of a nested message field to the total message size
*
* This helper function directly updates the total_size reference if the nested size
* is greater than zero or force is true.
*
* @param nested_size The pre-calculated size of the nested message
*/
static inline void add_message_field(uint32_t &total_size, uint32_t field_id_size, uint32_t nested_size,
bool force = false) {
// Skip calculation if nested message is empty and not forced
if (nested_size == 0 && !force) {
return; // No need to update total_size
}
// Calculate and directly add to total_size
// Field ID + length varint + nested message content
total_size += field_id_size + varint(nested_size) + nested_size;
}
/**
* @brief Calculates and adds the size of a nested message field to the total message size
*
* This templated version directly takes a message object, calculates its size internally,
* and updates the total_size reference. This eliminates the need for a temporary variable
* at the call site.
*
* @tparam MessageType The type of the nested message (inferred from parameter)
* @param message The nested message object
*/
template<typename MessageType>
static inline void add_message_object(uint32_t &total_size, uint32_t field_id_size, const MessageType &message,
bool force = false) {
uint32_t nested_size = 0;
message.calculate_size(nested_size);
// Use the base implementation with the calculated nested_size
add_message_field(total_size, field_id_size, nested_size, force);
}
/**
* @brief Calculates and adds the sizes of all messages in a repeated field to the total message size
*
* This helper processes a vector of message objects, calculating the size for each message
* and adding it to the total size.
*
* @tparam MessageType The type of the nested messages in the vector
* @param messages Vector of message objects
*/
template<typename MessageType>
static inline void add_repeated_message(uint32_t &total_size, uint32_t field_id_size,
const std::vector<MessageType> &messages) {
// Skip if the vector is empty
if (messages.empty()) {
return;
}
// For repeated fields, always use force=true
for (const auto &message : messages) {
add_message_object(total_size, field_id_size, message, true);
}
}
};
} // namespace api
} // namespace esphome
+11
View File
@@ -276,6 +276,7 @@ class ProtoMessage {
virtual ~ProtoMessage() = default; virtual ~ProtoMessage() = default;
virtual void encode(ProtoWriteBuffer buffer) const = 0; virtual void encode(ProtoWriteBuffer buffer) const = 0;
void decode(const uint8_t *buffer, size_t length); void decode(const uint8_t *buffer, size_t length);
virtual void calculate_size(uint32_t &total_size) const = 0;
#ifdef HAS_PROTO_MESSAGE_DUMP #ifdef HAS_PROTO_MESSAGE_DUMP
std::string dump() const; std::string dump() const;
virtual void dump_to(std::string &out) const = 0; virtual void dump_to(std::string &out) const = 0;
@@ -302,9 +303,19 @@ class ProtoService {
virtual bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) = 0; virtual bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) = 0;
virtual bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) = 0; virtual bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) = 0;
// Optimized method that pre-allocates buffer based on message size
template<class C> bool send_message_(const C &msg, uint32_t message_type) { template<class C> bool send_message_(const C &msg, uint32_t message_type) {
uint32_t msg_size = 0;
msg.calculate_size(msg_size);
// Create a pre-sized buffer
auto buffer = this->create_buffer(); auto buffer = this->create_buffer();
buffer.get_buffer()->reserve(msg_size);
// Encode message into the buffer
msg.encode(buffer); msg.encode(buffer);
// Send the buffer
return this->send_buffer(buffer, message_type); return this->send_buffer(buffer, message_type);
} }
}; };
+210 -3
View File
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import IntEnum
import os import os
from pathlib import Path from pathlib import Path
import re import re
@@ -10,11 +11,29 @@ import sys
from textwrap import dedent from textwrap import dedent
from typing import Any from typing import Any
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor import google.protobuf.descriptor_pb2 as descriptor
class WireType(IntEnum):
"""Protocol Buffer wire types as defined in the protobuf spec.
As specified in the Protocol Buffers encoding guide:
https://protobuf.dev/programming-guides/encoding/#structure
"""
VARINT = 0 # int32, int64, uint32, uint64, sint32, sint64, bool, enum
FIXED64 = 1 # fixed64, sfixed64, double
LENGTH_DELIMITED = 2 # string, bytes, embedded messages, packed repeated fields
START_GROUP = 3 # groups (deprecated)
END_GROUP = 4 # groups (deprecated)
FIXED32 = 5 # fixed32, sfixed32, float
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
"""Python 3 script to automatically generate C++ classes for ESPHome's native API. """Python 3 script to automatically generate C++ classes for ESPHome's native API.
It's pretty crappy spaghetti code, but it works. It's pretty crappy spaghetti code, but it works.
@@ -35,7 +54,7 @@ will be generated, they still need to be formatted
FILE_HEADER = """// This file was automatically generated with a tool. FILE_HEADER = """// This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py // See script/api_protobuf/api_protobuf.py
""" """
@@ -63,6 +82,11 @@ def camel_to_snake(name: str) -> str:
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def force_str(force: bool) -> str:
"""Convert a boolean force value to string format for C++ code."""
return str(force).lower()
class TypeInfo(ABC): class TypeInfo(ABC):
"""Base class for all type information.""" """Base class for all type information."""
@@ -99,6 +123,11 @@ class TypeInfo(ABC):
"""Check if the field is repeated.""" """Check if the field is repeated."""
return self._field.label == 3 return self._field.label == 3
@property
def wire_type(self) -> WireType:
"""Get the wire type for the field."""
raise NotImplementedError
@property @property
def cpp_type(self) -> str: def cpp_type(self) -> str:
raise NotImplementedError raise NotImplementedError
@@ -200,6 +229,35 @@ class TypeInfo(ABC):
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
"""Dump the value to the output.""" """Dump the value to the output."""
def calculate_field_id_size(self) -> int:
"""Calculates the size of a field ID in bytes.
Returns:
The number of bytes needed to encode the field ID
"""
# Calculate the tag by combining field_id and wire_type
tag = (self.number << 3) | (self.wire_type & 0b111)
# Calculate the varint size
if tag < 128:
return 1 # 7 bits
if tag < 16384:
return 2 # 14 bits
if tag < 2097152:
return 3 # 21 bits
if tag < 268435456:
return 4 # 28 bits
return 5 # 32 bits (maximum for uint32_t)
@abstractmethod
def get_size_calculation(self, name: str, force: bool = False) -> str:
"""Calculate the size needed for encoding this field.
Args:
name: The name of the field
force: Whether to force encoding the field even if it has a default value
"""
TYPE_INFO: dict[int, TypeInfo] = {} TYPE_INFO: dict[int, TypeInfo] = {}
@@ -221,12 +279,18 @@ class DoubleType(TypeInfo):
default_value = "0.0" default_value = "0.0"
decode_64bit = "value.as_double()" decode_64bit = "value.as_double()"
encode_func = "encode_double" encode_func = "encode_double"
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n' o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0.0, {force_str(force)});"
return o
@register_type(2) @register_type(2)
class FloatType(TypeInfo): class FloatType(TypeInfo):
@@ -234,12 +298,18 @@ class FloatType(TypeInfo):
default_value = "0.0f" default_value = "0.0f"
decode_32bit = "value.as_float()" decode_32bit = "value.as_float()"
encode_func = "encode_float" encode_func = "encode_float"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%g", {name});\n' o = f'sprintf(buffer, "%g", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0.0f, {force_str(force)});"
return o
@register_type(3) @register_type(3)
class Int64Type(TypeInfo): class Int64Type(TypeInfo):
@@ -247,12 +317,18 @@ class Int64Type(TypeInfo):
default_value = "0" default_value = "0"
decode_varint = "value.as_int64()" decode_varint = "value.as_int64()"
encode_func = "encode_int64" encode_func = "encode_int64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_int64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(4) @register_type(4)
class UInt64Type(TypeInfo): class UInt64Type(TypeInfo):
@@ -260,12 +336,18 @@ class UInt64Type(TypeInfo):
default_value = "0" default_value = "0"
decode_varint = "value.as_uint64()" decode_varint = "value.as_uint64()"
encode_func = "encode_uint64" encode_func = "encode_uint64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n' o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_uint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(5) @register_type(5)
class Int32Type(TypeInfo): class Int32Type(TypeInfo):
@@ -273,12 +355,18 @@ class Int32Type(TypeInfo):
default_value = "0" default_value = "0"
decode_varint = "value.as_int32()" decode_varint = "value.as_int32()"
encode_func = "encode_int32" encode_func = "encode_int32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_int32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(6) @register_type(6)
class Fixed64Type(TypeInfo): class Fixed64Type(TypeInfo):
@@ -286,12 +374,18 @@ class Fixed64Type(TypeInfo):
default_value = "0" default_value = "0"
decode_64bit = "value.as_fixed64()" decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64" encode_func = "encode_fixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%llu", {name});\n' o = f'sprintf(buffer, "%llu", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(7) @register_type(7)
class Fixed32Type(TypeInfo): class Fixed32Type(TypeInfo):
@@ -299,12 +393,18 @@ class Fixed32Type(TypeInfo):
default_value = "0" default_value = "0"
decode_32bit = "value.as_fixed32()" decode_32bit = "value.as_fixed32()"
encode_func = "encode_fixed32" encode_func = "encode_fixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n' o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(8) @register_type(8)
class BoolType(TypeInfo): class BoolType(TypeInfo):
@@ -312,11 +412,17 @@ class BoolType(TypeInfo):
default_value = "false" default_value = "false"
decode_varint = "value.as_bool()" decode_varint = "value.as_bool()"
encode_func = "encode_bool" encode_func = "encode_bool"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f"out.append(YESNO({name}));" o = f"out.append(YESNO({name}));"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_bool_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(9) @register_type(9)
class StringType(TypeInfo): class StringType(TypeInfo):
@@ -326,11 +432,17 @@ class StringType(TypeInfo):
const_reference_type = "const std::string &" const_reference_type = "const std::string &"
decode_length = "value.as_string()" decode_length = "value.as_string()"
encode_func = "encode_string" encode_func = "encode_string"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
def dump(self, name): def dump(self, name):
o = f'out.append("\'").append({name}).append("\'");' o = f'out.append("\'").append({name}).append("\'");'
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(11) @register_type(11)
class MessageType(TypeInfo): class MessageType(TypeInfo):
@@ -339,6 +451,7 @@ class MessageType(TypeInfo):
return self._field.type_name[1:] return self._field.type_name[1:]
default_value = "" default_value = ""
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property @property
def reference_type(self) -> str: def reference_type(self) -> str:
@@ -360,6 +473,11 @@ class MessageType(TypeInfo):
o = f"{name}.dump_to(out);" o = f"{name}.dump_to(out);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_message_object(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(12) @register_type(12)
class BytesType(TypeInfo): class BytesType(TypeInfo):
@@ -369,11 +487,17 @@ class BytesType(TypeInfo):
const_reference_type = "const std::string &" const_reference_type = "const std::string &"
decode_length = "value.as_string()" decode_length = "value.as_string()"
encode_func = "encode_string" encode_func = "encode_string"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'out.append("\'").append({name}).append("\'");' o = f'out.append("\'").append({name}).append("\'");'
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(13) @register_type(13)
class UInt32Type(TypeInfo): class UInt32Type(TypeInfo):
@@ -381,12 +505,18 @@ class UInt32Type(TypeInfo):
default_value = "0" default_value = "0"
decode_varint = "value.as_uint32()" decode_varint = "value.as_uint32()"
encode_func = "encode_uint32" encode_func = "encode_uint32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRIu32, {name});\n' o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_uint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(14) @register_type(14)
class EnumType(TypeInfo): class EnumType(TypeInfo):
@@ -399,6 +529,7 @@ class EnumType(TypeInfo):
return f"value.as_enum<{self.cpp_type}>()" return f"value.as_enum<{self.cpp_type}>()"
default_value = "" default_value = ""
wire_type = WireType.VARINT # Uses wire type 0
@property @property
def encode_func(self) -> str: def encode_func(self) -> str:
@@ -408,6 +539,11 @@ class EnumType(TypeInfo):
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));" o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_enum_field(total_size, {field_id_size}, static_cast<uint32_t>({name}), {force_str(force)});"
return o
@register_type(15) @register_type(15)
class SFixed32Type(TypeInfo): class SFixed32Type(TypeInfo):
@@ -415,12 +551,18 @@ class SFixed32Type(TypeInfo):
default_value = "0" default_value = "0"
decode_32bit = "value.as_sfixed32()" decode_32bit = "value.as_sfixed32()"
encode_func = "encode_sfixed32" encode_func = "encode_sfixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(16) @register_type(16)
class SFixed64Type(TypeInfo): class SFixed64Type(TypeInfo):
@@ -428,12 +570,18 @@ class SFixed64Type(TypeInfo):
default_value = "0" default_value = "0"
decode_64bit = "value.as_sfixed64()" decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64" encode_func = "encode_sfixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
return o
@register_type(17) @register_type(17)
class SInt32Type(TypeInfo): class SInt32Type(TypeInfo):
@@ -441,12 +589,18 @@ class SInt32Type(TypeInfo):
default_value = "0" default_value = "0"
decode_varint = "value.as_sint32()" decode_varint = "value.as_sint32()"
encode_func = "encode_sint32" encode_func = "encode_sint32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%" PRId32, {name});\n' o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_sint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
@register_type(18) @register_type(18)
class SInt64Type(TypeInfo): class SInt64Type(TypeInfo):
@@ -454,12 +608,18 @@ class SInt64Type(TypeInfo):
default_value = "0" default_value = "0"
decode_varint = "value.as_sint64()" decode_varint = "value.as_sint64()"
encode_func = "encode_sint64" encode_func = "encode_sint64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str: def dump(self, name: str) -> str:
o = f'sprintf(buffer, "%lld", {name});\n' o = f'sprintf(buffer, "%lld", {name});\n'
o += "out.append(buffer);" o += "out.append(buffer);"
return o return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
o = f"ProtoSize::add_sint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
return o
class RepeatedTypeInfo(TypeInfo): class RepeatedTypeInfo(TypeInfo):
def __init__(self, field: descriptor.FieldDescriptorProto) -> None: def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
@@ -478,6 +638,14 @@ class RepeatedTypeInfo(TypeInfo):
def const_reference_type(self) -> str: def const_reference_type(self) -> str:
return f"const {self.cpp_type} &" return f"const {self.cpp_type} &"
@property
def wire_type(self) -> WireType:
"""Get the wire type for this repeated field.
For repeated fields, we use the same wire type as the underlying field.
"""
return self._ti.wire_type
@property @property
def decode_varint_content(self) -> str: def decode_varint_content(self) -> str:
content = self._ti.decode_varint content = self._ti.decode_varint
@@ -554,6 +722,22 @@ class RepeatedTypeInfo(TypeInfo):
def dump(self, _: str): def dump(self, _: str):
pass pass
def get_size_calculation(self, name: str, force: bool = False) -> str:
# For repeated fields, we always need to pass force=True to the underlying type's calculation
# This is because the encode method always sets force=true for repeated fields
if isinstance(self._ti, MessageType):
# For repeated messages, use the dedicated helper that handles iteration internally
field_id_size = self._ti.calculate_field_id_size()
o = f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});"
return o
# For other repeated types, use the underlying type's size calculation with force=True
o = f"if (!{name}.empty()) {{\n"
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += " }\n"
o += "}"
return o
def build_enum_type(desc) -> tuple[str, str]: def build_enum_type(desc) -> tuple[str, str]:
"""Builds the enum type.""" """Builds the enum type."""
@@ -587,6 +771,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
decode_64bit: list[str] = [] decode_64bit: list[str] = []
encode: list[str] = [] encode: list[str] = []
dump: list[str] = [] dump: list[str] = []
size_calc: list[str] = []
for field in desc.field: for field in desc.field:
if field.label == 3: if field.label == 3:
@@ -596,6 +781,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
protected_content.extend(ti.protected_content) protected_content.extend(ti.protected_content)
public_content.extend(ti.public_content) public_content.extend(ti.public_content)
encode.append(ti.encode_content) encode.append(ti.encode_content)
size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}"))
if ti.decode_varint_content: if ti.decode_varint_content:
decode_varint.append(ti.decode_varint_content) decode_varint.append(ti.decode_varint_content)
@@ -662,6 +848,25 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
prot = "void encode(ProtoWriteBuffer buffer) const override;" prot = "void encode(ProtoWriteBuffer buffer) const override;"
public_content.append(prot) public_content.append(prot)
# Add calculate_size method
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
# Add a check for empty/default objects to short-circuit the calculation
# Only add this optimization if we have fields to check
if size_calc:
# For a single field, just inline it for simplicity
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
o += f" {size_calc[0]} "
else:
# For multiple fields, add a short-circuit check
o += "\n"
# Performance optimization: add all the size calculations
o += indent("\n".join(size_calc)) + "\n"
o += "}\n"
cpp += o
prot = "void calculate_size(uint32_t &total_size) const override;"
public_content.append(prot)
o = f"void {desc.name}::dump_to(std::string &out) const {{" o = f"void {desc.name}::dump_to(std::string &out) const {{"
if dump: if dump:
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120: if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
@@ -796,6 +1001,7 @@ def main() -> None:
#pragma once #pragma once
#include "proto.h" #include "proto.h"
#include "api_pb2_size.h"
namespace esphome { namespace esphome {
namespace api { namespace api {
@@ -805,6 +1011,7 @@ def main() -> None:
cpp = FILE_HEADER cpp = FILE_HEADER
cpp += """\ cpp += """\
#include "api_pb2.h" #include "api_pb2.h"
#include "api_pb2_size.h"
#include "esphome/core/log.h" #include "esphome/core/log.h"
#include <cinttypes> #include <cinttypes>