mirror of
https://github.com/esphome/esphome.git
synced 2026-05-21 17:39:00 +08:00
[api] Speed up protobuf encode 17-20% with register-optimized write path (#15290)
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
@@ -232,29 +232,31 @@ class TypeInfo(ABC):
|
||||
# eliminating the zero-check branch and encode_field_raw indirection.
|
||||
# {value} is replaced with the actual field expression.
|
||||
RAW_ENCODE_MAP: dict[str, str] = {
|
||||
"encode_uint32": "buffer.encode_varint_raw({value});",
|
||||
"encode_uint64": "buffer.encode_varint_raw_64({value});",
|
||||
"encode_sint32": "buffer.encode_varint_raw(encode_zigzag32({value}));",
|
||||
"encode_sint64": "buffer.encode_varint_raw_64(encode_zigzag64({value}));",
|
||||
"encode_int64": "buffer.encode_varint_raw_64(static_cast<uint64_t>({value}));",
|
||||
"encode_bool": "buffer.write_raw_byte({value} ? 0x01 : 0x00);",
|
||||
"encode_uint32": "ProtoEncode::encode_varint_raw(pos, {value});",
|
||||
"encode_uint64": "ProtoEncode::encode_varint_raw_64(pos, {value});",
|
||||
"encode_sint32": "ProtoEncode::encode_varint_raw_short(pos, encode_zigzag32({value}));",
|
||||
"encode_sint64": "ProtoEncode::encode_varint_raw_64(pos, encode_zigzag64({value}));",
|
||||
"encode_int64": "ProtoEncode::encode_varint_raw_64(pos, static_cast<uint64_t>({value}));",
|
||||
"encode_bool": "ProtoEncode::write_raw_byte(pos, {value} ? 0x01 : 0x00);",
|
||||
}
|
||||
|
||||
# When max_value < 128, the varint is always 1 byte — use a direct byte write
|
||||
RAW_ENCODE_SMALL_MAP: dict[str, str] = {
|
||||
"encode_uint32": "buffer.write_raw_byte(static_cast<uint8_t>({value}));",
|
||||
"encode_uint64": "buffer.write_raw_byte(static_cast<uint8_t>({value}));",
|
||||
"encode_uint32": "ProtoEncode::write_raw_byte(pos, static_cast<uint8_t>({value}));",
|
||||
"encode_uint64": "ProtoEncode::write_raw_byte(pos, static_cast<uint8_t>({value}));",
|
||||
}
|
||||
|
||||
def _encode_with_precomputed_tag(self, value_expr: str) -> str | None:
|
||||
"""Try to emit a precomputed-tag encode for a forced field.
|
||||
"""Try to emit a precomputed-tag encode for a field.
|
||||
|
||||
For forced fields: emits raw tag + value unconditionally.
|
||||
For non-forced fields with single-byte tag: emits inline zero-check
|
||||
+ raw tag + value, avoiding an outlined function call.
|
||||
|
||||
Returns the raw encode string if the tag is a single byte and the
|
||||
encode_func has a known raw equivalent, or None otherwise.
|
||||
When max_value < 128, uses direct byte write instead of varint encoding.
|
||||
"""
|
||||
if not self.force:
|
||||
return None
|
||||
tag = self.calculate_tag()
|
||||
if tag >= 128:
|
||||
return None
|
||||
@@ -263,10 +265,17 @@ class TypeInfo(ABC):
|
||||
if max_val is not None and max_val < 128:
|
||||
raw_expr = self.RAW_ENCODE_SMALL_MAP.get(self.encode_func)
|
||||
if raw_expr is None:
|
||||
# Only use RAW_ENCODE_MAP for forced fields or fields with max_value
|
||||
if not self.force and max_val is None:
|
||||
return None
|
||||
raw_expr = self.RAW_ENCODE_MAP.get(self.encode_func)
|
||||
if raw_expr is None:
|
||||
return None
|
||||
return f"buffer.write_raw_byte({tag});\n{raw_expr.format(value=value_expr)}"
|
||||
body = f"ProtoEncode::write_raw_byte(pos, {tag});\n{raw_expr.format(value=value_expr)}"
|
||||
if self.force:
|
||||
return body
|
||||
# Non-forced with max_value: inline zero-check + raw encode
|
||||
return f"if ({value_expr}) {{\n {body}\n}}"
|
||||
|
||||
def _encode_bytes_with_precomputed_tag(
|
||||
self, data_expr: str, len_expr: str, max_len: int | None = None
|
||||
@@ -283,14 +292,14 @@ class TypeInfo(ABC):
|
||||
return None
|
||||
# When max_len < 128, length varint is always 1 byte
|
||||
len_encode = (
|
||||
f"buffer.write_raw_byte(static_cast<uint8_t>({len_expr}));"
|
||||
f"ProtoEncode::write_raw_byte(pos, static_cast<uint8_t>({len_expr}));"
|
||||
if max_len is not None and max_len < 128
|
||||
else f"buffer.encode_varint_raw({len_expr});"
|
||||
else f"ProtoEncode::encode_varint_raw(pos, {len_expr});"
|
||||
)
|
||||
return (
|
||||
f"buffer.write_raw_byte({tag});\n"
|
||||
f"ProtoEncode::write_raw_byte(pos, {tag});\n"
|
||||
f"{len_encode}\n"
|
||||
f"buffer.encode_raw({data_expr}, {len_expr});"
|
||||
f"ProtoEncode::encode_raw(pos, {data_expr}, {len_expr});"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -298,8 +307,8 @@ class TypeInfo(ABC):
|
||||
if result := self._encode_with_precomputed_tag(f"this->{self.field_name}"):
|
||||
return result
|
||||
if self.force:
|
||||
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, true);"
|
||||
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
|
||||
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name}, true);"
|
||||
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name});"
|
||||
|
||||
encode_func = None
|
||||
|
||||
@@ -657,10 +666,10 @@ class Fixed32Type(TypeInfo):
|
||||
tag = self.calculate_tag()
|
||||
if self.force and tag < 128:
|
||||
# Emit combined tag+value write: precomputed tag + direct memcpy
|
||||
return f"buffer.write_tag_and_fixed32({tag}, this->{self.field_name});"
|
||||
return f"ProtoEncode::write_tag_and_fixed32(pos, {tag}, this->{self.field_name});"
|
||||
if self.force:
|
||||
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name}, true);"
|
||||
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
|
||||
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name}, true);"
|
||||
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, this->{self.field_name});"
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
@@ -734,8 +743,8 @@ class StringType(TypeInfo):
|
||||
):
|
||||
return result
|
||||
if self.force:
|
||||
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_, true);"
|
||||
return f"buffer.encode_string({self.number}, this->{self.field_name}_ref_);"
|
||||
return f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name}_ref_, true);"
|
||||
return f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name}_ref_);"
|
||||
|
||||
def dump(self, name):
|
||||
# If name is 'it', this is a repeated field element - always use string
|
||||
@@ -822,8 +831,8 @@ class MessageType(TypeInfo):
|
||||
|
||||
@property
|
||||
def encode_content(self) -> str:
|
||||
# encode_sub_message always encodes (uses backpatch), no force needed
|
||||
return f"buffer.{self.encode_func}({self.number}, this->{self.field_name});"
|
||||
# Sub-message encoding needs buffer for backpatch/sync
|
||||
return f"ProtoEncode::{self.encode_func}(pos, buffer, {self.number}, this->{self.field_name});"
|
||||
|
||||
@property
|
||||
def decode_length(self) -> str:
|
||||
@@ -904,8 +913,8 @@ class BytesType(TypeInfo):
|
||||
):
|
||||
return result
|
||||
if self.force:
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_, true);"
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
|
||||
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_, true);"
|
||||
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}_ptr_, this->{self.field_name}_len_);"
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
ptr_dump = f"format_hex_pretty(this->{self.field_name}_ptr_, this->{self.field_name}_len_)"
|
||||
@@ -1015,8 +1024,8 @@ class PointerToBytesBufferType(PointerToBufferTypeBase):
|
||||
):
|
||||
return result
|
||||
if self.force:
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
|
||||
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
|
||||
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len);"
|
||||
|
||||
@property
|
||||
def decode_length_content(self) -> str | None:
|
||||
@@ -1068,10 +1077,10 @@ class PointerToStringBufferType(PointerToBufferTypeBase):
|
||||
):
|
||||
return result
|
||||
if self.force:
|
||||
return (
|
||||
f"buffer.encode_string({self.number}, this->{self.field_name}, true);"
|
||||
)
|
||||
return f"buffer.encode_string({self.number}, this->{self.field_name});"
|
||||
return f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name}, true);"
|
||||
return (
|
||||
f"ProtoEncode::encode_string(pos, {self.number}, this->{self.field_name});"
|
||||
)
|
||||
|
||||
@property
|
||||
def decode_length_content(self) -> str | None:
|
||||
@@ -1240,8 +1249,8 @@ class FixedArrayBytesType(TypeInfo):
|
||||
):
|
||||
return result
|
||||
if self.force:
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
|
||||
return f"buffer.encode_bytes({self.number}, this->{self.field_name}, this->{self.field_name}_len);"
|
||||
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len, true);"
|
||||
return f"ProtoEncode::encode_bytes(pos, {self.number}, this->{self.field_name}, this->{self.field_name}_len);"
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
return f"out.append(format_hex_pretty({name}, {name}_len));"
|
||||
@@ -1323,8 +1332,8 @@ class EnumType(TypeInfo):
|
||||
):
|
||||
return result
|
||||
if self.force:
|
||||
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}), true);"
|
||||
return f"buffer.{self.encode_func}({self.number}, static_cast<uint32_t>(this->{self.field_name}));"
|
||||
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, static_cast<uint32_t>(this->{self.field_name}), true);"
|
||||
return f"ProtoEncode::{self.encode_func}(pos, {self.number}, static_cast<uint32_t>(this->{self.field_name}));"
|
||||
|
||||
def dump(self, name: str) -> str:
|
||||
return f"out.append_p(proto_enum_to_string<{self.cpp_type}>({name}));"
|
||||
@@ -1487,11 +1496,13 @@ class FixedArrayRepeatedType(TypeInfo):
|
||||
def _encode_element(self, element: str) -> str:
|
||||
"""Helper to generate encode statement for a single element."""
|
||||
if isinstance(self._ti, EnumType):
|
||||
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
|
||||
return f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, static_cast<uint32_t>({element}), true);"
|
||||
# Repeated message elements use encode_sub_message (force=true is default)
|
||||
if isinstance(self._ti, MessageType):
|
||||
return f"buffer.encode_sub_message({self.number}, {element});"
|
||||
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
|
||||
return f"ProtoEncode::encode_sub_message(pos, buffer, {self.number}, {element});"
|
||||
return (
|
||||
f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, {element}, true);"
|
||||
)
|
||||
|
||||
@property
|
||||
def cpp_type(self) -> str:
|
||||
@@ -1815,11 +1826,13 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
def _encode_element_call(self, element: str) -> str:
|
||||
"""Helper to generate encode call for a single element."""
|
||||
if isinstance(self._ti, EnumType):
|
||||
return f"buffer.{self._ti.encode_func}({self.number}, static_cast<uint32_t>({element}), true);"
|
||||
return f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, static_cast<uint32_t>({element}), true);"
|
||||
# Repeated message elements use encode_sub_message (force=true is default)
|
||||
if isinstance(self._ti, MessageType):
|
||||
return f"buffer.encode_sub_message({self.number}, {element});"
|
||||
return f"buffer.{self._ti.encode_func}({self.number}, {element}, true);"
|
||||
return f"ProtoEncode::encode_sub_message(pos, buffer, {self.number}, {element});"
|
||||
return (
|
||||
f"ProtoEncode::{self._ti.encode_func}(pos, {self.number}, {element}, true);"
|
||||
)
|
||||
|
||||
@property
|
||||
def encode_content(self) -> str:
|
||||
@@ -1828,7 +1841,7 @@ class RepeatedTypeInfo(TypeInfo):
|
||||
# Special handling for const char* elements (when container_no_template contains "const char")
|
||||
if "const char" in self._container_no_template:
|
||||
o = f"for (const char *it : *this->{self.field_name}) {{\n"
|
||||
o += f" buffer.{self._ti.encode_func}({self.number}, it, strlen(it), true);\n"
|
||||
o += f" ProtoEncode::{self._ti.encode_func}(pos, {self.number}, it, strlen(it), true);\n"
|
||||
else:
|
||||
o = f"for (const auto &it : *this->{self.field_name}) {{\n"
|
||||
o += f" {self._encode_element_call('it')}\n"
|
||||
@@ -2403,15 +2416,19 @@ def build_message_type(
|
||||
|
||||
# Only generate encode method if this message needs encoding and has fields
|
||||
if needs_encode and encode:
|
||||
o = f"void {desc.name}::encode(ProtoWriteBuffer &buffer) const {{"
|
||||
if len(encode) == 1 and len(encode[0]) + len(o) + 3 < 120:
|
||||
o += f" {encode[0]} }}\n"
|
||||
else:
|
||||
o += "\n"
|
||||
o += indent("\n".join(encode)) + "\n"
|
||||
o += "}\n"
|
||||
# Add PROTO_ENCODE_DEBUG_ARG after pos in all proto_* calls
|
||||
encode_debug = [
|
||||
line.replace("(pos,", "(pos PROTO_ENCODE_DEBUG_ARG,") for line in encode
|
||||
]
|
||||
o = f"uint8_t *{desc.name}::encode(ProtoWriteBuffer &buffer PROTO_ENCODE_DEBUG_PARAM) const {{\n"
|
||||
o += " uint8_t *__restrict__ pos = buffer.get_pos();\n"
|
||||
o += indent("\n".join(encode_debug)) + "\n"
|
||||
o += " return pos;\n"
|
||||
o += "}\n"
|
||||
cpp += o
|
||||
prot = "void encode(ProtoWriteBuffer &buffer) const;"
|
||||
prot = (
|
||||
"uint8_t *encode(ProtoWriteBuffer &buffer PROTO_ENCODE_DEBUG_PARAM) const;"
|
||||
)
|
||||
public_content.append(prot)
|
||||
# If no fields to encode or message doesn't need encoding, the default implementation in ProtoMessage will be used
|
||||
|
||||
|
||||
Reference in New Issue
Block a user