Files
esphome/script/api_protobuf/api_protobuf.py
J. Nick Koston a5bd718fc9 [api] Make ProtoDecodableMessage::decode() non-virtual
decode() is never called polymorphically - all call sites in
read_message_() use concrete types. The only indirect call site was
decode_to_message(), which also always knows the concrete type.

Convert decode_to_message() to a template so the concrete type is
preserved, allowing decode() to be non-virtual. The two classes that
override decode() (ExecuteServiceArgument, ExecuteServiceRequest) now
hide the base method, which works since all calls use concrete types.

This removes one vtable slot (4 bytes) from each decodable message
class vtable, saving ~148 bytes of flash.
2026-03-21 20:51:28 -10:00

3147 lines
110 KiB
Python
Executable File

#!/usr/bin/env python3
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import IntEnum
from pathlib import Path
import re
from subprocess import call
import sys
from typing import Any
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
class WireType(IntEnum):
"""Protocol Buffer wire types as defined in the protobuf spec.
As specified in the Protocol Buffers encoding guide:
https://protobuf.dev/programming-guides/encoding/#structure
"""
VARINT = 0 # int32, int64, uint32, uint64, sint32, sint64, bool, enum
FIXED64 = 1 # fixed64, sfixed64, double
LENGTH_DELIMITED = 2 # string, bytes, embedded messages, packed repeated fields
START_GROUP = 3 # groups (deprecated)
END_GROUP = 4 # groups (deprecated)
FIXED32 = 5 # fixed32, sfixed32, float
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
It's pretty crappy spaghetti code, but it works.
you need to install protobuf-compiler:
running protoc --version should return
libprotoc 3.6.1
then run this script with python3 and the files
esphome/components/api/api_pb2_service.h
esphome/components/api/api_pb2_service.cpp
esphome/components/api/api_pb2.h
esphome/components/api/api_pb2.cpp
will be generated, they still need to be formatted
"""
FILE_HEADER = """// This file was automatically generated with a tool.
// See script/api_protobuf/api_protobuf.py
"""
def indent_list(text: str, padding: str = " ") -> list[str]:
"""Indent each line of the given text with the specified padding."""
lines = []
for line in text.splitlines():
if line == "" or line.startswith("#ifdef") or line.startswith("#endif"):
p = ""
else:
p = padding
lines.append(p + line)
return lines
def indent(text: str, padding: str = " ") -> str:
return "\n".join(indent_list(text, padding))
def wrap_with_ifdef(content: str | list[str], ifdef: str | None) -> list[str]:
"""Wrap content with #ifdef directives if ifdef is provided.
Args:
content: Single string or list of strings to wrap
ifdef: The ifdef condition, or None to skip wrapping
Returns:
List of strings with ifdef wrapping if needed
"""
if not ifdef:
if isinstance(content, str):
return [content]
return content
result = [f"#ifdef {ifdef}"]
if isinstance(content, str):
result.append(content)
else:
result.extend(content)
result.append("#endif")
return result
def camel_to_snake(name: str) -> str:
# https://stackoverflow.com/a/1176023
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def force_str(force: bool) -> str:
"""Convert a boolean force value to string format for C++ code."""
return str(force).lower()
class TypeInfo(ABC):
"""Base class for all type information."""
def __init__(
self,
field: descriptor.FieldDescriptorProto,
needs_decode: bool = True,
needs_encode: bool = True,
) -> None:
self._field = field
self._needs_decode = needs_decode
self._needs_encode = needs_encode
@property
def default_value(self) -> str:
"""Get the default value."""
return ""
@property
def name(self) -> str:
"""Get the name of the field."""
return self._field.name
@property
def arg_name(self) -> str:
"""Get the argument name."""
return self.name
@property
def field_name(self) -> str:
"""Get the field name."""
return self.name
@property
def number(self) -> int:
"""Get the field number."""
return self._field.number
@property
def repeated(self) -> bool:
"""Check if the field is repeated."""
return self._field.label == FieldDescriptorProto.LABEL_REPEATED
@property
def force(self) -> bool:
"""Check if this field should always be encoded (skip zero/empty check)."""
return get_field_opt(self._field, pb.force, False)
@property
def wire_type(self) -> WireType:
"""Get the wire type for the field."""
raise NotImplementedError
@property
def cpp_type(self) -> str:
raise NotImplementedError
@property
def reference_type(self) -> str:
return f"{self.cpp_type} "
@property
def const_reference_type(self) -> str:
return f"{self.cpp_type} "
@property
def public_content(self) -> str:
return [self.class_member]
@property
def protected_content(self) -> str:
return []
@property
def class_member(self) -> str:
return f"{self.cpp_type} {self.field_name}{{{self.default_value}}};"
@property
def decode_varint_content(self) -> str:
content = self.decode_varint
if content is None:
return None
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_varint = None
@property
def decode_length_content(self) -> str:
content = self.decode_length
if content is None:
return None
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_length = None
@property
def decode_32bit_content(self) -> str:
content = self.decode_32bit
if content is None:
return None
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_32bit = None
@property
def decode_64bit_content(self) -> str:
content = self.decode_64bit
if content is None:
return None
return f"case {self.number}: this->{self.field_name} = {content}; break;"
decode_64bit = None
@property
def encode_content(self) -> str:
if self.force:
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, true);"
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
encode_func = None
@classmethod
def can_use_dump_field(cls) -> bool:
"""Whether this type can use the dump_field helper functions.
Returns True for simple types that have dump_field overloads.
Complex types like messages and bytes should return False.
"""
return True
def dump_field_value(self, value: str) -> str:
"""Get the value expression to pass to dump_field.
Most types just pass the value directly, but some (like enums) need a cast.
"""
return value
@property
def dump_content(self) -> str:
# Default implementation - subclasses can override if they need special handling
return f'dump_field(out, "{self.name}", {self.dump_field_value(f"this->{self.field_name}")});'
@abstractmethod
def dump(self, name: str) -> str:
"""Dump the value to the output."""
def calculate_field_id_size(self) -> int:
"""Calculates the size of a field ID in bytes.
Returns:
The number of bytes needed to encode the field ID
"""
# Calculate the tag by combining field_id and wire_type
tag = (self.number << 3) | (self.wire_type & 0b111)
# Calculate the varint size
if tag < 128:
return 1 # 7 bits
if tag < 16384:
return 2 # 14 bits
if tag < 2097152:
return 3 # 21 bits
if tag < 268435456:
return 4 # 28 bits
return 5 # 32 bits (maximum for uint32_t)
def _get_simple_size_calculation(
self, name: str, force: bool, base_method: str, value_expr: str = None
) -> str:
"""Helper for simple size calculations using static ProtoSize methods.
Args:
name: Field name
force: Whether this is for a repeated field
base_method: Base method name (e.g., "int32")
value_expr: Optional value expression (defaults to name)
"""
field_id_size = self.calculate_field_id_size()
method = f"calc_{base_method}_force" if force else f"calc_{base_method}"
# calc_bool_force only takes field_id_size (no value needed - bool is always 1 byte)
if base_method == "bool" and force:
return f"size += ProtoSize::{method}({field_id_size});"
value = value_expr or name
return f"size += ProtoSize::{method}({field_id_size}, {value});"
@abstractmethod
def get_size_calculation(self, name: str, force: bool = False) -> str:
"""Calculate the size needed for encoding this field.
Args:
name: The name of the field
force: Whether to force encoding the field even if it has a default value
"""
def get_fixed_size_bytes(self) -> int | None:
"""Get the number of bytes for fixed-size fields (float, double, fixed32, etc).
Returns:
The number of bytes (4 or 8) for fixed-size fields, None for variable-size fields.
"""
return None
@abstractmethod
def get_estimated_size(self) -> int:
"""Get estimated size in bytes for this field with typical values.
Returns:
Estimated size in bytes including field ID and typical data
"""
TYPE_INFO: dict[int, TypeInfo] = {}
# Unsupported 64-bit types that would add overhead for embedded systems
# TYPE_DOUBLE = 1, TYPE_FIXED64 = 6, TYPE_SFIXED64 = 16, TYPE_SINT64 = 18
UNSUPPORTED_TYPES = {1: "double", 6: "fixed64", 16: "sfixed64", 18: "sint64"}
def validate_field_type(field_type: int, field_name: str = "") -> None:
"""Validate that the field type is supported by ESPHome API.
Raises ValueError for unsupported 64-bit types.
"""
if field_type in UNSUPPORTED_TYPES:
type_name = UNSUPPORTED_TYPES[field_type]
field_info = f" (field: {field_name})" if field_name else ""
raise ValueError(
f"64-bit type '{type_name}'{field_info} is not supported by ESPHome API. "
"These types add significant overhead for embedded systems. "
"If you need 64-bit support, please add the necessary encoding/decoding "
"functions to proto.h/proto.cpp first."
)
def create_field_type_info(
field: descriptor.FieldDescriptorProto,
needs_decode: bool = True,
needs_encode: bool = True,
) -> TypeInfo:
"""Create the appropriate TypeInfo instance for a field, handling repeated fields and custom options."""
if field.label == FieldDescriptorProto.LABEL_REPEATED:
# Check if this is a packed_buffer field (zero-copy packed repeated)
if get_field_opt(field, pb.packed_buffer, False):
return PackedBufferTypeInfo(field)
# Check if this repeated field has fixed_array_with_length_define option
if (
fixed_size := get_field_opt(field, pb.fixed_array_with_length_define)
) is not None:
return FixedArrayWithLengthRepeatedType(field, fixed_size)
# Check if this repeated field has fixed_array_size option
if (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None:
return FixedArrayRepeatedType(field, fixed_size)
# Check if this repeated field has fixed_array_size_define option
if (
size_define := get_field_opt(field, pb.fixed_array_size_define)
) is not None:
return FixedArrayRepeatedType(field, size_define)
return RepeatedTypeInfo(field)
# Special handling for bytes fields
if field.type == 12:
fixed_size = get_field_opt(field, pb.fixed_array_size, None)
if fixed_size is not None:
# Traditional fixed array approach with copy (takes priority)
return FixedArrayBytesType(field, fixed_size)
# For messages that decode (SOURCE_CLIENT or SOURCE_BOTH), use pointer
# for zero-copy access to the receive buffer
if needs_decode:
return PointerToBytesBufferType(field, None)
# For SOURCE_SERVER (encode only), explicit annotation is still needed
if get_field_opt(field, pb.pointer_to_buffer, False):
return PointerToBytesBufferType(field, None)
return BytesType(field, needs_decode, needs_encode)
# Special handling for string fields - use StringRef for zero-copy
if field.type == 9:
return PointerToStringBufferType(field, None)
validate_field_type(field.type, field.name)
return TYPE_INFO[field.type](field)
def register_type(name: int):
"""Decorator to register a type with a name and number."""
def func(value: TypeInfo) -> TypeInfo:
"""Register the type with the given name and number."""
TYPE_INFO[name] = value
return value
return func
@register_type(1)
class DoubleType(TypeInfo):
cpp_type = "double"
default_value = "0.0"
decode_64bit = "value.as_double()"
encode_func = "encode_double"
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%g", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
if force:
return f"size += {field_id_size + self.get_fixed_size_bytes()};"
return f"size += ProtoSize::calc_fixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes for double
@register_type(2)
class FloatType(TypeInfo):
cpp_type = "float"
default_value = "0.0f"
decode_32bit = "value.as_float()"
encode_func = "encode_float"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%g", {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
if force:
return f"size += {field_id_size + self.get_fixed_size_bytes()};"
return f"size += ProtoSize::calc_float({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 4 # field ID + 4 bytes for float
@register_type(3)
class Int64Type(TypeInfo):
cpp_type = "int64_t"
default_value = "0"
decode_varint = "static_cast<int64_t>(value)"
encode_func = "encode_int64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRId64, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "int64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@register_type(4)
class UInt64Type(TypeInfo):
cpp_type = "uint64_t"
default_value = "0"
decode_varint = "value"
encode_func = "encode_uint64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRIu64, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "uint64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@register_type(5)
class Int32Type(TypeInfo):
cpp_type = "int32_t"
default_value = "0"
decode_varint = "static_cast<int32_t>(value)"
encode_func = "encode_int32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "int32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@register_type(6)
class Fixed64Type(TypeInfo):
cpp_type = "uint64_t"
default_value = "0"
decode_64bit = "value.as_fixed64()"
encode_func = "encode_fixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRIu64, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
if force:
return f"size += {field_id_size + self.get_fixed_size_bytes()};"
return f"size += ProtoSize::calc_fixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes fixed
@register_type(7)
class Fixed32Type(TypeInfo):
cpp_type = "uint32_t"
default_value = "0"
decode_32bit = "value.as_fixed32()"
encode_func = "encode_fixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRIu32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
if force:
return f"size += {field_id_size + self.get_fixed_size_bytes()};"
return f"size += ProtoSize::calc_fixed32({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 4 # field ID + 4 bytes fixed
@register_type(8)
class BoolType(TypeInfo):
cpp_type = "bool"
default_value = "false"
decode_varint = "value != 0"
encode_func = "encode_bool"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
return f"out.append(YESNO({name}));"
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "bool")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 1 # field ID + 1 byte
@register_type(9)
class StringType(TypeInfo):
cpp_type = "std::string"
default_value = ""
reference_type = "std::string &"
const_reference_type = "const std::string &"
decode_length = "value.as_string()"
encode_func = "encode_string"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property
def public_content(self) -> list[str]:
content: list[str] = []
# Add std::string storage if message needs decoding
if self._needs_decode:
content.append(f"std::string {self.field_name}{{}};")
# Add StringRef if encoding is needed
if self._needs_encode:
content.extend(
[
# Add StringRef field if message needs encoding
f"StringRef {self.field_name}_ref_{{}};",
# Add setter method if message needs encoding
f"void set_{self.field_name}(const StringRef &ref) {{",
f" this->{self.field_name}_ref_ = ref;",
"}",
]
)
return content
@property
def encode_content(self) -> str:
# Use the StringRef
if self.force:
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_, true);"
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_);"
def dump(self, name):
# If name is 'it', this is a repeated field element - always use string
if name == "it":
return "append_quoted_string(out, StringRef(it));"
# For SOURCE_CLIENT only, always use std::string
if not self._needs_encode:
return f'out.append("\'").append(this->{self.field_name}).append("\'");'
# For SOURCE_SERVER, always use StringRef
if not self._needs_decode:
return f"append_quoted_string(out, this->{self.field_name}_ref_);"
# For SOURCE_BOTH, check if StringRef is set (sending) or use string (received)
return (
f"if (!this->{self.field_name}_ref_.empty()) {{"
f' out.append("\'").append(this->{self.field_name}_ref_.c_str(), this->{self.field_name}_ref_.size()).append("\'");'
f"}} else {{"
f' out.append("\'").append(this->{self.field_name}).append("\'");'
f"}}"
)
@property
def dump_content(self) -> str:
# For SOURCE_CLIENT only, use std::string
if not self._needs_encode:
return f'dump_field(out, "{self.name}", this->{self.field_name});'
# For SOURCE_SERVER, use StringRef with _ref_ suffix
if not self._needs_decode:
return f'dump_field(out, "{self.name}", this->{self.field_name}_ref_);'
# For SOURCE_BOTH, we need custom logic
o = f'out.append(" {self.name}: ");\n'
o += self.dump(f"this->{self.field_name}") + "\n"
o += 'out.append("\\n");'
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
# For SOURCE_CLIENT only messages, use the string field directly
if not self._needs_encode:
return self._get_simple_size_calculation(name, force, "length")
# Check if this is being called from a repeated field context
# In that case, 'name' will be 'it' and we need to use the repeated version
if name == "it":
# For repeated fields, we need to use length_force which includes field ID
field_id_size = self.calculate_field_id_size()
return f"size += ProtoSize::calc_length_force({field_id_size}, it.size());"
# For messages that need encoding, use the StringRef size
field_id_size = self.calculate_field_id_size()
return f"size += ProtoSize::calc_length({field_id_size}, this->{self.field_name}_ref_.size());"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
@register_type(11)
class MessageType(TypeInfo):
@classmethod
def can_use_dump_field(cls) -> bool:
return False
@property
def cpp_type(self) -> str:
return self._field.type_name[1:]
default_value = ""
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property
def reference_type(self) -> str:
return f"{self.cpp_type} &"
@property
def const_reference_type(self) -> str:
return f"const {self.cpp_type} &"
@property
def encode_func(self) -> str:
return "encode_optional_sub_message"
@property
def encode_content(self) -> str:
# encode_sub_message always encodes (uses backpatch), no force needed
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
@property
def decode_length(self) -> str:
# Override to return None for message types because we can't use template-based
# decoding when the specific message type isn't known at compile time.
# Instead, we use the non-template decode_to_message() method which allows
# runtime polymorphism through virtual function calls.
return None
@property
def decode_length_content(self) -> str:
# Custom decode that doesn't use templates
return f"case {self.number}: value.decode_to_message(this->{self.field_name}); break;"
def dump(self, name: str) -> str:
return f"{name}.dump_to(out);"
@property
def dump_content(self) -> str:
o = f'out.append(" {self.name}: ");\n'
o += f"this->{self.field_name}.dump_to(out);\n"
o += 'out.append("\\n");'
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
method = "calc_message_force" if force else "calc_message"
return f"size += ProtoSize::{method}({field_id_size}, {name}.calculate_size());"
def get_estimated_size(self) -> int:
# For message types, we can't easily estimate the submessage size without
# access to the actual message definition. This is just a rough estimate.
return (
self.calculate_field_id_size() + 16
) # field ID + 16 bytes estimated submessage
@register_type(12)
class BytesType(TypeInfo):
@classmethod
def can_use_dump_field(cls) -> bool:
return False
cpp_type = "std::string"
default_value = ""
reference_type = "std::string &"
const_reference_type = "const std::string &"
encode_func = "encode_bytes"
decode_length = "value.as_string()"
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
@property
def public_content(self) -> list[str]:
content: list[str] = []
# Add std::string storage if message needs decoding
if self._needs_decode:
content.append(f"std::string {self.field_name}{{}};")
if self._needs_encode:
content.extend(
[
# Add pointer/length fields if message needs encoding
f"const uint8_t* {self.field_name}_ptr_{{nullptr}};",
f"size_t {self.field_name}_len_{{0}};",
# Add setter method if message needs encoding
f"void set_{self.field_name}(const uint8_t* data, size_t len) {{",
f" this->{self.field_name}_ptr_ = data;",
f" this->{self.field_name}_len_ = len;",
"}",
]
)
return content
@property
def encode_content(self) -> str:
if self.force:
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_, true);"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
def dump(self, name: str) -> str:
ptr_dump = f"format_hex_pretty(this->{self.field_name}_ptr_, this->{self.field_name}_len_)"
str_dump = f"format_hex_pretty(reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), this->{self.field_name}.size())"
# For SOURCE_CLIENT only, always use std::string
if not self._needs_encode:
return f"out.append({str_dump});"
# For SOURCE_SERVER, always use pointer/length
if not self._needs_decode:
return f"out.append({ptr_dump});"
# For SOURCE_BOTH, check if pointer is set (sending) or use string (received)
return (
f"if (this->{self.field_name}_ptr_ != nullptr) {{\n"
f" out.append({ptr_dump});\n"
f" }} else {{\n"
f" out.append({str_dump});\n"
f" }}"
)
@property
def dump_content(self) -> str:
# For SOURCE_CLIENT only, always use std::string
if not self._needs_encode:
return (
f'dump_bytes_field(out, "{self.name}", '
f"reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), "
f"this->{self.field_name}.size());"
)
# For SOURCE_SERVER, always use pointer/length
if not self._needs_decode:
return (
f'dump_bytes_field(out, "{self.name}", '
f"this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
)
# For SOURCE_BOTH, check if pointer is set (sending) or use string (received)
return (
f"if (this->{self.field_name}_ptr_ != nullptr) {{\n"
f' dump_bytes_field(out, "{self.name}", '
f"this->{self.field_name}_ptr_, this->{self.field_name}_len_);\n"
f"}} else {{\n"
f' dump_bytes_field(out, "{self.name}", '
f"reinterpret_cast<const uint8_t*>(this->{self.field_name}.data()), "
f"this->{self.field_name}.size());\n"
f"}}"
)
def get_size_calculation(self, name: str, force: bool = False) -> str:
return f"size += ProtoSize::calc_length({self.calculate_field_id_size()}, this->{self.field_name}_len_);"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical bytes
class PointerToBufferTypeBase(TypeInfo):
"""Base class for pointer_to_buffer types (bytes and strings) for zero-copy decoding."""
@classmethod
def can_use_dump_field(cls) -> bool:
return False
def __init__(
self, field: descriptor.FieldDescriptorProto, size: int | None = None
) -> None:
super().__init__(field)
self.array_size = 0
@property
def decode_length(self) -> str | None:
# This is handled in decode_length_content
return None
@property
def wire_type(self) -> WireType:
"""Get the wire type for this field."""
return WireType.LENGTH_DELIMITED # Uses wire type 2
def get_estimated_size(self) -> int:
# field ID + length varint + typical data (assume small for pointer fields)
return self.calculate_field_id_size() + 2 + 16
class PointerToBytesBufferType(PointerToBufferTypeBase):
"""Type for bytes fields that use pointer_to_buffer option for zero-copy."""
cpp_type = "const uint8_t*"
default_value = "nullptr"
reference_type = "const uint8_t*"
const_reference_type = "const uint8_t*"
@property
def public_content(self) -> list[str]:
# Use uint16_t for length - max packet size is well below 65535
return [
f"const uint8_t* {self.field_name}{{nullptr}};",
f"uint16_t {self.field_name}_len{{0}};",
]
@property
def encode_content(self) -> str:
if self.force:
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
@property
def decode_length_content(self) -> str | None:
return f"""case {self.number}: {{
this->{self.field_name} = value.data();
this->{self.field_name}_len = value.size();
break;
}}"""
def dump(self, name: str) -> str:
return (
f"format_hex_pretty(this->{self.field_name}, this->{self.field_name}_len)"
)
@property
def dump_content(self) -> str:
return (
f'dump_bytes_field(out, "{self.name}", '
f"this->{self.field_name}, this->{self.field_name}_len);"
)
def get_size_calculation(self, name: str, force: bool = False) -> str:
return f"size += ProtoSize::calc_length({self.calculate_field_id_size()}, this->{self.field_name}_len);"
class PointerToStringBufferType(PointerToBufferTypeBase):
"""Type for string fields that use pointer_to_buffer option for zero-copy.
Uses StringRef instead of separate pointer and length fields.
"""
cpp_type = "StringRef"
default_value = ""
reference_type = "StringRef &"
const_reference_type = "const StringRef &"
@classmethod
def can_use_dump_field(cls) -> bool:
return True
@property
def public_content(self) -> list[str]:
return [f"StringRef {self.field_name}{{}};"]
@property
def encode_content(self) -> str:
if self.force:
return (
f"buffer.encode_string({self.number}, this->{self.field_name}, true);"
)
return f"buffer.encode_string({self.number}, this->{self.field_name});"
@property
def decode_length_content(self) -> str | None:
return f"""case {self.number}: {{
this->{self.field_name} = StringRef(reinterpret_cast<const char *>(value.data()), value.size());
break;
}}"""
def dump(self, name: str) -> str:
# Not used since we use dump_field, but required by abstract base class
return f'out.append("\'").append({name}.c_str(), {name}.size()).append("\'");'
@property
def dump_content(self) -> str:
return f'dump_field(out, "{self.name}", this->{self.field_name});'
def get_size_calculation(self, name: str, force: bool = False) -> str:
return f"size += ProtoSize::calc_length({self.calculate_field_id_size()}, this->{self.field_name}.size());"
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes typical string
class PackedBufferTypeInfo(TypeInfo):
"""Type for packed repeated fields that expose raw buffer instead of decoding.
When a repeated field is marked with [(packed_buffer) = true], this type
generates code that stores a pointer to the raw protobuf buffer along with
its length and the count of values. This enables zero-copy passthrough when
the consumer can decode the packed varints on-demand.
"""
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
# packed_buffer is decode-only (SOURCE_CLIENT messages)
super().__init__(field, needs_decode=True, needs_encode=False)
@property
def cpp_type(self) -> str:
# Not used - we have multiple fields
return "const uint8_t*"
@property
def wire_type(self) -> WireType:
"""Packed fields use LENGTH_DELIMITED wire type."""
return WireType.LENGTH_DELIMITED
@property
def public_content(self) -> list[str]:
"""Generate three fields: data pointer, length, and count."""
return [
f"const uint8_t *{self.field_name}_data_{{nullptr}};",
f"uint16_t {self.field_name}_length_{{0}};",
f"uint16_t {self.field_name}_count_{{0}};",
]
@property
def decode_length_content(self) -> str:
"""Store pointer to buffer and calculate count of packed varints."""
return f"""case {self.number}: {{
this->{self.field_name}_data_ = value.data();
this->{self.field_name}_length_ = value.size();
this->{self.field_name}_count_ = count_packed_varints(value.data(), value.size());
break;
}}"""
@property
def encode_content(self) -> str:
"""No encoding - this is decode-only for SOURCE_CLIENT messages."""
return None
@property
def dump_content(self) -> str:
"""Dump shows buffer info but not decoded values."""
return (
f'out.append(" {self.name}: ");\n'
+ 'out.append("packed buffer [");\n'
+ f"append_uint(out, this->{self.field_name}_count_);\n"
+ 'out.append(" values, ");\n'
+ f"append_uint(out, this->{self.field_name}_length_);\n"
+ 'out.append(" bytes]\\n");'
)
def dump(self, name: str) -> str:
"""Dump method for packed buffer - not typically used but required by abstract base."""
return 'out.append("packed buffer");'
def get_size_calculation(self, name: str, force: bool = False) -> str:
"""No size calculation needed - decode-only."""
return ""
def get_estimated_size(self) -> int:
"""Estimate size for packed buffer field.
Typical IR/RF timing array has ~50-200 values, each encoded as 1-3 bytes.
Estimate 100 values * 2 bytes = 200 bytes typical.
"""
return (
self.calculate_field_id_size() + 2 + 200
) # field ID + length varint + data
@classmethod
def can_use_dump_field(cls) -> bool:
return False
class FixedArrayBytesType(TypeInfo):
"""Special type for fixed-size byte arrays."""
@classmethod
def can_use_dump_field(cls) -> bool:
return False
def __init__(self, field: descriptor.FieldDescriptorProto, size: int) -> None:
super().__init__(field)
self.array_size = size
@property
def cpp_type(self) -> str:
return "uint8_t"
@property
def default_value(self) -> str:
return "{}"
@property
def reference_type(self) -> str:
return f"uint8_t (&)[{self.array_size}]"
@property
def const_reference_type(self) -> str:
return f"const uint8_t (&)[{self.array_size}]"
@property
def public_content(self) -> list[str]:
len_type = (
"uint8_t"
if self.array_size <= 255
else "uint16_t"
if self.array_size <= 65535
else "size_t"
)
# Add both the array and length fields
return [
f"uint8_t {self.field_name}[{self.array_size}]{{}};",
f"{len_type} {self.field_name}_len{{0}};",
]
@property
def decode_length_content(self) -> str:
o = f"case {self.number}: {{\n"
o += " const std::string &data_str = value.as_string();\n"
o += f" this->{self.field_name}_len = data_str.size();\n"
o += f" if (this->{self.field_name}_len > {self.array_size}) {{\n"
o += f" this->{self.field_name}_len = {self.array_size};\n"
o += " }\n"
o += f" memcpy(this->{self.field_name}, data_str.data(), this->{self.field_name}_len);\n"
o += " break;\n"
o += "}"
return o
@property
def encode_content(self) -> str:
if self.force:
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
def dump(self, name: str) -> str:
return f"out.append(format_hex_pretty({name}, {name}_len));"
@property
def dump_content(self) -> str:
return (
f'dump_bytes_field(out, "{self.name}", '
f"this->{self.field_name}, this->{self.field_name}_len);"
)
def get_size_calculation(self, name: str, force: bool = False) -> str:
# Use the actual length stored in the _len field
length_field = f"this->{self.field_name}_len"
field_id_size = self.calculate_field_id_size()
if force:
# For repeated fields, always calculate size (no zero check)
return f"size += ProtoSize::calc_length_force({field_id_size}, {length_field});"
# For non-repeated fields, length already checks for zero
return f"size += ProtoSize::calc_length({field_id_size}, {length_field});"
def get_estimated_size(self) -> int:
# Estimate based on typical BLE advertisement size
return (
self.calculate_field_id_size() + 1 + 31
) # field ID + length byte + typical 31 bytes
@property
def wire_type(self) -> WireType:
return WireType.LENGTH_DELIMITED
@register_type(13)
class UInt32Type(TypeInfo):
cpp_type = "uint32_t"
default_value = "0"
decode_varint = "value"
encode_func = "encode_uint32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRIu32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "uint32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@register_type(14)
class EnumType(TypeInfo):
@property
def cpp_type(self) -> str:
return f"enums::{self._field.type_name[1:]}"
@property
def decode_varint(self) -> str:
return f"static_cast<{self.cpp_type}>(value)"
default_value = ""
wire_type = WireType.VARINT # Uses wire type 0
@property
def encode_func(self) -> str:
return "encode_uint32"
@property
def encode_content(self) -> str:
if self.force:
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}), true);"
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}));"
def dump(self, name: str) -> str:
return f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
def dump_field_value(self, value: str) -> str:
# Enums need explicit cast for the template
return f"static_cast<{self.cpp_type}>({value})"
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(
name, force, "uint32", f"static_cast<uint32_t>({name})"
)
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 1 # field ID + 1 byte typical enum
@register_type(15)
class SFixed32Type(TypeInfo):
cpp_type = "int32_t"
default_value = "0"
decode_32bit = "value.as_sfixed32()"
encode_func = "encode_sfixed32"
wire_type = WireType.FIXED32 # Uses wire type 5
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
if force:
return f"size += {field_id_size + self.get_fixed_size_bytes()};"
return f"size += ProtoSize::calc_sfixed32({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 4
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 4 # field ID + 4 bytes fixed
@register_type(16)
class SFixed64Type(TypeInfo):
cpp_type = "int64_t"
default_value = "0"
decode_64bit = "value.as_sfixed64()"
encode_func = "encode_sfixed64"
wire_type = WireType.FIXED64 # Uses wire type 1
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRId64, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
field_id_size = self.calculate_field_id_size()
if force:
return f"size += {field_id_size + self.get_fixed_size_bytes()};"
return f"size += ProtoSize::calc_sfixed64({field_id_size}, {name});"
def get_fixed_size_bytes(self) -> int:
return 8
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 8 # field ID + 8 bytes fixed
@register_type(17)
class SInt32Type(TypeInfo):
cpp_type = "int32_t"
default_value = "0"
decode_varint = "decode_zigzag32(static_cast<uint32_t>(value))"
encode_func = "encode_sint32"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRId32, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "sint32")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
@register_type(18)
class SInt64Type(TypeInfo):
cpp_type = "int64_t"
default_value = "0"
decode_varint = "decode_zigzag64(value)"
encode_func = "encode_sint64"
wire_type = WireType.VARINT # Uses wire type 0
def dump(self, name: str) -> str:
o = f'snprintf(buffer, sizeof(buffer), "%" PRId64, {name});\n'
o += "out.append(buffer);"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
return self._get_simple_size_calculation(name, force, "sint64")
def get_estimated_size(self) -> int:
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
def _generate_array_dump_content(
ti,
field_name: str,
name: str,
is_bool: bool = False,
is_const_char_ptr: bool = False,
) -> str:
"""Generate dump content for array types (repeated or fixed array).
Shared helper to avoid code duplication between RepeatedTypeInfo and FixedArrayRepeatedType.
"""
o = f"for (const auto {'' if is_bool else '&'}it : {field_name}) {{\n"
# Check if underlying type can use dump_field
if is_const_char_ptr:
# Special case for const char* - use it directly
o += f' dump_field(out, "{name}", it, 4);\n'
elif ti.can_use_dump_field():
# For types that have dump_field overloads, use them with extra indent
# std::vector<bool> iterators return proxy objects, need explicit cast
value_expr = "static_cast<bool>(it)" if is_bool else ti.dump_field_value("it")
o += f' dump_field(out, "{name}", {value_expr}, 4);\n'
else:
# For complex types (messages, bytes), use the old pattern
o += f' out.append(" {name}: ");\n'
o += indent(ti.dump("it")) + "\n"
o += ' out.append("\\n");\n'
o += "}"
return o
class FixedArrayRepeatedType(TypeInfo):
"""Special type for fixed-size repeated fields using std::array.
Fixed arrays are only supported for encoding (SOURCE_SERVER) since we cannot
control how many items we receive when decoding.
"""
def __init__(self, field: descriptor.FieldDescriptorProto, size: int | str) -> None:
super().__init__(field)
self.array_size = size
self.is_define = isinstance(size, str)
# Check if we should skip encoding when all elements are zero
# Use getattr to handle older versions of api_options_pb2
self.skip_zero = get_field_opt(
field, getattr(pb, "fixed_array_skip_zero", None), False
)
# Create the element type info
validate_field_type(field.type, field.name)
self._ti: TypeInfo = TYPE_INFO[field.type](field)
def _encode_element(self, element: str) -> str:
"""Helper to generate encode statement for a single element."""
if isinstance(self._ti, EnumType):
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
# Repeated message elements use encode_sub_message (force=true is default)
if isinstance(self._ti, MessageType):
return f"buffer.encode_sub_message({self.number}, {element});"
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
@property
def cpp_type(self) -> str:
return f"std::array<{self._ti.cpp_type}, {self.array_size}>"
@property
def reference_type(self) -> str:
return f"{self.cpp_type} &"
@property
def const_reference_type(self) -> str:
return f"const {self.cpp_type} &"
@property
def wire_type(self) -> WireType:
"""Get the wire type for this fixed array field."""
return self._ti.wire_type
@property
def public_content(self) -> list[str]:
# Just the array member, no index needed since we don't decode
return [f"{self.cpp_type} {self.field_name}{{}};"]
# No decode methods needed - fixed arrays don't support decoding
# The base class TypeInfo already returns None for all decode properties
@property
def encode_content(self) -> str:
# If skip_zero is enabled, wrap encoding in a zero check
if self.skip_zero:
if self.is_define:
# When using a define, we need to use a loop-based approach
o = f"for (const auto &it : this->{self.field_name}) {{\n"
o += " if (it != 0) {\n"
o += f" {self._encode_element('it')}\n"
o += " }\n"
o += "}"
return o
# Build the condition to check if at least one element is non-zero
non_zero_checks = " || ".join(
[f"this->{self.field_name}[{i}] != 0" for i in range(self.array_size)]
)
encode_lines = [
f" {self._encode_element(f'this->{self.field_name}[{i}]')}"
for i in range(self.array_size)
]
return f"if ({non_zero_checks}) {{\n" + "\n".join(encode_lines) + "\n}"
# When using a define, always use loop-based approach
if self.is_define:
o = f"for (const auto &it : this->{self.field_name}) {{\n"
o += f" {self._encode_element('it')}\n"
o += "}"
return o
# Unroll small arrays for efficiency
if self.array_size == 1:
return self._encode_element(f"this->{self.field_name}[0]")
if self.array_size == 2:
return (
self._encode_element(f"this->{self.field_name}[0]")
+ "\n "
+ self._encode_element(f"this->{self.field_name}[1]")
)
# Use loops for larger arrays
o = f"for (const auto &it : this->{self.field_name}) {{\n"
o += f" {self._encode_element('it')}\n"
o += "}"
return o
@property
def dump_content(self) -> str:
return _generate_array_dump_content(
self._ti, f"this->{self.field_name}", self.name, is_bool=False
)
def dump(self, name: str) -> str:
# This is used when dumping the array itself (not its elements)
# Since dump_content handles the iteration, this is not used directly
return ""
def get_size_calculation(self, name: str, force: bool = False) -> str:
# If skip_zero is enabled, wrap size calculation in a zero check
if self.skip_zero:
if self.is_define:
# When using a define, we need to use a loop-based approach
o = f"for (const auto &it : {name}) {{\n"
o += " if (it != 0) {\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += " }\n"
o += "}"
return o
# Build the condition to check if at least one element is non-zero
non_zero_checks = " || ".join(
[f"{name}[{i}] != 0" for i in range(self.array_size)]
)
size_lines = [
f" {self._ti.get_size_calculation(f'{name}[{i}]', True)}"
for i in range(self.array_size)
]
return f"if ({non_zero_checks}) {{\n" + "\n".join(size_lines) + "\n}"
# When using a define, always use loop-based approach
if self.is_define:
o = f"for (const auto &it : {name}) {{\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += "}"
return o
# For fixed arrays, we always encode all elements
# Special case for single-element arrays - no loop needed
if self.array_size == 1:
return self._ti.get_size_calculation(f"{name}[0]", True)
# Special case for 2-element arrays - unroll the calculation
if self.array_size == 2:
return (
self._ti.get_size_calculation(f"{name}[0]", True)
+ "\n "
+ self._ti.get_size_calculation(f"{name}[1]", True)
)
# Use loops for larger arrays
o = f"for (const auto &it : {name}) {{\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += "}"
return o
def get_estimated_size(self) -> int:
# For fixed arrays, estimate underlying type size * array size
underlying_size = self._ti.get_estimated_size()
if self.is_define:
# When using a define, we don't know the actual size so just guess 3
# This is only used for documentation and never actually used since
# fixed arrays are only for SOURCE_SERVER (encode-only) messages
return underlying_size * 3
return underlying_size * self.array_size
class FixedArrayWithLengthRepeatedType(FixedArrayRepeatedType):
"""Special type for fixed-size repeated fields with variable length tracking.
Similar to FixedArrayRepeatedType but generates an additional length field
to track how many elements are actually in use. Only encodes/sends elements
up to the current length.
Fixed arrays with length are only supported for encoding (SOURCE_SERVER) since
we cannot control how many items we receive when decoding.
"""
@property
def public_content(self) -> list[str]:
# Return both the array and the length field
return [
f"{self.cpp_type} {self.field_name}{{}};",
f"uint16_t {self.field_name}_len{{0}};",
]
@property
def encode_content(self) -> str:
# Always use a loop up to the current length
o = f"for (uint16_t i = 0; i < this->{self.field_name}_len; i++) {{\n"
o += f" {self._encode_element(f'this->{self.field_name}[i]')}\n"
o += "}"
return o
@property
def dump_content(self) -> str:
# Dump only the active elements
o = f"for (uint16_t i = 0; i < this->{self.field_name}_len; i++) {{\n"
# Check if underlying type can use dump_field
if self._ti.can_use_dump_field():
o += f' dump_field(out, "{self.name}", {self._ti.dump_field_value(f"this->{self.field_name}[i]")}, 4);\n'
else:
o += f' out.append(" {self.name}: ");\n'
o += indent(self._ti.dump(f"this->{self.field_name}[i]")) + "\n"
o += ' out.append("\\n");\n'
o += "}"
return o
def get_size_calculation(self, name: str, force: bool = False) -> str:
# Calculate size only for active elements
o = f"for (uint16_t i = 0; i < {name}_len; i++) {{\n"
o += f" {self._ti.get_size_calculation(f'{name}[i]', True)}\n"
o += "}"
return o
def get_estimated_size(self) -> int:
# For fixed arrays with length, estimate based on typical usage
# Assume on average half the array is used
underlying_size = self._ti.get_estimated_size()
if self.is_define:
# When using a define, estimate 8 elements as typical
return underlying_size * 8
return underlying_size * (
self.array_size // 2 if self.array_size > 2 else self.array_size
)
class RepeatedTypeInfo(TypeInfo):
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
super().__init__(field)
# Check if this is a pointer field by looking for container_pointer option
self._container_type = get_field_opt(field, pb.container_pointer, "")
# Check for non-template container pointer
self._container_no_template = get_field_opt(
field, pb.container_pointer_no_template, ""
)
self._use_pointer = bool(self._container_type) or bool(
self._container_no_template
)
# Check if this should use FixedVector instead of std::vector
self._use_fixed_vector = get_field_opt(field, pb.fixed_vector, False)
# For repeated fields, we need to get the base type info
# but we can't call create_field_type_info as it would cause recursion
# So we extract just the type creation logic
if (
field.type == 12
and (fixed_size := get_field_opt(field, pb.fixed_array_size)) is not None
):
self._ti: TypeInfo = FixedArrayBytesType(field, fixed_size)
return
validate_field_type(field.type, field.name)
self._ti: TypeInfo = TYPE_INFO[field.type](field)
@property
def cpp_type(self) -> str:
if self._container_no_template:
# Non-template container: use type as-is without appending template parameters
return f"const {self._container_no_template}*"
if self._use_pointer and self._container_type:
# For pointer fields, use the specified container type
# Two cases:
# 1. "std::set<climate::ClimateMode>" - Full type with template params, use as-is
# 2. "std::set" - No <>, append the element type
if "<" in self._container_type and ">" in self._container_type:
# Has template parameters specified, use as-is
return f"const {self._container_type}*"
# No <> at all, append element type
return f"const {self._container_type}<{self._ti.cpp_type}>*"
if self._use_fixed_vector:
return f"FixedVector<{self._ti.cpp_type}>"
return f"std::vector<{self._ti.cpp_type}>"
@property
def reference_type(self) -> str:
return f"{self.cpp_type} &"
@property
def const_reference_type(self) -> str:
return f"const {self.cpp_type} &"
@property
def wire_type(self) -> WireType:
"""Get the wire type for this repeated field.
For repeated fields, we use the same wire type as the underlying field.
"""
return self._ti.wire_type
@property
def decode_varint_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_varint
if content is None:
return None
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
def decode_length_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_length
if content is None and isinstance(self._ti, MessageType):
# Special handling for non-template message decoding
return f"case {self.number}: this->{self.field_name}.emplace_back(); value.decode_to_message(this->{self.field_name}.back()); break;"
if content is None:
return None
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
def decode_32bit_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_32bit
if content is None:
return None
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
def decode_64bit_content(self) -> str:
# Pointer fields don't support decoding
if self._use_pointer:
return None
content = self._ti.decode_64bit
if content is None:
return None
return (
f"case {self.number}: this->{self.field_name}.push_back({content}); break;"
)
@property
def _ti_is_bool(self) -> bool:
# std::vector is specialized for bool, reference does not work
return isinstance(self._ti, BoolType)
def _encode_element_call(self, element: str) -> str:
"""Helper to generate encode call for a single element."""
if isinstance(self._ti, EnumType):
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
# Repeated message elements use encode_sub_message (force=true is default)
if isinstance(self._ti, MessageType):
return f"buffer.encode_sub_message({self.number}, {element});"
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
@property
def encode_content(self) -> str:
if self._use_pointer:
# For pointer fields, just dereference (pointer should never be null in our use case)
# Special handling for const char* elements (when container_no_template contains "const char")
if "const char" in self._container_no_template:
o = f"for (const char *it : *this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, strlen(it), true);\n"
else:
o = f"for (const auto &it : *this->{self.field_name}) {{\n"
o += f" {self._encode_element_call('it')}\n"
o += "}"
return o
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f" {self._encode_element_call('it')}\n"
o += "}"
return o
@property
def dump_content(self) -> str:
# Check if this is const char* elements
is_const_char_ptr = (
self._use_pointer and "const char" in self._container_no_template
)
if self._use_pointer:
# For pointer fields, dereference and use the existing helper
return _generate_array_dump_content(
self._ti,
f"*this->{self.field_name}",
self.name,
is_bool=False,
is_const_char_ptr=is_const_char_ptr,
)
return _generate_array_dump_content(
self._ti, f"this->{self.field_name}", self.name, is_bool=self._ti_is_bool
)
def dump(self, _: str):
pass
def get_size_calculation(self, name: str, force: bool = False) -> str:
# For repeated fields, we always need to pass force=True to the underlying type's calculation
# This is because the encode method always sets force=true for repeated fields
# Handle message types separately - generate inline loop
if isinstance(self._ti, MessageType):
field_id_size = self._ti.calculate_field_id_size()
container_ref = f"*{name}" if self._use_pointer else name
empty_check = f"{name}->empty()" if self._use_pointer else f"{name}.empty()"
o = f"if (!{empty_check}) {{\n"
o += f" for (const auto &it : {container_ref}) {{\n"
o += f" size += ProtoSize::calc_message_force({field_id_size}, it.calculate_size());\n"
o += " }\n"
o += "}"
return o
# For non-message types, generate size calculation with iteration
container_ref = f"*{name}" if self._use_pointer else name
empty_check = f"{name}->empty()" if self._use_pointer else f"{name}.empty()"
o = f"if (!{empty_check}) {{\n"
# Check if this is a fixed-size type
num_bytes = self._ti.get_fixed_size_bytes()
if num_bytes is not None:
# Fixed types have constant size per element
field_id_size = self._ti.calculate_field_id_size()
bytes_per_element = field_id_size + num_bytes
size_expr = f"{name}->size()" if self._use_pointer else f"{name}.size()"
o += f" size += {size_expr} * {bytes_per_element};\n"
else:
# Other types need the actual value
# Special handling for const char* elements
if self._use_pointer and "const char" in self._container_no_template:
field_id_size = self.calculate_field_id_size()
o += f" for (const char *it : {container_ref}) {{\n"
o += f" size += ProtoSize::calc_length_force({field_id_size}, strlen(it));\n"
else:
auto_ref = "" if self._ti_is_bool else "&"
o += f" for (const auto {auto_ref}it : {container_ref}) {{\n"
o += f" {self._ti.get_size_calculation('it', True)}\n"
o += " }\n"
o += "}"
return o
def get_estimated_size(self) -> int:
# For repeated fields, estimate underlying type size * 2 (assume 2 items typically)
underlying_size = (
self._ti.get_estimated_size()
if hasattr(self._ti, "get_estimated_size")
else 8
)
return underlying_size * 2
def build_type_usage_map(
file_desc: descriptor.FileDescriptorProto,
) -> tuple[dict[str, str | None], dict[str, str | None], dict[str, int], set[str]]:
"""Build mappings for both enums and messages to their ifdefs based on usage.
Returns:
tuple: (enum_ifdef_map, message_ifdef_map, message_source_map, used_messages)
"""
enum_ifdef_map: dict[str, str | None] = {}
message_ifdef_map: dict[str, str | None] = {}
message_source_map: dict[str, int] = {}
# Build maps of which types are used by which messages
enum_usage: dict[
str, set[str]
] = {} # enum_name -> set of message names that use it
message_usage: dict[
str, set[str]
] = {} # message_name -> set of message names that use it
used_messages: set[str] = set() # Track which messages are actually used
# Build message name to ifdef mapping for quick lookup
message_to_ifdef: dict[str, str | None] = {
msg.name: get_opt(msg, pb.ifdef) for msg in file_desc.message_type
}
# Analyze field usage
# Also track field_ifdef for message types
message_field_ifdefs: dict[
str, set[str | None]
] = {} # message_name -> set of field_ifdefs that use it
for message in file_desc.message_type:
# Skip deprecated messages entirely
if message.options.deprecated:
continue
for field in message.field:
# Skip deprecated fields when tracking enum usage
if field.options.deprecated:
continue
type_name = field.type_name.split(".")[-1] if field.type_name else None
if not type_name:
continue
# Track enum usage (only from non-deprecated fields)
if field.type == 14: # TYPE_ENUM
enum_usage.setdefault(type_name, set()).add(message.name)
# Track message usage
elif field.type == 11: # TYPE_MESSAGE
message_usage.setdefault(type_name, set()).add(message.name)
# Also track the field_ifdef if present
field_ifdef = get_field_opt(field, pb.field_ifdef)
message_field_ifdefs.setdefault(type_name, set()).add(field_ifdef)
used_messages.add(type_name)
# Helper to get unique ifdef from a set of messages
def get_unique_ifdef(message_names: set[str]) -> str | None:
ifdefs: set[str] = {
message_to_ifdef[name]
for name in message_names
if message_to_ifdef.get(name)
}
return ifdefs.pop() if len(ifdefs) == 1 else None
# Build enum ifdef map
for enum in file_desc.enum_type:
if enum.name in enum_usage:
enum_ifdef_map[enum.name] = get_unique_ifdef(enum_usage[enum.name])
else:
enum_ifdef_map[enum.name] = None
# Build message ifdef map
for message in file_desc.message_type:
# Explicit ifdef takes precedence
explicit_ifdef = message_to_ifdef.get(message.name)
if explicit_ifdef:
message_ifdef_map[message.name] = explicit_ifdef
elif message.name in message_usage:
# Inherit ifdef if all parent messages have the same one
if parent_ifdef := get_unique_ifdef(message_usage[message.name]):
message_ifdef_map[message.name] = parent_ifdef
elif message.name in message_field_ifdefs:
# If no parent message ifdef, check if all fields using this message have the same field_ifdef
field_ifdefs = message_field_ifdefs[message.name] - {None}
message_ifdef_map[message.name] = (
field_ifdefs.pop() if len(field_ifdefs) == 1 else None
)
else:
message_ifdef_map[message.name] = None
else:
message_ifdef_map[message.name] = None
# Second pass: propagate ifdefs recursively
# Keep iterating until no more changes are made
changed = True
iterations = 0
while changed and iterations < 10: # Add safety limit
changed = False
iterations += 1
for message in file_desc.message_type:
# Skip if already has an ifdef
if message_ifdef_map.get(message.name):
continue
# Check if this message is used by other messages
if message.name not in message_usage:
continue
# Get ifdefs from all messages that use this one
parent_ifdefs: set[str] = {
message_ifdef_map.get(parent)
for parent in message_usage[message.name]
if message_ifdef_map.get(parent)
}
# If all parents have the same ifdef, inherit it
if len(parent_ifdefs) == 1 and None not in parent_ifdefs:
message_ifdef_map[message.name] = parent_ifdefs.pop()
changed = True
# Build message source map
# First pass: Get explicit sources for messages with source option or id
for msg in file_desc.message_type:
# Skip deprecated messages
if msg.options.deprecated:
continue
if msg.options.HasExtension(pb.source):
# Explicit source option takes precedence
message_source_map[msg.name] = get_opt(msg, pb.source, SOURCE_BOTH)
elif msg.options.HasExtension(pb.id):
# Service messages (with id) default to SOURCE_BOTH
message_source_map[msg.name] = SOURCE_BOTH
# Service messages are always used
used_messages.add(msg.name)
# Second pass: Determine sources for embedded messages based on their usage
for msg in file_desc.message_type:
if msg.name in message_source_map:
continue # Already has explicit source
if msg.name in message_usage:
# Get sources from all parent messages that use this one
parent_sources = {
message_source_map[parent]
for parent in message_usage[msg.name]
if parent in message_source_map
}
# Combine parent sources
if not parent_sources:
# No parent has explicit source, default to encode-only
message_source_map[msg.name] = SOURCE_SERVER
elif len(parent_sources) > 1:
# Multiple different sources or SOURCE_BOTH present
message_source_map[msg.name] = SOURCE_BOTH
else:
# Inherit single parent source
message_source_map[msg.name] = parent_sources.pop()
else:
# Not used by any message and no explicit source - default to encode-only
message_source_map[msg.name] = SOURCE_SERVER
return (
enum_ifdef_map,
message_ifdef_map,
message_source_map,
used_messages,
)
def get_varint64_ifdef(
file_desc: descriptor.FileDescriptorProto,
message_ifdef_map: dict[str, str | None],
) -> tuple[bool, str | None]:
"""Check if 64-bit varint fields exist and get their common ifdef guard.
Returns:
(has_varint64, ifdef_guard) - has_varint64 is True if any fields exist,
ifdef_guard is the common guard or None if unconditional.
"""
varint64_types = {
FieldDescriptorProto.TYPE_INT64,
FieldDescriptorProto.TYPE_UINT64,
FieldDescriptorProto.TYPE_SINT64,
}
ifdefs: set[str | None] = {
message_ifdef_map.get(msg.name)
for msg in file_desc.message_type
if not msg.options.deprecated
for field in msg.field
if not field.options.deprecated and field.type in varint64_types
}
if not ifdefs:
return False, None
if None in ifdefs:
# At least one 64-bit varint field is unconditional, so the guard must be unconditional.
return True, None
ifdefs.discard(None)
return True, ifdefs.pop() if len(ifdefs) == 1 else None
def build_enum_type(desc, enum_ifdef_map) -> tuple[str, str, str]:
"""Builds the enum type.
Args:
desc: The enum descriptor
enum_ifdef_map: Mapping of enum names to their ifdefs
Returns:
tuple: (header_content, cpp_content, dump_cpp_content)
"""
name = desc.name
out = f"enum {name} : uint32_t {{\n"
for v in desc.value:
out += f" {v.name} = {v.number},\n"
out += "};\n"
# Regular cpp file has no enum content anymore
cpp = ""
# Dump cpp content for enum string conversion
dump_cpp = f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n"
dump_cpp += " switch (value) {\n"
for v in desc.value:
dump_cpp += f" case enums::{v.name}:\n"
dump_cpp += f' return "{v.name}";\n'
dump_cpp += " default:\n"
dump_cpp += ' return "UNKNOWN";\n'
dump_cpp += " }\n"
dump_cpp += "}\n"
return out, cpp, dump_cpp
def calculate_message_estimated_size(desc: descriptor.DescriptorProto) -> int:
"""Calculate estimated size for a complete message based on typical values."""
total_size = 0
for field in desc.field:
# Skip deprecated fields
if field.options.deprecated:
continue
ti = create_field_type_info(field)
# Add estimated size for this field
total_size += ti.get_estimated_size()
return total_size
def build_message_type(
desc: descriptor.DescriptorProto,
base_class_fields: dict[str, list[descriptor.FieldDescriptorProto]],
message_source_map: dict[str, int],
) -> tuple[str, str, str]:
public_content: list[str] = []
protected_content: list[str] = []
decode_varint: list[str] = []
decode_length: list[str] = []
decode_32bit: list[str] = []
decode_64bit: list[str] = []
encode: list[str] = []
dump: list[str] = []
size_calc: list[str] = []
# Check if this message has a base class
base_class = get_base_class(desc)
common_field_names = set()
if base_class and base_class_fields and base_class in base_class_fields:
common_field_names = {f.name for f in base_class_fields[base_class]}
# Get message ID if it's a service message
message_id: int | None = get_opt(desc, pb.id)
# Get source direction to determine if we need decode/encode methods
source = message_source_map[desc.name]
needs_decode = source in (SOURCE_BOTH, SOURCE_CLIENT)
needs_encode = source in (SOURCE_BOTH, SOURCE_SERVER)
# Add MESSAGE_TYPE method if this is a service message
if message_id is not None:
# Validate that message_id fits in uint8_t
if message_id > 255:
raise ValueError(
f"Message ID {message_id} for {desc.name} exceeds uint8_t maximum (255)"
)
# Add static constexpr for message type
public_content.append(f"static constexpr uint8_t MESSAGE_TYPE = {message_id};")
# Add estimated size constant
estimated_size = calculate_message_estimated_size(desc)
# Use a type appropriate for estimated_size
estimated_size_type = (
"uint8_t"
if estimated_size <= 255
else "uint16_t"
if estimated_size <= 65535
else "size_t"
)
public_content.append(
f"static constexpr {estimated_size_type} ESTIMATED_SIZE = {estimated_size};"
)
# Add message_name method inline in header
public_content.append("#ifdef HAS_PROTO_MESSAGE_DUMP")
snake_name = camel_to_snake(desc.name)
public_content.append(
f'const char *message_name() const override {{ return "{snake_name}"; }}'
)
public_content.append("#endif")
# Collect fixed_vector fields for custom decode generation
fixed_vector_fields = []
for field in desc.field:
# Skip deprecated fields completely
if field.options.deprecated:
continue
# Validate that fixed_array_size is only used in encode-only messages
if (
needs_decode
and field.label == FieldDescriptorProto.LABEL_REPEATED
and get_field_opt(field, pb.fixed_array_size) is not None
):
raise ValueError(
f"Message '{desc.name}' uses fixed_array_size on field '{field.name}' "
f"but has source={SOURCE_NAMES[source]}. "
f"Fixed arrays are only supported for SOURCE_SERVER (encode-only) messages "
f"since we cannot trust or control the number of items received from clients."
)
# Validate that fixed_array_with_length_define is only used in encode-only messages
if (
needs_decode
and field.label == FieldDescriptorProto.LABEL_REPEATED
and get_field_opt(field, pb.fixed_array_with_length_define) is not None
):
raise ValueError(
f"Message '{desc.name}' uses fixed_array_with_length_define on field '{field.name}' "
f"but has source={SOURCE_NAMES[source]}. "
f"Fixed arrays with length are only supported for SOURCE_SERVER (encode-only) messages "
f"since we cannot trust or control the number of items received from clients."
)
# Collect fixed_vector repeated fields for custom decode generation
if (
needs_decode
and field.label == FieldDescriptorProto.LABEL_REPEATED
and get_field_opt(field, pb.fixed_vector, False)
):
fixed_vector_fields.append((field.name, field.number))
ti = create_field_type_info(field, needs_decode, needs_encode)
# Skip field declarations for fields that are in the base class
# but include their encode/decode logic
if field.name not in common_field_names:
# Check for field_ifdef option
field_ifdef = None
if field.options.HasExtension(pb.field_ifdef):
field_ifdef = field.options.Extensions[pb.field_ifdef]
if ti.protected_content:
protected_content.extend(
wrap_with_ifdef(ti.protected_content, field_ifdef)
)
if ti.public_content:
public_content.extend(wrap_with_ifdef(ti.public_content, field_ifdef))
# Only collect encode logic if this message needs it
if needs_encode:
# Check for field_ifdef option
field_ifdef = None
if field.options.HasExtension(pb.field_ifdef):
field_ifdef = field.options.Extensions[pb.field_ifdef]
encode.extend(wrap_with_ifdef(ti.encode_content, field_ifdef))
size_calc.extend(
wrap_with_ifdef(
ti.get_size_calculation(f"this->{ti.field_name}", ti.force),
field_ifdef,
)
)
# Only collect decode methods if this message needs them
if needs_decode:
# Check for field_ifdef option for decode as well
field_ifdef = None
if field.options.HasExtension(pb.field_ifdef):
field_ifdef = field.options.Extensions[pb.field_ifdef]
if ti.decode_varint_content:
decode_varint.extend(
wrap_with_ifdef(ti.decode_varint_content, field_ifdef)
)
if ti.decode_length_content:
decode_length.extend(
wrap_with_ifdef(ti.decode_length_content, field_ifdef)
)
if ti.decode_32bit_content:
decode_32bit.extend(
wrap_with_ifdef(ti.decode_32bit_content, field_ifdef)
)
if ti.decode_64bit_content:
decode_64bit.extend(
wrap_with_ifdef(ti.decode_64bit_content, field_ifdef)
)
if ti.dump_content:
# Check for field_ifdef option for dump as well
field_ifdef = None
if field.options.HasExtension(pb.field_ifdef):
field_ifdef = field.options.Extensions[pb.field_ifdef]
dump.extend(wrap_with_ifdef(ti.dump_content, field_ifdef))
cpp = ""
if decode_varint:
o = f"bool {desc.name}::decode_varint(uint32_t field_id, proto_varint_value_t value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_varint), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_varint(uint32_t field_id, proto_varint_value_t value) override;"
protected_content.insert(0, prot)
if decode_length:
o = f"bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_length), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;"
protected_content.insert(0, prot)
if decode_32bit:
o = f"bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_32bit), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_32bit(uint32_t field_id, Proto32Bit value) override;"
protected_content.insert(0, prot)
if decode_64bit:
o = f"bool {desc.name}::decode_64bit(uint32_t field_id, Proto64Bit value) {{\n"
o += " switch (field_id) {\n"
o += indent("\n".join(decode_64bit), " ") + "\n"
o += " default: return false;\n"
o += " }\n"
o += " return true;\n"
o += "}\n"
cpp += o
prot = "bool decode_64bit(uint32_t field_id, Proto64Bit value) override;"
protected_content.insert(0, prot)
# Generate custom decode() override for messages with FixedVector fields
if fixed_vector_fields:
# Generate the decode() implementation in cpp
o = f"void {desc.name}::decode(const uint8_t *buffer, size_t length) {{\n"
# Count and init each FixedVector field
for field_name, field_number in fixed_vector_fields:
o += f" uint32_t count_{field_name} = ProtoDecodableMessage::count_repeated_field(buffer, length, {field_number});\n"
o += f" this->{field_name}.init(count_{field_name});\n"
# Call parent decode to populate the fields
o += " ProtoDecodableMessage::decode(buffer, length);\n"
o += "}\n"
cpp += o
# Generate the decode() declaration in header (public method)
prot = "void decode(const uint8_t *buffer, size_t length);"
public_content.append(prot)
# Only generate encode method if this message needs encoding and has fields
if needs_encode and encode:
o = f"void {desc.name}::encode(ProtoWriteBuffer &buffer) const {{"
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
o += f" {encode[0]} }}\n"
else:
o += "\n"
o += indent("\n".join(encode)) + "\n"
o += "}\n"
cpp += o
prot = "void encode(ProtoWriteBuffer &buffer) const;"
public_content.append(prot)
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used
# Add calculate_size method only if this message needs encoding and has fields
if needs_encode and size_calc:
o = f"uint32_t {desc.name}::calculate_size() const {{\n"
o += " uint32_t size = 0;\n"
o += indent("\n".join(size_calc)) + "\n"
o += " return size;\n"
o += "}\n"
cpp += o
prot = "uint32_t calculate_size() const;"
public_content.append(prot)
# If no fields to calculate size for or message doesn't need encoding, the default implementation in ProtoMessage will be used
# dump_to method declaration in header
prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
prot += "const char *dump_to(DumpBuffer &out) const override;\n"
prot += "#endif\n"
public_content.append(prot)
# dump_to implementation will go in dump_cpp
dump_impl = f"const char *{desc.name}::dump_to(DumpBuffer &out) const {{"
if dump:
# Always use MessageDumpHelper for consistent output formatting
dump_impl += "\n"
dump_impl += f' MessageDumpHelper helper(out, "{desc.name}");\n'
dump_impl += indent("\n".join(dump)) + "\n"
dump_impl += " return out.c_str();\n"
else:
dump_impl += "\n"
dump_impl += f' out.append("{desc.name} {{}}");\n'
dump_impl += " return out.c_str();\n"
dump_impl += "}\n"
if base_class:
out = f"class {desc.name} final : public {base_class} {{\n"
else:
# Check if message has any non-deprecated fields
has_fields = any(not field.options.deprecated for field in desc.field)
# Determine inheritance based on whether the message needs decoding and has fields
if needs_decode and has_fields:
base_class = "ProtoDecodableMessage"
else:
base_class = "ProtoMessage"
out = f"class {desc.name} final : public {base_class} {{\n"
out += " public:\n"
out += indent("\n".join(public_content)) + "\n"
out += "\n"
out += " protected:\n"
out += indent("\n".join(protected_content))
if len(protected_content) > 0:
out += "\n"
out += "};\n"
# Build dump_cpp content with dump_to implementation
dump_cpp = dump_impl
return out, cpp, dump_cpp
SOURCE_BOTH = 0
SOURCE_SERVER = 1
SOURCE_CLIENT = 2
SOURCE_NAMES = {
SOURCE_BOTH: "SOURCE_BOTH",
SOURCE_SERVER: "SOURCE_SERVER",
SOURCE_CLIENT: "SOURCE_CLIENT",
}
RECEIVE_CASES: dict[int, tuple[str, str | None, str]] = {}
ifdefs: dict[str, str] = {}
# Track messages with no fields (empty messages) for parameter elision
EMPTY_MESSAGES: set[str] = set()
# Track empty SOURCE_CLIENT messages that don't need class generation
# These messages have no fields and are only received (never sent), so the
# class definition (vtable, dump_to, message_name, ESTIMATED_SIZE) is dead code
# that the compiler compiles but the linker strips away.
SKIP_CLASS_GENERATION: set[str] = set()
def get_opt(
desc: descriptor.DescriptorProto,
opt: descriptor.MessageOptions,
default: Any = None,
) -> Any:
"""Get the option from the descriptor."""
if not desc.options.HasExtension(opt):
return default
return desc.options.Extensions[opt]
def get_field_opt(
field: descriptor.FieldDescriptorProto,
opt: descriptor.FieldOptions,
default: Any = None,
) -> Any:
"""Get the option from a field descriptor."""
if not field.options.HasExtension(opt):
return default
return field.options.Extensions[opt]
def get_base_class(desc: descriptor.DescriptorProto) -> str | None:
"""Get the base_class option from a message descriptor."""
if not desc.options.HasExtension(pb.base_class):
return None
return desc.options.Extensions[pb.base_class]
def collect_messages_by_base_class(
messages: list[descriptor.DescriptorProto],
) -> dict[str, list[descriptor.DescriptorProto]]:
"""Group messages by their base_class option."""
base_class_groups = {}
for msg in messages:
base_class = get_base_class(msg)
if base_class:
if base_class not in base_class_groups:
base_class_groups[base_class] = []
base_class_groups[base_class].append(msg)
return base_class_groups
def find_common_fields(
messages: list[descriptor.DescriptorProto],
) -> list[descriptor.FieldDescriptorProto]:
"""Find fields that are common to all messages in the list."""
if not messages:
return []
# Start with fields from the first message (excluding deprecated fields)
first_msg_fields = {
field.name: field for field in messages[0].field if not field.options.deprecated
}
common_fields = []
# Check each field to see if it exists in all messages with same type
# Field numbers can vary between messages - derived classes handle the mapping
for field_name, field in first_msg_fields.items():
is_common = True
for msg in messages[1:]:
found = False
for other_field in msg.field:
# Skip deprecated fields
if other_field.options.deprecated:
continue
if (
other_field.name == field_name
and other_field.type == field.type
and other_field.label == field.label
):
found = True
break
if not found:
is_common = False
break
if is_common:
common_fields.append(field)
# Sort by field number to maintain order
common_fields.sort(key=lambda f: f.number)
return common_fields
def get_common_field_ifdef(
field_name: str, messages: list[descriptor.DescriptorProto]
) -> str | None:
"""Get the field_ifdef option if it's consistent across all messages.
Args:
field_name: Name of the field to check
messages: List of messages that contain this field
Returns:
The field_ifdef string if all messages have the same value, None otherwise
"""
field_ifdefs = {
get_field_opt(field, pb.field_ifdef)
for msg in messages
if (field := next((f for f in msg.field if f.name == field_name), None))
}
# Return the ifdef only if all messages agree on the same value
return field_ifdefs.pop() if len(field_ifdefs) == 1 else None
def build_base_class(
base_class_name: str,
common_fields: list[descriptor.FieldDescriptorProto],
messages: list[descriptor.DescriptorProto],
message_source_map: dict[str, int],
) -> tuple[str, str, str]:
"""Build the base class definition and implementation."""
public_content = []
protected_content = []
# Determine if any message using this base class needs decoding/encoding
needs_decode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_CLIENT)
for msg in messages
)
needs_encode = any(
message_source_map.get(msg.name, SOURCE_BOTH) in (SOURCE_BOTH, SOURCE_SERVER)
for msg in messages
)
# For base classes, we only declare the fields but don't handle encode/decode
# The derived classes will handle encoding/decoding with their specific field numbers
for field in common_fields:
ti = create_field_type_info(field, needs_decode, needs_encode)
# Get field_ifdef if it's consistent across all messages
field_ifdef = get_common_field_ifdef(field.name, messages)
# Only add field declarations, not encode/decode logic
if ti.protected_content:
protected_content.extend(wrap_with_ifdef(ti.protected_content, field_ifdef))
if ti.public_content:
public_content.extend(wrap_with_ifdef(ti.public_content, field_ifdef))
# Build header
parent_class = "ProtoDecodableMessage" if needs_decode else "ProtoMessage"
out = f"class {base_class_name} : public {parent_class} {{\n"
out += " public:\n"
# Base classes don't implement encode/decode/calculate_size
# Derived classes handle these with their specific field numbers
cpp = ""
out += indent("\n".join(public_content)) + "\n"
out += "\n"
out += " protected:\n"
# Non-virtual protected destructor prevents accidental polymorphic deletion
protected_content.insert(0, f"~{base_class_name}() = default;")
out += indent("\n".join(protected_content))
if protected_content:
out += "\n"
out += "};\n"
# No implementation needed for base classes
dump_cpp = ""
return out, cpp, dump_cpp
def generate_base_classes(
base_class_groups: dict[str, list[descriptor.DescriptorProto]],
message_source_map: dict[str, int],
) -> tuple[str, str, str]:
"""Generate all base classes."""
all_headers = []
all_cpp = []
all_dump_cpp = []
for base_class_name, messages in base_class_groups.items():
# Find common fields
common_fields = find_common_fields(messages)
if common_fields:
# Generate base class
header, cpp, dump_cpp = build_base_class(
base_class_name, common_fields, messages, message_source_map
)
all_headers.append(header)
all_cpp.append(cpp)
all_dump_cpp.append(dump_cpp)
return "\n".join(all_headers), "\n".join(all_cpp), "\n".join(all_dump_cpp)
def build_service_message_type(
mt: descriptor.DescriptorProto,
message_source_map: dict[str, int],
) -> tuple[str, str] | None:
"""Builds the service message type."""
# Skip deprecated messages
if mt.options.deprecated:
return None
snake = camel_to_snake(mt.name)
id_: int | None = get_opt(mt, pb.id)
if id_ is None:
return None
source: int = message_source_map.get(mt.name, SOURCE_BOTH)
ifdef: str | None = get_opt(mt, pb.ifdef)
log: bool = get_opt(mt, pb.log, True)
hout = ""
cout = ""
# Store ifdef for later use
if ifdef is not None:
ifdefs[str(mt.name)] = ifdef
if source in (SOURCE_BOTH, SOURCE_SERVER):
# Don't generate individual send methods anymore
# The generic send_message method will be used instead
pass
if source in (SOURCE_BOTH, SOURCE_CLIENT):
# Only add ifdef when we're actually generating content
if ifdef is not None:
hout += f"#ifdef {ifdef}\n"
# Generate receive handler and switch case
func = f"on_{snake}"
has_fields = any(not field.options.deprecated for field in mt.field)
is_empty = not has_fields
if is_empty:
EMPTY_MESSAGES.add(mt.name)
hout += f"void {func}({'' if is_empty else f'const {mt.name} &value'}){{}};\n"
case = ""
if not is_empty:
case += f"{mt.name} msg;\n"
case += "msg.decode(msg_data, msg_size);\n"
if log:
case += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
if is_empty:
case += f'this->log_receive_message_(LOG_STR("{func}"));\n'
else:
case += f'this->log_receive_message_(LOG_STR("{func}"), msg);\n'
case += "#endif\n"
case += f"this->{func}({'msg' if not is_empty else ''});\n"
case += "break;"
if mt.name in SKIP_CLASS_GENERATION:
case_label = f"{id_} /* {mt.name} is empty */"
else:
case_label = f"{mt.name}::MESSAGE_TYPE"
RECEIVE_CASES[id_] = (case, ifdef, case_label)
# Only close ifdef if we opened it
if ifdef is not None:
hout += "#endif\n"
return hout, cout
def main() -> None:
"""Main function to generate the C++ classes."""
cwd = Path(__file__).resolve().parent
root = cwd.parent.parent / "esphome" / "components" / "api"
prot_file = root / "api.protoc"
call(["protoc", "-o", str(prot_file), "-I", str(root), "api.proto"])
proto_content = prot_file.read_bytes()
# pylint: disable-next=no-member
d = descriptor.FileDescriptorSet.FromString(proto_content)
file = d.file[0]
# Build dynamic ifdef mappings early so we can emit USE_API_VARINT64 before includes
enum_ifdef_map, message_ifdef_map, message_source_map, used_messages = (
build_type_usage_map(file)
)
# Find the ifdef guard for 64-bit varint fields (int64/uint64/sint64).
# Generated into api_pb2_defines.h so proto.h can include it, ensuring
# consistent ProtoVarInt layout across all translation units.
has_varint64, varint64_guard = get_varint64_ifdef(file, message_ifdef_map)
# Generate api_pb2_defines.h — included by proto.h to ensure all translation
# units see USE_API_VARINT64 consistently (avoids ODR violations in ProtoVarInt).
defines_content = FILE_HEADER
defines_content += "#pragma once\n\n"
defines_content += '#include "esphome/core/defines.h"\n'
if has_varint64:
lines = [
"#ifndef USE_API_VARINT64",
"#define USE_API_VARINT64",
"#endif",
]
defines_content += "\n".join(wrap_with_ifdef(lines, varint64_guard))
defines_content += "\n"
defines_content += "\nnamespace esphome::api {} // namespace esphome::api\n"
with open(root / "api_pb2_defines.h", "w", encoding="utf-8") as f:
f.write(defines_content)
content = FILE_HEADER
content += """\
#pragma once
#include "esphome/core/string_ref.h"
#include "proto.h"
#include "api_pb2_includes.h"
"""
content += """
namespace esphome::api {
"""
cpp = FILE_HEADER
cpp += """\
#include "api_pb2.h"
#include "esphome/core/log.h"
#include "esphome/core/helpers.h"
#include <cstring>
namespace esphome::api {
"""
# Initialize dump cpp content
dump_cpp = FILE_HEADER
dump_cpp += """\
#include "api_pb2.h"
#include "esphome/core/helpers.h"
#include <cinttypes>
#ifdef HAS_PROTO_MESSAGE_DUMP
namespace esphome::api {
// Helper function to append a quoted string, handling empty StringRef
static inline void append_quoted_string(DumpBuffer &out, const StringRef &ref) {
out.append("'");
if (!ref.empty()) {
out.append(ref.c_str(), ref.size());
}
out.append("'");
}
// Common helpers for dump_field functions
static inline void append_field_prefix(DumpBuffer &out, const char *field_name, int indent) {
out.append(indent, ' ').append(field_name).append(": ");
}
static inline void append_uint(DumpBuffer &out, uint32_t value) {
out.set_pos(buf_append_printf(out.data(), DumpBuffer::CAPACITY, out.pos(), "%" PRIu32, value));
}
// RAII helper for message dump formatting
class MessageDumpHelper {
public:
MessageDumpHelper(DumpBuffer &out, const char *message_name) : out_(out) {
out_.append(message_name);
out_.append(" {\\n");
}
~MessageDumpHelper() { out_.append(" }"); }
private:
DumpBuffer &out_;
};
// Helper functions to reduce code duplication in dump methods
static void dump_field(DumpBuffer &out, const char *field_name, int32_t value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.set_pos(buf_append_printf(out.data(), DumpBuffer::CAPACITY, out.pos(), "%" PRId32 "\\n", value));
}
static void dump_field(DumpBuffer &out, const char *field_name, uint32_t value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.set_pos(buf_append_printf(out.data(), DumpBuffer::CAPACITY, out.pos(), "%" PRIu32 "\\n", value));
}
static void dump_field(DumpBuffer &out, const char *field_name, float value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.set_pos(buf_append_printf(out.data(), DumpBuffer::CAPACITY, out.pos(), "%g\\n", value));
}
static void dump_field(DumpBuffer &out, const char *field_name, uint64_t value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.set_pos(buf_append_printf(out.data(), DumpBuffer::CAPACITY, out.pos(), "%" PRIu64 "\\n", value));
}
static void dump_field(DumpBuffer &out, const char *field_name, bool value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.append(YESNO(value));
out.append("\\n");
}
static void dump_field(DumpBuffer &out, const char *field_name, const std::string &value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.append("'").append(value.c_str()).append("'");
out.append("\\n");
}
static void dump_field(DumpBuffer &out, const char *field_name, StringRef value, int indent = 2) {
append_field_prefix(out, field_name, indent);
append_quoted_string(out, value);
out.append("\\n");
}
static void dump_field(DumpBuffer &out, const char *field_name, const char *value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.append("'").append(value).append("'");
out.append("\\n");
}
template<typename T>
static void dump_field(DumpBuffer &out, const char *field_name, T value, int indent = 2) {
append_field_prefix(out, field_name, indent);
out.append(proto_enum_to_string<T>(value));
out.append("\\n");
}
// Helper for bytes fields - uses stack buffer to avoid heap allocation
// Buffer sized for 160 bytes of data (480 chars with separators) to fit typical log buffer
static void dump_bytes_field(DumpBuffer &out, const char *field_name, const uint8_t *data, size_t len, int indent = 2) {
char hex_buf[format_hex_pretty_size(160)];
append_field_prefix(out, field_name, indent);
format_hex_pretty_to(hex_buf, data, len);
out.append(hex_buf).append("\\n");
}
"""
content += "namespace enums {\n\n"
# Simple grouping of enums by ifdef
current_ifdef = None
for enum in file.enum_type:
# Skip deprecated enums
if enum.options.deprecated:
continue
s, c, dc = build_enum_type(enum, enum_ifdef_map)
enum_ifdef = enum_ifdef_map.get(enum.name)
# Handle ifdef changes
if enum_ifdef != current_ifdef:
if current_ifdef is not None:
content += "#endif\n"
dump_cpp += "#endif\n"
if enum_ifdef is not None:
content += f"#ifdef {enum_ifdef}\n"
dump_cpp += f"#ifdef {enum_ifdef}\n"
current_ifdef = enum_ifdef
content += s
cpp += c
dump_cpp += dc
# Close last ifdef
if current_ifdef is not None:
content += "#endif\n"
dump_cpp += "#endif\n"
content += "\n} // namespace enums\n\n"
mt = file.message_type
# Identify empty SOURCE_CLIENT messages that don't need class generation
for m in mt:
if m.options.deprecated:
continue
if not m.options.HasExtension(pb.id):
continue
source = message_source_map.get(m.name)
if source != SOURCE_CLIENT:
continue
has_fields = any(not field.options.deprecated for field in m.field)
if not has_fields:
SKIP_CLASS_GENERATION.add(m.name)
# Collect messages by base class
base_class_groups = collect_messages_by_base_class(mt)
# Find common fields for each base class
base_class_fields = {}
for base_class_name, messages in base_class_groups.items():
common_fields = find_common_fields(messages)
if common_fields:
base_class_fields[base_class_name] = common_fields
# Generate base classes
if base_class_fields:
base_headers, base_cpp, base_dump_cpp = generate_base_classes(
base_class_groups, message_source_map
)
content += base_headers
cpp += base_cpp
dump_cpp += base_dump_cpp
# Generate message types with base class information
# Simple grouping by ifdef
current_ifdef = None
for m in mt:
# Skip deprecated messages
if m.options.deprecated:
continue
# Skip messages that aren't used (unless they have an ID/service message)
if m.name not in used_messages and not m.options.HasExtension(pb.id):
continue
# Skip class generation for empty SOURCE_CLIENT messages
if m.name in SKIP_CLASS_GENERATION:
continue
s, c, dc = build_message_type(m, base_class_fields, message_source_map)
msg_ifdef = message_ifdef_map.get(m.name)
# Handle ifdef changes
if msg_ifdef != current_ifdef:
if current_ifdef is not None:
content += "#endif\n"
if cpp:
cpp += "#endif\n"
if dump_cpp:
dump_cpp += "#endif\n"
if msg_ifdef is not None:
content += f"#ifdef {msg_ifdef}\n"
cpp += f"#ifdef {msg_ifdef}\n"
dump_cpp += f"#ifdef {msg_ifdef}\n"
current_ifdef = msg_ifdef
content += s
cpp += c
dump_cpp += dc
# Close last ifdef
if current_ifdef is not None:
content += "#endif\n"
cpp += "#endif\n"
dump_cpp += "#endif\n"
content += """\
} // namespace esphome::api
"""
cpp += """\
} // namespace esphome::api
"""
dump_cpp += """\
} // namespace esphome::api
#endif // HAS_PROTO_MESSAGE_DUMP
"""
with open(root / "api_pb2.h", "w", encoding="utf-8") as f:
f.write(content)
with open(root / "api_pb2.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
with open(root / "api_pb2_dump.cpp", "w", encoding="utf-8") as f:
f.write(dump_cpp)
hpp = FILE_HEADER
hpp += """\
#pragma once
#include "esphome/core/defines.h"
#include "api_pb2.h"
namespace esphome::api {
"""
cpp = FILE_HEADER
cpp += """\
#include "api_pb2_service.h"
#ifdef USE_API
#include "api_connection.h"
#endif
#include "esphome/core/log.h"
namespace esphome::api {
static const char *const TAG = "api.service";
"""
class_name = "APIServerConnectionBase"
hpp += f"class {class_name} {{\n"
hpp += " public:\n"
# Add logging helper method declarations
hpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
hpp += " protected:\n"
hpp += " void log_send_message_(const char *name, const char *dump);\n"
hpp += (
" void log_receive_message_(const LogString *name, const ProtoMessage &msg);\n"
)
hpp += " void log_receive_message_(const LogString *name);\n"
hpp += " public:\n"
hpp += "#endif\n\n"
# send_message is now a template on APIConnection directly
# No non-template send_message method needed here
# Add logging helper method implementations to cpp
cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += (
f"void {class_name}::log_send_message_(const char *name, const char *dump) {{\n"
)
cpp += ' ESP_LOGVV(TAG, "send_message %s: %s", name, dump);\n'
cpp += "}\n"
cpp += f"void {class_name}::log_receive_message_(const LogString *name, const ProtoMessage &msg) {{\n"
cpp += " DumpBuffer dump_buf;\n"
cpp += ' ESP_LOGVV(TAG, "%s: %s", LOG_STR_ARG(name), msg.dump_to(dump_buf));\n'
cpp += "}\n"
cpp += f"void {class_name}::log_receive_message_(const LogString *name) {{\n"
cpp += ' ESP_LOGVV(TAG, "%s: {}", LOG_STR_ARG(name));\n'
cpp += "}\n"
cpp += "#endif\n\n"
for mt in file.message_type:
obj = build_service_message_type(mt, message_source_map)
if obj is None:
continue
hout, cout = obj
hpp += indent(hout) + "\n"
cpp += cout
cases = list(RECEIVE_CASES.items())
cases.sort()
serv = file.service[0]
# Build a mapping of message input types to their authentication requirements
message_auth_map: dict[str, bool] = {}
message_conn_map: dict[str, bool] = {}
for m in serv.method:
inp = m.input_type[1:]
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
# Store authentication requirements for message types
message_auth_map[inp] = needs_auth
message_conn_map[inp] = needs_conn
# Categorize messages by their authentication requirements
no_conn_ids: set[int] = set()
conn_only_ids: set[int] = set()
# Build a reverse lookup from message id to message name for auth lookups
id_to_msg_name: dict[int, str] = {}
for mt in file.message_type:
id_ = get_opt(mt, pb.id)
if id_ is not None and not mt.options.deprecated:
id_to_msg_name[id_] = mt.name
for id_, (_, _, case_label) in cases:
msg_name = id_to_msg_name.get(id_, "")
if msg_name in message_auth_map:
needs_auth = message_auth_map[msg_name]
needs_conn = message_conn_map[msg_name]
if not needs_conn:
no_conn_ids.add(id_)
elif not needs_auth:
conn_only_ids.add(id_)
# Helper to generate case statements with ifdefs
def generate_cases(ids: set[int], comment: str) -> str:
result = ""
for id_ in sorted(ids):
_, ifdef, case_label = RECEIVE_CASES[id_]
if ifdef:
result += f"#ifdef {ifdef}\n"
result += f" case {case_label}: {comment}\n"
if ifdef:
result += "#endif\n"
return result
# Generate read_message_ as APIConnection method (not base class) so the compiler
# can devirtualize and inline the on_* handler calls within the same class.
# APIConnection declares this method in api_connection.h.
out = "#ifdef USE_API\n"
out += "void APIConnection::read_message_(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {\n"
# Auth check block before dispatch switch
out += " // Check authentication/connection requirements\n"
if no_conn_ids or conn_only_ids:
out += " switch (msg_type) {\n"
if no_conn_ids:
out += generate_cases(no_conn_ids, "// No setup required")
out += " break;\n"
if conn_only_ids:
out += generate_cases(conn_only_ids, "// Connection setup only")
out += " if (!this->check_connection_setup_()) {\n"
out += " return;\n"
out += " }\n"
out += " break;\n"
out += " default:\n"
out += " if (!this->check_authenticated_()) {\n"
out += " return;\n"
out += " }\n"
out += " break;\n"
out += " }\n"
else:
out += " if (!this->check_authenticated_()) {\n"
out += " return;\n"
out += " }\n"
# Dispatch switch
out += " switch (msg_type) {\n"
for i, (case, ifdef, case_label) in cases:
if ifdef is not None:
out += f"#ifdef {ifdef}\n"
c = f" case {case_label}: {{\n"
c += indent(case, " ") + "\n"
c += " }"
out += c + "\n"
if ifdef is not None:
out += "#endif\n"
out += " default:\n"
out += " break;\n"
out += " }\n"
out += "}\n"
out += "#endif // USE_API\n"
cpp += out
hpp += "};\n"
hpp += """\
} // namespace esphome::api
"""
cpp += """\
} // namespace esphome::api
"""
with open(root / "api_pb2_service.h", "w", encoding="utf-8") as f:
f.write(hpp)
with open(root / "api_pb2_service.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
prot_file.unlink()
try:
import clang_format
def exec_clang_format(path: Path) -> None:
clang_format_path = (
Path(clang_format.__file__).parent / "data" / "bin" / "clang-format"
)
call([clang_format_path, "-i", path])
exec_clang_format(root / "api_pb2_service.h")
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h")
exec_clang_format(root / "api_pb2.cpp")
exec_clang_format(root / "api_pb2_dump.cpp")
except ImportError:
pass
if __name__ == "__main__":
sys.exit(main())