mirror of
https://github.com/esphome/esphome.git
synced 2026-05-23 11:16:52 +08:00
[api] Devirtualize protobuf encode/calculate_size (#14449)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -270,18 +270,21 @@ class TypeInfo(ABC):
|
||||
def _get_simple_size_calculation(
|
||||
self, name: str, force: bool, base_method: str, value_expr: str = None
|
||||
) -> str:
|
||||
"""Helper for simple size calculations.
|
||||
"""Helper for simple size calculations using static ProtoSize methods.
|
||||
|
||||
Args:
|
||||
name: Field name
|
||||
force: Whether this is for a repeated field
|
||||
base_method: Base method name (e.g., "add_int32")
|
||||
base_method: Base method name (e.g., "int32")
|
||||
value_expr: Optional value expression (defaults to name)
|
||||
"""
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
method = f"{base_method}_force" if force else base_method
|
||||
method = f"calc_{base_method}_force" if force else f"calc_{base_method}"
|
||||
# calc_bool_force only takes field_id_size (no value needed - bool is always 1 byte)
|
||||
if base_method == "bool" and force:
|
||||
return f"size += ProtoSize::{method}({field_id_size});"
|
||||
value = value_expr or name
|
||||
return f"size.{method}({field_id_size}, {value});"
|
||||
return f"size += ProtoSize::{method}({field_id_size}, {value});"
|
||||
|
||||
@abstractmethod
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
@@ -410,7 +413,7 @@ class DoubleType(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_double({field_id_size}, {name});"
|
||||
return f"size += ProtoSize::calc_fixed64({field_id_size}, {name});"
|
||||
|
||||
def get_fixed_size_bytes(self) -> int:
|
||||
return 8
|
||||
@@ -434,7 +437,7 @@ class FloatType(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_float({field_id_size}, {name});"
|
||||
return f"size += ProtoSize::calc_float({field_id_size}, {name});"
|
||||
|
||||
def get_fixed_size_bytes(self) -> int:
|
||||
return 4
|
||||
@@ -457,7 +460,7 @@ class Int64Type(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_int64")
|
||||
return self._get_simple_size_calculation(name, force, "int64")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
@@ -477,7 +480,7 @@ class UInt64Type(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_uint64")
|
||||
return self._get_simple_size_calculation(name, force, "uint64")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
@@ -497,7 +500,7 @@ class Int32Type(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_int32")
|
||||
return self._get_simple_size_calculation(name, force, "int32")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
@@ -518,7 +521,7 @@ class Fixed64Type(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_fixed64({field_id_size}, {name});"
|
||||
return f"size += ProtoSize::calc_fixed64({field_id_size}, {name});"
|
||||
|
||||
def get_fixed_size_bytes(self) -> int:
|
||||
return 8
|
||||
@@ -542,7 +545,7 @@ class Fixed32Type(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_fixed32({field_id_size}, {name});"
|
||||
return f"size += ProtoSize::calc_fixed32({field_id_size}, {name});"
|
||||
|
||||
def get_fixed_size_bytes(self) -> int:
|
||||
return 4
|
||||
@@ -563,7 +566,7 @@ class BoolType(TypeInfo):
|
||||
return f"out.append(YESNO({name}));"
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_bool")
|
||||
return self._get_simple_size_calculation(name, force, "bool")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 1 # field ID + 1 byte
|
||||
@@ -647,18 +650,18 @@ class StringType(TypeInfo):
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
# For SOURCE_CLIENT only messages, use the string field directly
|
||||
if not self._needs_encode:
|
||||
return self._get_simple_size_calculation(name, force, "add_length")
|
||||
return self._get_simple_size_calculation(name, force, "length")
|
||||
|
||||
# Check if this is being called from a repeated field context
|
||||
# In that case, 'name' will be 'it' and we need to use the repeated version
|
||||
if name == "it":
|
||||
# For repeated fields, we need to use add_length_force which includes field ID
|
||||
# For repeated fields, we need to use length_force which includes field ID
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_length_force({field_id_size}, it.size());"
|
||||
return f"size += ProtoSize::calc_length_force({field_id_size}, it.size());"
|
||||
|
||||
# For messages that need encoding, use the StringRef size
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_length({field_id_size}, this->{self.field_name}_ref_.size());"
|
||||
return f"size += ProtoSize::calc_length({field_id_size}, this->{self.field_name}_ref_.size());"
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
|
||||
@@ -721,7 +724,9 @@ class MessageType(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_message_object")
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
method = "calc_message_force" if force else "calc_message"
|
||||
return f"size += ProtoSize::{method}({field_id_size}, {name}.calculate_size());"
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
# For message types, we can't easily estimate the submessage size without
|
||||
@@ -822,7 +827,7 @@ class BytesType(TypeInfo):
|
||||
)
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return f"size.add_length({self.calculate_field_id_size()}, this->{self.field_name}_len_);"
|
||||
return f"size += ProtoSize::calc_length({self.calculate_field_id_size()}, this->{self.field_name}_len_);"
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes
|
||||
@@ -897,7 +902,7 @@ class PointerToBytesBufferType(PointerToBufferTypeBase):
|
||||
)
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return f"size.add_length({self.calculate_field_id_size()}, this->{self.field_name}_len);"
|
||||
return f"size += ProtoSize::calc_length({self.calculate_field_id_size()}, this->{self.field_name}_len);"
|
||||
|
||||
|
||||
class PointerToStringBufferType(PointerToBufferTypeBase):
|
||||
@@ -939,7 +944,7 @@ class PointerToStringBufferType(PointerToBufferTypeBase):
|
||||
return f'dump_field(out, "{self.name}", this->{self.field_name});'
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return f"size.add_length({self.calculate_field_id_size()}, this->{self.field_name}.size());"
|
||||
return f"size += ProtoSize::calc_length({self.calculate_field_id_size()}, this->{self.field_name}.size());"
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
|
||||
@@ -1103,9 +1108,9 @@ class FixedArrayBytesType(TypeInfo):
|
||||
|
||||
if force:
|
||||
# For repeated fields, always calculate size (no zero check)
|
||||
return f"size.add_length_force({field_id_size}, {length_field});"
|
||||
# For non-repeated fields, add_length already checks for zero
|
||||
return f"size.add_length({field_id_size}, {length_field});"
|
||||
return f"size += ProtoSize::calc_length_force({field_id_size}, {length_field});"
|
||||
# For non-repeated fields, length already checks for zero
|
||||
return f"size += ProtoSize::calc_length({field_id_size}, {length_field});"
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
# Estimate based on typical BLE advertisement size
|
||||
@@ -1132,7 +1137,7 @@ class UInt32Type(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_uint32")
|
||||
return self._get_simple_size_calculation(name, force, "uint32")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
@@ -1168,7 +1173,7 @@ class EnumType(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(
|
||||
name, force, "add_uint32", f"static_cast<uint32_t>({name})"
|
||||
name, force, "uint32", f"static_cast<uint32_t>({name})"
|
||||
)
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
@@ -1190,7 +1195,7 @@ class SFixed32Type(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_sfixed32({field_id_size}, {name});"
|
||||
return f"size += ProtoSize::calc_sfixed32({field_id_size}, {name});"
|
||||
|
||||
def get_fixed_size_bytes(self) -> int:
|
||||
return 4
|
||||
@@ -1214,7 +1219,7 @@ class SFixed64Type(TypeInfo):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return f"size.add_sfixed64({field_id_size}, {name});"
|
||||
return f"size += ProtoSize::calc_sfixed64({field_id_size}, {name});"
|
||||
|
||||
def get_fixed_size_bytes(self) -> int:
|
||||
return 8
|
||||
@@ -1237,7 +1242,7 @@ class SInt32Type(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_sint32")
|
||||
return self._get_simple_size_calculation(name, force, "sint32")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
@@ -1257,7 +1262,7 @@ class SInt64Type(TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
return self._get_simple_size_calculation(name, force, "add_sint64")
|
||||
return self._get_simple_size_calculation(name, force, "sint64")
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
@@ -1694,11 +1699,17 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
# For repeated fields, we always need to pass force=True to the underlying type's calculation
|
||||
# This is because the encode method always sets force=true for repeated fields
|
||||
|
||||
# Handle message types separately as they use a dedicated helper
|
||||
# Handle message types separately - generate inline loop
|
||||
if isinstance(self._ti, MessageType):
|
||||
field_id_size = self._ti.calculate_field_id_size()
|
||||
container = f"*{name}" if self._use_pointer else name
|
||||
return f"size.add_repeated_message({field_id_size}, {container});"
|
||||
container_ref = f"*{name}" if self._use_pointer else name
|
||||
empty_check = f"{name}->empty()" if self._use_pointer else f"{name}.empty()"
|
||||
o = f"if (!{empty_check}) {{\n"
|
||||
o += f" for (const auto &it : {container_ref}) {{\n"
|
||||
o += f" size += ProtoSize::calc_message_force({field_id_size}, it.calculate_size());\n"
|
||||
o += " }\n"
|
||||
o += "}"
|
||||
return o
|
||||
|
||||
# For non-message types, generate size calculation with iteration
|
||||
container_ref = f"*{name}" if self._use_pointer else name
|
||||
@@ -1713,14 +1724,14 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
field_id_size = self._ti.calculate_field_id_size()
|
||||
bytes_per_element = field_id_size + num_bytes
|
||||
size_expr = f"{name}->size()" if self._use_pointer else f"{name}.size()"
|
||||
o += f" size.add_precalculated_size({size_expr} * {bytes_per_element});\n"
|
||||
o += f" size += {size_expr} * {bytes_per_element};\n"
|
||||
else:
|
||||
# Other types need the actual value
|
||||
# Special handling for const char* elements
|
||||
if self._use_pointer and "const char" in self._container_no_template:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
o += f" for (const char *it : {container_ref}) {{\n"
|
||||
o += f" size.add_length_force({field_id_size}, strlen(it));\n"
|
||||
o += f" size += ProtoSize::calc_length_force({field_id_size}, strlen(it));\n"
|
||||
else:
|
||||
auto_ref = "" if self._ti_is_bool else "&"
|
||||
o += f" for (const auto {auto_ref}it : {container_ref}) {{\n"
|
||||
@@ -2233,23 +2244,19 @@ def build_message_type(
|
||||
o += indent("\n".join(encode)) + "\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "void encode(ProtoWriteBuffer &buffer) const override;"
|
||||
prot = "void encode(ProtoWriteBuffer &buffer) const;"
|
||||
public_content.append(prot)
|
||||
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used
|
||||
|
||||
# Add calculate_size method only if this message needs encoding and has fields
|
||||
if needs_encode and size_calc:
|
||||
o = f"void {desc.name}::calculate_size(ProtoSize &size) const {{"
|
||||
# For a single field, just inline it for simplicity
|
||||
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
|
||||
o += f" {size_calc[0]} }}\n"
|
||||
else:
|
||||
# For multiple fields
|
||||
o += "\n"
|
||||
o += indent("\n".join(size_calc)) + "\n"
|
||||
o += "}\n"
|
||||
o = f"uint32_t {desc.name}::calculate_size() const {{\n"
|
||||
o += " uint32_t size = 0;\n"
|
||||
o += indent("\n".join(size_calc)) + "\n"
|
||||
o += " return size;\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "void calculate_size(ProtoSize &size) const override;"
|
||||
prot = "uint32_t calculate_size() const;"
|
||||
public_content.append(prot)
|
||||
# If no fields to calculate size for or message doesn't need encoding, the default implementation in ProtoMessage will be used
|
||||
|
||||
@@ -2933,14 +2940,8 @@ static const char *const TAG = "api.service";
|
||||
hpp += " public:\n"
|
||||
hpp += "#endif\n\n"
|
||||
|
||||
# Add non-template send_message method
|
||||
hpp += " bool send_message(const ProtoMessage &msg, uint8_t message_type) {\n"
|
||||
hpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
||||
hpp += " DumpBuffer dump_buf;\n"
|
||||
hpp += " this->log_send_message_(msg.message_name(), msg.dump_to(dump_buf));\n"
|
||||
hpp += "#endif\n"
|
||||
hpp += " return this->send_message_impl(msg, message_type);\n"
|
||||
hpp += " }\n\n"
|
||||
# send_message is now a template on APIConnection directly
|
||||
# No non-template send_message method needed here
|
||||
|
||||
# Add logging helper method implementations to cpp
|
||||
cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
|
||||
|
||||
Reference in New Issue
Block a user