mirror of
https://github.com/esphome/esphome.git
synced 2026-05-20 09:31:56 +08:00
[api] Add (inline_encode) proto option for sub-message inlining (#15599)
This commit is contained in:
committed by
Jesse Hills
parent
835ee456a5
commit
0f16d27a72
@@ -60,6 +60,10 @@ FILE_HEADER = """// This file was automatically generated with a tool.
|
||||
# Maps enum type name (e.g. ".BluetoothDeviceRequestType") to max enum value.
|
||||
_enum_max_values: dict[str, int] = {}
|
||||
|
||||
# Populated by main() before message generation.
|
||||
# Maps message name (e.g. "BluetoothLERawAdvertisement") to its descriptor.
|
||||
_message_desc_map: dict[str, Any] = {}
|
||||
|
||||
|
||||
def indent_list(text: str, padding: str = " ") -> list[str]:
|
||||
"""Indent each line of the given text with the specified padding."""
|
||||
@@ -427,6 +431,23 @@ class TypeInfo(ABC):
|
||||
Estimated size in bytes including field ID and typical data
|
||||
"""
|
||||
|
||||
def get_max_encoded_size(self) -> int | None:
|
||||
"""Get the maximum possible encoded size in bytes for this field.
|
||||
|
||||
Returns the worst-case encoded size including field ID and maximum
|
||||
possible value encoding. Returns None if the size is unbounded
|
||||
(e.g., variable-length strings without max_data_length).
|
||||
|
||||
Used by (inline_encode) validation to ensure sub-messages fit in a
|
||||
single-byte length varint (< 128 bytes).
|
||||
"""
|
||||
return None # Unbounded by default
|
||||
|
||||
|
||||
def _varint_max_size(bits: int) -> int:
|
||||
"""Return the maximum varint encoding size for a value with the given number of bits."""
|
||||
return (max(bits, 1) + 6) // 7 # ceil(bits / 7), min 1 byte for varint(0)
|
||||
|
||||
|
||||
TYPE_INFO: dict[int, TypeInfo] = {}
|
||||
|
||||
@@ -514,8 +535,30 @@ def register_type(name: int):
|
||||
return func
|
||||
|
||||
|
||||
class FixedSizeTypeMixin:
|
||||
"""Mixin for types with a known fixed encoded size (float, double, fixed32, fixed64)."""
|
||||
|
||||
def get_max_encoded_size(self) -> int:
|
||||
return self.calculate_field_id_size() + self.get_fixed_size_bytes()
|
||||
|
||||
|
||||
class VarintTypeMixin:
|
||||
"""Mixin for varint types. Subclasses set _varint_max_bits."""
|
||||
|
||||
_varint_max_bits: int = 64 # Default to worst case
|
||||
|
||||
def get_max_encoded_size(self) -> int:
|
||||
max_val = self.max_value
|
||||
if max_val is not None:
|
||||
return self.calculate_field_id_size() + _varint_max_size(
|
||||
max_val.bit_length() if max_val > 0 else 1
|
||||
)
|
||||
return self.calculate_field_id_size() + _varint_max_size(self._varint_max_bits)
|
||||
|
||||
|
||||
@register_type(1)
|
||||
class DoubleType(TypeInfo):
|
||||
class DoubleType(FixedSizeTypeMixin, TypeInfo):
|
||||
# Unsupported but defined for completeness
|
||||
cpp_type = "double"
|
||||
default_value = "0.0"
|
||||
decode_64bit = "value.as_double()"
|
||||
@@ -541,7 +584,7 @@ class DoubleType(TypeInfo):
|
||||
|
||||
|
||||
@register_type(2)
|
||||
class FloatType(TypeInfo):
|
||||
class FloatType(FixedSizeTypeMixin, TypeInfo):
|
||||
cpp_type = "float"
|
||||
default_value = "0.0f"
|
||||
decode_32bit = "value.as_float()"
|
||||
@@ -567,8 +610,9 @@ class FloatType(TypeInfo):
|
||||
|
||||
|
||||
@register_type(3)
|
||||
class Int64Type(TypeInfo):
|
||||
class Int64Type(VarintTypeMixin, TypeInfo):
|
||||
cpp_type = "int64_t"
|
||||
_varint_max_bits = 64
|
||||
default_value = "0"
|
||||
decode_varint = "static_cast<int64_t>(value)"
|
||||
encode_func = "encode_int64"
|
||||
@@ -587,8 +631,9 @@ class Int64Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(4)
|
||||
class UInt64Type(TypeInfo):
|
||||
class UInt64Type(VarintTypeMixin, TypeInfo):
|
||||
cpp_type = "uint64_t"
|
||||
_varint_max_bits = 64
|
||||
default_value = "0"
|
||||
decode_varint = "value"
|
||||
encode_func = "encode_uint64"
|
||||
@@ -607,8 +652,9 @@ class UInt64Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(5)
|
||||
class Int32Type(TypeInfo):
|
||||
class Int32Type(VarintTypeMixin, TypeInfo):
|
||||
cpp_type = "int32_t"
|
||||
_varint_max_bits = 64 # int32 is sign-extended to 64 bits in protobuf
|
||||
default_value = "0"
|
||||
decode_varint = "static_cast<int32_t>(value)"
|
||||
encode_func = "encode_int32"
|
||||
@@ -627,7 +673,7 @@ class Int32Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(6)
|
||||
class Fixed64Type(TypeInfo):
|
||||
class Fixed64Type(FixedSizeTypeMixin, TypeInfo):
|
||||
cpp_type = "uint64_t"
|
||||
default_value = "0"
|
||||
decode_64bit = "value.as_fixed64()"
|
||||
@@ -653,7 +699,7 @@ class Fixed64Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(7)
|
||||
class Fixed32Type(TypeInfo):
|
||||
class Fixed32Type(FixedSizeTypeMixin, TypeInfo):
|
||||
cpp_type = "uint32_t"
|
||||
default_value = "0"
|
||||
decode_32bit = "value.as_fixed32()"
|
||||
@@ -689,7 +735,8 @@ class Fixed32Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(8)
|
||||
class BoolType(TypeInfo):
|
||||
class BoolType(VarintTypeMixin, TypeInfo):
|
||||
_varint_max_bits = 1
|
||||
cpp_type = "bool"
|
||||
default_value = "false"
|
||||
decode_varint = "value != 0"
|
||||
@@ -807,6 +854,16 @@ class StringType(TypeInfo):
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
|
||||
|
||||
def get_max_encoded_size(self) -> int | None:
|
||||
max_len = self.max_data_length
|
||||
if max_len is not None:
|
||||
return (
|
||||
self.calculate_field_id_size()
|
||||
+ _varint_max_size(max_len.bit_length())
|
||||
+ max_len
|
||||
)
|
||||
return None # Unbounded
|
||||
|
||||
|
||||
@register_type(11)
|
||||
class MessageType(TypeInfo):
|
||||
@@ -1122,6 +1179,16 @@ class PointerToStringBufferType(PointerToBufferTypeBase):
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
|
||||
|
||||
def get_max_encoded_size(self) -> int | None:
|
||||
max_len = self.max_data_length
|
||||
if max_len is not None:
|
||||
return (
|
||||
self.calculate_field_id_size()
|
||||
+ _varint_max_size(max_len.bit_length())
|
||||
+ max_len
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class PackedBufferTypeInfo(TypeInfo):
|
||||
"""Type for packed repeated fields that expose raw buffer instead of decoding.
|
||||
@@ -1299,14 +1366,23 @@ class FixedArrayBytesType(TypeInfo):
|
||||
self.calculate_field_id_size() + 1 + 31
|
||||
) # field ID + length byte + typical 31 bytes
|
||||
|
||||
def get_max_encoded_size(self) -> int:
|
||||
# field_id + varint(array_size) + array_size
|
||||
return (
|
||||
self.calculate_field_id_size()
|
||||
+ _varint_max_size(self.array_size.bit_length())
|
||||
+ self.array_size
|
||||
)
|
||||
|
||||
@property
|
||||
def wire_type(self) -> WireType:
|
||||
return WireType.LENGTH_DELIMITED
|
||||
|
||||
|
||||
@register_type(13)
|
||||
class UInt32Type(TypeInfo):
|
||||
class UInt32Type(VarintTypeMixin, TypeInfo):
|
||||
cpp_type = "uint32_t"
|
||||
_varint_max_bits = 32
|
||||
default_value = "0"
|
||||
decode_varint = "value"
|
||||
encode_func = "encode_uint32"
|
||||
@@ -1328,7 +1404,9 @@ class UInt32Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(14)
|
||||
class EnumType(TypeInfo):
|
||||
class EnumType(VarintTypeMixin, TypeInfo):
|
||||
_varint_max_bits = 32
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
return f"enums::{self._field.type_name[1:]}"
|
||||
@@ -1379,7 +1457,7 @@ class EnumType(TypeInfo):
|
||||
|
||||
|
||||
@register_type(15)
|
||||
class SFixed32Type(TypeInfo):
|
||||
class SFixed32Type(FixedSizeTypeMixin, TypeInfo):
|
||||
cpp_type = "int32_t"
|
||||
default_value = "0"
|
||||
decode_32bit = "value.as_sfixed32()"
|
||||
@@ -1405,7 +1483,7 @@ class SFixed32Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(16)
|
||||
class SFixed64Type(TypeInfo):
|
||||
class SFixed64Type(FixedSizeTypeMixin, TypeInfo):
|
||||
cpp_type = "int64_t"
|
||||
default_value = "0"
|
||||
decode_64bit = "value.as_sfixed64()"
|
||||
@@ -1431,8 +1509,9 @@ class SFixed64Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(17)
|
||||
class SInt32Type(TypeInfo):
|
||||
class SInt32Type(VarintTypeMixin, TypeInfo):
|
||||
cpp_type = "int32_t"
|
||||
_varint_max_bits = 32 # zigzag encoding keeps it 32-bit
|
||||
default_value = "0"
|
||||
decode_varint = "decode_zigzag32(static_cast<uint32_t>(value))"
|
||||
encode_func = "encode_sint32"
|
||||
@@ -1451,8 +1530,9 @@ class SInt32Type(TypeInfo):
|
||||
|
||||
|
||||
@register_type(18)
|
||||
class SInt64Type(TypeInfo):
|
||||
class SInt64Type(VarintTypeMixin, TypeInfo):
|
||||
cpp_type = "int64_t"
|
||||
_varint_max_bits = 64
|
||||
default_value = "0"
|
||||
decode_varint = "decode_zigzag64(value)"
|
||||
encode_func = "encode_sint64"
|
||||
@@ -1500,6 +1580,91 @@ def _generate_array_dump_content(
|
||||
return o
|
||||
|
||||
|
||||
def _is_inline_encode(sub_msg_name: str) -> bool:
|
||||
"""Check if a sub-message type has the (inline_encode) option set."""
|
||||
sub_desc = _message_desc_map.get(sub_msg_name)
|
||||
if not sub_desc:
|
||||
return False
|
||||
inline_opt = getattr(pb, "inline_encode", None)
|
||||
if inline_opt is None:
|
||||
return False
|
||||
return get_opt(sub_desc, inline_opt, False)
|
||||
|
||||
|
||||
def _generate_inline_encode_block(
|
||||
field_number: int, sub_msg_name: str, element: str
|
||||
) -> str:
|
||||
"""Generate inline encode code for a sub-message with (inline_encode) = true.
|
||||
|
||||
Instead of calling encode_sub_message (function pointer indirection),
|
||||
this inlines the sub-message's field encoding directly. Uses 1-byte
|
||||
backpatch for the length (validated to be < 128 at generation time).
|
||||
|
||||
Uses a local reference alias 'sub_msg' to avoid issues with this-> replacement
|
||||
on complex element expressions.
|
||||
|
||||
Args:
|
||||
field_number: The parent field number for this sub-message
|
||||
sub_msg_name: The sub-message type name
|
||||
element: C++ expression for the element (e.g., "it" or "this->field[i]")
|
||||
"""
|
||||
sub_desc = _message_desc_map[sub_msg_name]
|
||||
tag = (field_number << 3) | 2 # wire type 2 = LENGTH_DELIMITED
|
||||
assert tag < 128, f"inline_encode requires single-byte tag, got {tag}"
|
||||
|
||||
lines = []
|
||||
lines.append(f"auto &sub_msg = {element};")
|
||||
lines.append(f"ProtoEncode::write_raw_byte(pos, {tag});")
|
||||
lines.append("uint8_t *len_pos = pos;")
|
||||
lines.append("ProtoEncode::reserve_byte(pos);")
|
||||
|
||||
# Generate inline field encoding for each sub-message field
|
||||
for field in sub_desc.field:
|
||||
if field.options.deprecated:
|
||||
continue
|
||||
ti = create_field_type_info(field, needs_decode=False, needs_encode=True)
|
||||
encode_line = ti.encode_content
|
||||
# Replace this-> with sub_msg reference for the sub-message fields
|
||||
encode_line = encode_line.replace("this->", "sub_msg.")
|
||||
lines.extend(wrap_with_ifdef(encode_line, get_field_opt(field, pb.field_ifdef)))
|
||||
|
||||
lines.append("*len_pos = static_cast<uint8_t>(pos - len_pos - 1);")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_inline_size_block(
|
||||
field_number: int, sub_msg_name: str, element: str
|
||||
) -> str:
|
||||
"""Generate inline size calculation for a sub-message with (inline_encode) = true.
|
||||
|
||||
Uses a local reference alias 'sub_msg' to avoid issues with this-> replacement
|
||||
on complex element expressions like 'this->advertisements[i]'.
|
||||
|
||||
Args:
|
||||
field_number: The parent field number for this sub-message
|
||||
sub_msg_name: The sub-message type name
|
||||
element: C++ expression for the element
|
||||
"""
|
||||
sub_desc = _message_desc_map[sub_msg_name]
|
||||
|
||||
lines = []
|
||||
lines.append(f"auto &sub_msg = {element};")
|
||||
# 1 byte tag + 1 byte length (guaranteed < 128 by validation)
|
||||
lines.append("size += 2;")
|
||||
|
||||
for field in sub_desc.field:
|
||||
if field.options.deprecated:
|
||||
continue
|
||||
ti = create_field_type_info(field, needs_decode=False, needs_encode=True)
|
||||
force = get_field_opt(field, pb.force, False)
|
||||
size_line = ti.get_size_calculation(f"sub_msg.{ti.field_name}", force)
|
||||
# Replace hardcoded this-> references (e.g., FixedArrayBytesType uses this->field_len)
|
||||
size_line = size_line.replace("this->", "sub_msg.")
|
||||
lines.extend(wrap_with_ifdef(size_line, get_field_opt(field, pb.field_ifdef)))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class FixedArrayRepeatedType(TypeInfo):
|
||||
"""Special type for fixed-size repeated fields using std::array.
|
||||
|
||||
@@ -1526,6 +1691,10 @@ class FixedArrayRepeatedType(TypeInfo):
|
||||
return f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, static_cast<uint32_t>({element}), true);"
|
||||
# Repeated message elements use encode_sub_message (force=true is default)
|
||||
if isinstance(self._ti, MessageType):
|
||||
if _is_inline_encode(self._ti.cpp_type):
|
||||
return _generate_inline_encode_block(
|
||||
self.number, self._ti.cpp_type, element
|
||||
)
|
||||
return f"ProtoEncode::encode_sub_message(pos, buffer, {self.number}, {element});"
|
||||
return (
|
||||
f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, {element}, true);"
|
||||
@@ -1633,8 +1802,19 @@ class FixedArrayRepeatedType(TypeInfo):
|
||||
]
|
||||
return f"if ({non_zero_checks}) {{\n" + "\n".join(size_lines) + "\n}"
|
||||
|
||||
is_inline = isinstance(self._ti, MessageType) and _is_inline_encode(
|
||||
self._ti.cpp_type
|
||||
)
|
||||
|
||||
# When using a define, always use loop-based approach
|
||||
if self.is_define:
|
||||
if is_inline:
|
||||
o = f"for (const auto &it : {name}) {{\n"
|
||||
o += indent(
|
||||
_generate_inline_size_block(self.number, self._ti.cpp_type, "it")
|
||||
)
|
||||
o += "\n}"
|
||||
return o
|
||||
o = f"for (const auto &it : {name}) {{\n"
|
||||
o += f" {self._ti.get_size_calculation('it', True)}\n"
|
||||
o += "}"
|
||||
@@ -1642,6 +1822,14 @@ class FixedArrayRepeatedType(TypeInfo):
|
||||
|
||||
# For fixed arrays, we always encode all elements
|
||||
|
||||
if is_inline:
|
||||
o = f"for (const auto &it : {name}) {{\n"
|
||||
o += indent(
|
||||
_generate_inline_size_block(self.number, self._ti.cpp_type, "it")
|
||||
)
|
||||
o += "\n}"
|
||||
return o
|
||||
|
||||
# Special case for single-element arrays - no loop needed
|
||||
if self.array_size == 1:
|
||||
return self._ti.get_size_calculation(f"{name}[0]", True)
|
||||
@@ -1714,6 +1902,15 @@ class FixedArrayWithLengthRepeatedType(FixedArrayRepeatedType):
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
# Calculate size only for active elements
|
||||
if isinstance(self._ti, MessageType) and _is_inline_encode(self._ti.cpp_type):
|
||||
o = f"for (uint16_t i = 0; i < {name}_len; i++) {{\n"
|
||||
o += indent(
|
||||
_generate_inline_size_block(
|
||||
self.number, self._ti.cpp_type, f"{name}[i]"
|
||||
)
|
||||
)
|
||||
o += "\n}"
|
||||
return o
|
||||
o = f"for (uint16_t i = 0; i < {name}_len; i++) {{\n"
|
||||
o += f" {self._ti.get_size_calculation(f'{name}[i]', True)}\n"
|
||||
o += "}"
|
||||
@@ -2222,6 +2419,28 @@ def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
|
||||
return total_size
|
||||
|
||||
|
||||
def calculate_message_max_size(desc: descriptor.DescriptorProto) -> int | None:
|
||||
"""Calculate the maximum possible encoded size for a message.
|
||||
|
||||
Returns None if any field has unbounded size (e.g., variable-length strings).
|
||||
Used to validate that (inline_encode) messages fit in a single-byte length varint.
|
||||
"""
|
||||
total_size = 0
|
||||
|
||||
for field in desc.field:
|
||||
if field.options.deprecated:
|
||||
continue
|
||||
|
||||
ti = create_field_type_info(field, needs_decode=False, needs_encode=True)
|
||||
max_size = ti.get_max_encoded_size()
|
||||
if max_size is None:
|
||||
return None
|
||||
|
||||
total_size += max_size
|
||||
|
||||
return total_size
|
||||
|
||||
|
||||
def build_message_type(
|
||||
desc: descriptor.DescriptorProto,
|
||||
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]],
|
||||
@@ -2451,11 +2670,23 @@ def build_message_type(
|
||||
prot = "void decode(const uint8_t *buffer, size_t length);"
|
||||
public_content.append(prot)
|
||||
|
||||
# Check if this message uses inline_encode — if so, skip generating standalone
|
||||
# encode/calculate_size methods since the encoding is inlined into the parent.
|
||||
inline_opt = getattr(pb, "inline_encode", None)
|
||||
is_inline_only = (
|
||||
message_id is None # Not a service message (no id)
|
||||
and inline_opt is not None
|
||||
and get_opt(desc, inline_opt, False)
|
||||
)
|
||||
|
||||
# Only generate encode method if this message needs encoding and has fields
|
||||
if needs_encode and encode:
|
||||
if needs_encode and encode and not is_inline_only:
|
||||
# Add PROTO_ENCODE_DEBUG_ARG after pos in all proto_* calls
|
||||
encode_debug = [
|
||||
line.replace("(pos,", "(pos PROTO_ENCODE_DEBUG_ARG,") for line in encode
|
||||
line.replace("(pos,", "(pos PROTO_ENCODE_DEBUG_ARG,").replace(
|
||||
"(pos)", "(pos PROTO_ENCODE_DEBUG_ARG)"
|
||||
)
|
||||
for line in encode
|
||||
]
|
||||
o = f"uint8_t *{desc.name}::encode(ProtoWriteBuffer &buffer PROTO_ENCODE_DEBUG_PARAM) const {{\n"
|
||||
o += " uint8_t *__restrict__ pos = buffer.get_pos();\n"
|
||||
@@ -2470,7 +2701,7 @@ def build_message_type(
|
||||
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used
|
||||
|
||||
# Add calculate_size method only if this message needs encoding and has fields
|
||||
if needs_encode and size_calc:
|
||||
if needs_encode and size_calc and not is_inline_only:
|
||||
o = f"uint32_t {desc.name}::calculate_size() const {{\n"
|
||||
o += " uint32_t size = 0;\n"
|
||||
o += indent("\n".join(size_calc)) + "\n"
|
||||
@@ -2830,6 +3061,32 @@ def main() -> None:
|
||||
if not enum.options.deprecated and enum.value:
|
||||
_enum_max_values[f".{enum.name}"] = max(v.number for v in enum.value)
|
||||
|
||||
# Build message descriptor map for inline_encode lookups
|
||||
mt = file.message_type
|
||||
_message_desc_map.update({m.name: m for m in mt if not m.options.deprecated})
|
||||
|
||||
# Validate inline_encode messages fit in single-byte length varint
|
||||
inline_encode_opt = getattr(pb, "inline_encode", None)
|
||||
if inline_encode_opt is not None:
|
||||
for m in mt:
|
||||
if m.options.deprecated:
|
||||
continue
|
||||
if not get_opt(m, inline_encode_opt, False):
|
||||
continue
|
||||
max_size = calculate_message_max_size(m)
|
||||
if max_size is None:
|
||||
raise ValueError(
|
||||
f"Message '{m.name}' has (inline_encode) = true but contains "
|
||||
f"fields with unbounded size. Inline encoding requires all "
|
||||
f"fields to have bounded maximum size."
|
||||
)
|
||||
if max_size >= 128:
|
||||
raise ValueError(
|
||||
f"Message '{m.name}' has (inline_encode) = true but max "
|
||||
f"encoded size is {max_size} bytes (>= 128). Inline encoding "
|
||||
f"requires sub-messages that fit in a single-byte length varint."
|
||||
)
|
||||
|
||||
# Build dynamic ifdef mappings early so we can emit USE_API_VARINT64 before includes
|
||||
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
|
||||
build_type_usage_map(file)
|
||||
@@ -3048,8 +3305,6 @@ static void dump_bytes_field(DumpBuffer &out, const char *field_name, const uint
|
||||
|
||||
content += "\n} // namespace enums\n\n"
|
||||
|
||||
mt = file.message_type
|
||||
|
||||
# Identify empty SOURCE_CLIENT messages that don't need class generation
|
||||
for m in mt:
|
||||
if m.options.deprecated:
|
||||
|
||||
Reference in New Issue
Block a user