mirror of
https://github.com/esphome/esphome.git
synced 2026-06-02 03:02:19 +08:00
Reserve buffer space to avoid frequent realloc when generating protobuf messages (#8707)
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
|||||||
// This file was automatically generated with a tool.
|
// This file was automatically generated with a tool.
|
||||||
// See scripts/api_protobuf/api_protobuf.py
|
// See script/api_protobuf/api_protobuf.py
|
||||||
#include "api_pb2_service.h"
|
#include "api_pb2_service.h"
|
||||||
#include "esphome/core/log.h"
|
#include "esphome/core/log.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// This file was automatically generated with a tool.
|
// This file was automatically generated with a tool.
|
||||||
// See scripts/api_protobuf/api_protobuf.py
|
// See script/api_protobuf/api_protobuf.py
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "api_pb2.h"
|
#include "api_pb2.h"
|
||||||
|
|||||||
@@ -0,0 +1,361 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "proto.h"
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace esphome {
|
||||||
|
namespace api {
|
||||||
|
|
||||||
|
class ProtoSize {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief ProtoSize class for Protocol Buffer serialization size calculation
|
||||||
|
*
|
||||||
|
* This class provides static methods to calculate the exact byte counts needed
|
||||||
|
* for encoding various Protocol Buffer field types. All methods are designed to be
|
||||||
|
* efficient for the common case where many fields have default values.
|
||||||
|
*
|
||||||
|
* Implements Protocol Buffer encoding size calculation according to:
|
||||||
|
* https://protobuf.dev/programming-guides/encoding/
|
||||||
|
*
|
||||||
|
* Key features:
|
||||||
|
* - Early-return optimization for zero/default values
|
||||||
|
* - Direct total_size updates to avoid unnecessary additions
|
||||||
|
* - Specialized handling for different field types according to protobuf spec
|
||||||
|
* - Templated helpers for repeated fields and messages
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates the size in bytes needed to encode a uint32_t value as a varint
|
||||||
|
*
|
||||||
|
* @param value The uint32_t value to calculate size for
|
||||||
|
* @return The number of bytes needed to encode the value
|
||||||
|
*/
|
||||||
|
static inline uint32_t varint(uint32_t value) {
|
||||||
|
// Optimized varint size calculation using leading zeros
|
||||||
|
// Each 7 bits requires one byte in the varint encoding
|
||||||
|
if (value < 128)
|
||||||
|
return 1; // 7 bits, common case for small values
|
||||||
|
|
||||||
|
// For larger values, count bytes needed based on the position of the highest bit set
|
||||||
|
if (value < 16384) {
|
||||||
|
return 2; // 14 bits
|
||||||
|
} else if (value < 2097152) {
|
||||||
|
return 3; // 21 bits
|
||||||
|
} else if (value < 268435456) {
|
||||||
|
return 4; // 28 bits
|
||||||
|
} else {
|
||||||
|
return 5; // 32 bits (maximum for uint32_t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates the size in bytes needed to encode a uint64_t value as a varint
|
||||||
|
*
|
||||||
|
* @param value The uint64_t value to calculate size for
|
||||||
|
* @return The number of bytes needed to encode the value
|
||||||
|
*/
|
||||||
|
static inline uint32_t varint(uint64_t value) {
|
||||||
|
// Handle common case of values fitting in uint32_t (vast majority of use cases)
|
||||||
|
if (value <= UINT32_MAX) {
|
||||||
|
return varint(static_cast<uint32_t>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
// For larger values, determine size based on highest bit position
|
||||||
|
if (value < (1ULL << 35)) {
|
||||||
|
return 5; // 35 bits
|
||||||
|
} else if (value < (1ULL << 42)) {
|
||||||
|
return 6; // 42 bits
|
||||||
|
} else if (value < (1ULL << 49)) {
|
||||||
|
return 7; // 49 bits
|
||||||
|
} else if (value < (1ULL << 56)) {
|
||||||
|
return 8; // 56 bits
|
||||||
|
} else if (value < (1ULL << 63)) {
|
||||||
|
return 9; // 63 bits
|
||||||
|
} else {
|
||||||
|
return 10; // 64 bits (maximum for uint64_t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates the size in bytes needed to encode an int32_t value as a varint
|
||||||
|
*
|
||||||
|
* Special handling is needed for negative values, which are sign-extended to 64 bits
|
||||||
|
* in Protocol Buffers, resulting in a 10-byte varint.
|
||||||
|
*
|
||||||
|
* @param value The int32_t value to calculate size for
|
||||||
|
* @return The number of bytes needed to encode the value
|
||||||
|
*/
|
||||||
|
static inline uint32_t varint(int32_t value) {
|
||||||
|
// Negative values are sign-extended to 64 bits in protocol buffers,
|
||||||
|
// which always results in a 10-byte varint for negative int32
|
||||||
|
if (value < 0) {
|
||||||
|
return 10; // Negative int32 is always 10 bytes long
|
||||||
|
}
|
||||||
|
// For non-negative values, use the uint32_t implementation
|
||||||
|
return varint(static_cast<uint32_t>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates the size in bytes needed to encode an int64_t value as a varint
|
||||||
|
*
|
||||||
|
* @param value The int64_t value to calculate size for
|
||||||
|
* @return The number of bytes needed to encode the value
|
||||||
|
*/
|
||||||
|
static inline uint32_t varint(int64_t value) {
|
||||||
|
// For int64_t, we convert to uint64_t and calculate the size
|
||||||
|
// This works because the bit pattern determines the encoding size,
|
||||||
|
// and we've handled negative int32 values as a special case above
|
||||||
|
return varint(static_cast<uint64_t>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates the size in bytes needed to encode a field ID and wire type
|
||||||
|
*
|
||||||
|
* @param field_id The field identifier
|
||||||
|
* @param type The wire type value (from the WireType enum in the protobuf spec)
|
||||||
|
* @return The number of bytes needed to encode the field ID and wire type
|
||||||
|
*/
|
||||||
|
static inline uint32_t field(uint32_t field_id, uint32_t type) {
|
||||||
|
uint32_t tag = (field_id << 3) | (type & 0b111);
|
||||||
|
return varint(tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Common parameters for all add_*_field methods
|
||||||
|
*
|
||||||
|
* All add_*_field methods follow these common patterns:
|
||||||
|
*
|
||||||
|
* @param total_size Reference to the total message size to update
|
||||||
|
* @param field_id_size Pre-calculated size of the field ID in bytes
|
||||||
|
* @param value The value to calculate size for (type varies)
|
||||||
|
* @param force Whether to calculate size even if the value is default/zero/empty
|
||||||
|
*
|
||||||
|
* Each method follows this implementation pattern:
|
||||||
|
* 1. Skip calculation if value is default (0, false, empty) and not forced
|
||||||
|
* 2. Calculate the size based on the field's encoding rules
|
||||||
|
* 3. Add the field_id_size + calculated value size to total_size
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of an int32 field to the total message size
|
||||||
|
*/
|
||||||
|
static inline void add_int32_field(uint32_t &total_size, uint32_t field_id_size, int32_t value, bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and directly add to total_size
|
||||||
|
if (value < 0) {
|
||||||
|
// Negative values are encoded as 10-byte varints in protobuf
|
||||||
|
total_size += field_id_size + 10;
|
||||||
|
} else {
|
||||||
|
// For non-negative values, use the standard varint size
|
||||||
|
total_size += field_id_size + varint(static_cast<uint32_t>(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a uint32 field to the total message size
|
||||||
|
*/
|
||||||
|
static inline void add_uint32_field(uint32_t &total_size, uint32_t field_id_size, uint32_t value,
|
||||||
|
bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and directly add to total_size
|
||||||
|
total_size += field_id_size + varint(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a boolean field to the total message size
|
||||||
|
*/
|
||||||
|
static inline void add_bool_field(uint32_t &total_size, uint32_t field_id_size, bool value, bool force = false) {
|
||||||
|
// Skip calculation if value is false and not forced
|
||||||
|
if (!value && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boolean fields always use 1 byte when true
|
||||||
|
total_size += field_id_size + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a fixed field to the total message size
|
||||||
|
*
|
||||||
|
* Fixed fields always take exactly N bytes (4 for fixed32/float, 8 for fixed64/double).
|
||||||
|
*
|
||||||
|
* @tparam NumBytes The number of bytes for this fixed field (4 or 8)
|
||||||
|
* @param is_nonzero Whether the value is non-zero
|
||||||
|
*/
|
||||||
|
template<uint32_t NumBytes>
|
||||||
|
static inline void add_fixed_field(uint32_t &total_size, uint32_t field_id_size, bool is_nonzero,
|
||||||
|
bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (!is_nonzero && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fixed fields always take exactly NumBytes
|
||||||
|
total_size += field_id_size + NumBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of an enum field to the total message size
|
||||||
|
*
|
||||||
|
* Enum fields are encoded as uint32 varints.
|
||||||
|
*/
|
||||||
|
static inline void add_enum_field(uint32_t &total_size, uint32_t field_id_size, uint32_t value, bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enums are encoded as uint32
|
||||||
|
total_size += field_id_size + varint(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a sint32 field to the total message size
|
||||||
|
*
|
||||||
|
* Sint32 fields use ZigZag encoding, which is more efficient for negative values.
|
||||||
|
*/
|
||||||
|
static inline void add_sint32_field(uint32_t &total_size, uint32_t field_id_size, int32_t value, bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZigZag encoding for sint32: (n << 1) ^ (n >> 31)
|
||||||
|
uint32_t zigzag = (static_cast<uint32_t>(value) << 1) ^ (static_cast<uint32_t>(value >> 31));
|
||||||
|
total_size += field_id_size + varint(zigzag);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of an int64 field to the total message size
|
||||||
|
*/
|
||||||
|
static inline void add_int64_field(uint32_t &total_size, uint32_t field_id_size, int64_t value, bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and directly add to total_size
|
||||||
|
total_size += field_id_size + varint(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a uint64 field to the total message size
|
||||||
|
*/
|
||||||
|
static inline void add_uint64_field(uint32_t &total_size, uint32_t field_id_size, uint64_t value,
|
||||||
|
bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and directly add to total_size
|
||||||
|
total_size += field_id_size + varint(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a sint64 field to the total message size
|
||||||
|
*
|
||||||
|
* Sint64 fields use ZigZag encoding, which is more efficient for negative values.
|
||||||
|
*/
|
||||||
|
static inline void add_sint64_field(uint32_t &total_size, uint32_t field_id_size, int64_t value, bool force = false) {
|
||||||
|
// Skip calculation if value is zero and not forced
|
||||||
|
if (value == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZigZag encoding for sint64: (n << 1) ^ (n >> 63)
|
||||||
|
uint64_t zigzag = (static_cast<uint64_t>(value) << 1) ^ (static_cast<uint64_t>(value >> 63));
|
||||||
|
total_size += field_id_size + varint(zigzag);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a string/bytes field to the total message size
|
||||||
|
*/
|
||||||
|
static inline void add_string_field(uint32_t &total_size, uint32_t field_id_size, const std::string &str,
|
||||||
|
bool force = false) {
|
||||||
|
// Skip calculation if string is empty and not forced
|
||||||
|
if (str.empty() && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and directly add to total_size
|
||||||
|
const uint32_t str_size = static_cast<uint32_t>(str.size());
|
||||||
|
total_size += field_id_size + varint(str_size) + str_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a nested message field to the total message size
|
||||||
|
*
|
||||||
|
* This helper function directly updates the total_size reference if the nested size
|
||||||
|
* is greater than zero or force is true.
|
||||||
|
*
|
||||||
|
* @param nested_size The pre-calculated size of the nested message
|
||||||
|
*/
|
||||||
|
static inline void add_message_field(uint32_t &total_size, uint32_t field_id_size, uint32_t nested_size,
|
||||||
|
bool force = false) {
|
||||||
|
// Skip calculation if nested message is empty and not forced
|
||||||
|
if (nested_size == 0 && !force) {
|
||||||
|
return; // No need to update total_size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and directly add to total_size
|
||||||
|
// Field ID + length varint + nested message content
|
||||||
|
total_size += field_id_size + varint(nested_size) + nested_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the size of a nested message field to the total message size
|
||||||
|
*
|
||||||
|
* This templated version directly takes a message object, calculates its size internally,
|
||||||
|
* and updates the total_size reference. This eliminates the need for a temporary variable
|
||||||
|
* at the call site.
|
||||||
|
*
|
||||||
|
* @tparam MessageType The type of the nested message (inferred from parameter)
|
||||||
|
* @param message The nested message object
|
||||||
|
*/
|
||||||
|
template<typename MessageType>
|
||||||
|
static inline void add_message_object(uint32_t &total_size, uint32_t field_id_size, const MessageType &message,
|
||||||
|
bool force = false) {
|
||||||
|
uint32_t nested_size = 0;
|
||||||
|
message.calculate_size(nested_size);
|
||||||
|
|
||||||
|
// Use the base implementation with the calculated nested_size
|
||||||
|
add_message_field(total_size, field_id_size, nested_size, force);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculates and adds the sizes of all messages in a repeated field to the total message size
|
||||||
|
*
|
||||||
|
* This helper processes a vector of message objects, calculating the size for each message
|
||||||
|
* and adding it to the total size.
|
||||||
|
*
|
||||||
|
* @tparam MessageType The type of the nested messages in the vector
|
||||||
|
* @param messages Vector of message objects
|
||||||
|
*/
|
||||||
|
template<typename MessageType>
|
||||||
|
static inline void add_repeated_message(uint32_t &total_size, uint32_t field_id_size,
|
||||||
|
const std::vector<MessageType> &messages) {
|
||||||
|
// Skip if the vector is empty
|
||||||
|
if (messages.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For repeated fields, always use force=true
|
||||||
|
for (const auto &message : messages) {
|
||||||
|
add_message_object(total_size, field_id_size, message, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace api
|
||||||
|
} // namespace esphome
|
||||||
@@ -276,6 +276,7 @@ class ProtoMessage {
|
|||||||
virtual ~ProtoMessage() = default;
|
virtual ~ProtoMessage() = default;
|
||||||
virtual void encode(ProtoWriteBuffer buffer) const = 0;
|
virtual void encode(ProtoWriteBuffer buffer) const = 0;
|
||||||
void decode(const uint8_t *buffer, size_t length);
|
void decode(const uint8_t *buffer, size_t length);
|
||||||
|
virtual void calculate_size(uint32_t &total_size) const = 0;
|
||||||
#ifdef HAS_PROTO_MESSAGE_DUMP
|
#ifdef HAS_PROTO_MESSAGE_DUMP
|
||||||
std::string dump() const;
|
std::string dump() const;
|
||||||
virtual void dump_to(std::string &out) const = 0;
|
virtual void dump_to(std::string &out) const = 0;
|
||||||
@@ -302,9 +303,19 @@ class ProtoService {
|
|||||||
virtual bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) = 0;
|
virtual bool send_buffer(ProtoWriteBuffer buffer, uint32_t message_type) = 0;
|
||||||
virtual bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) = 0;
|
virtual bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) = 0;
|
||||||
|
|
||||||
|
// Optimized method that pre-allocates buffer based on message size
|
||||||
template<class C> bool send_message_(const C &msg, uint32_t message_type) {
|
template<class C> bool send_message_(const C &msg, uint32_t message_type) {
|
||||||
|
uint32_t msg_size = 0;
|
||||||
|
msg.calculate_size(msg_size);
|
||||||
|
|
||||||
|
// Create a pre-sized buffer
|
||||||
auto buffer = this->create_buffer();
|
auto buffer = this->create_buffer();
|
||||||
|
buffer.get_buffer()->reserve(msg_size);
|
||||||
|
|
||||||
|
// Encode message into the buffer
|
||||||
msg.encode(buffer);
|
msg.encode(buffer);
|
||||||
|
|
||||||
|
// Send the buffer
|
||||||
return this->send_buffer(buffer, message_type);
|
return this->send_buffer(buffer, message_type);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import IntEnum
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
@@ -10,11 +11,29 @@ import sys
|
|||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
# Generate with
|
|
||||||
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
|
||||||
import aioesphomeapi.api_options_pb2 as pb
|
import aioesphomeapi.api_options_pb2 as pb
|
||||||
import google.protobuf.descriptor_pb2 as descriptor
|
import google.protobuf.descriptor_pb2 as descriptor
|
||||||
|
|
||||||
|
|
||||||
|
class WireType(IntEnum):
|
||||||
|
"""Protocol Buffer wire types as defined in the protobuf spec.
|
||||||
|
|
||||||
|
As specified in the Protocol Buffers encoding guide:
|
||||||
|
https://protobuf.dev/programming-guides/encoding/#structure
|
||||||
|
"""
|
||||||
|
|
||||||
|
VARINT = 0 # int32, int64, uint32, uint64, sint32, sint64, bool, enum
|
||||||
|
FIXED64 = 1 # fixed64, sfixed64, double
|
||||||
|
LENGTH_DELIMITED = 2 # string, bytes, embedded messages, packed repeated fields
|
||||||
|
START_GROUP = 3 # groups (deprecated)
|
||||||
|
END_GROUP = 4 # groups (deprecated)
|
||||||
|
FIXED32 = 5 # fixed32, sfixed32, float
|
||||||
|
|
||||||
|
|
||||||
|
# Generate with
|
||||||
|
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
|
||||||
|
|
||||||
|
|
||||||
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
|
"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
|
||||||
|
|
||||||
It's pretty crappy spaghetti code, but it works.
|
It's pretty crappy spaghetti code, but it works.
|
||||||
@@ -35,7 +54,7 @@ will be generated, they still need to be formatted
|
|||||||
|
|
||||||
|
|
||||||
FILE_HEADER = """// This file was automatically generated with a tool.
|
FILE_HEADER = """// This file was automatically generated with a tool.
|
||||||
// See scripts/api_protobuf/api_protobuf.py
|
// See script/api_protobuf/api_protobuf.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +82,11 @@ def camel_to_snake(name: str) -> str:
|
|||||||
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def force_str(force: bool) -> str:
|
||||||
|
"""Convert a boolean force value to string format for C++ code."""
|
||||||
|
return str(force).lower()
|
||||||
|
|
||||||
|
|
||||||
class TypeInfo(ABC):
|
class TypeInfo(ABC):
|
||||||
"""Base class for all type information."""
|
"""Base class for all type information."""
|
||||||
|
|
||||||
@@ -99,6 +123,11 @@ class TypeInfo(ABC):
|
|||||||
"""Check if the field is repeated."""
|
"""Check if the field is repeated."""
|
||||||
return self._field.label == 3
|
return self._field.label == 3
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wire_type(self) -> WireType:
|
||||||
|
"""Get the wire type for the field."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cpp_type(self) -> str:
|
def cpp_type(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -200,6 +229,35 @@ class TypeInfo(ABC):
|
|||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
"""Dump the value to the output."""
|
"""Dump the value to the output."""
|
||||||
|
|
||||||
|
def calculate_field_id_size(self) -> int:
|
||||||
|
"""Calculates the size of a field ID in bytes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of bytes needed to encode the field ID
|
||||||
|
"""
|
||||||
|
# Calculate the tag by combining field_id and wire_type
|
||||||
|
tag = (self.number << 3) | (self.wire_type & 0b111)
|
||||||
|
|
||||||
|
# Calculate the varint size
|
||||||
|
if tag < 128:
|
||||||
|
return 1 # 7 bits
|
||||||
|
if tag < 16384:
|
||||||
|
return 2 # 14 bits
|
||||||
|
if tag < 2097152:
|
||||||
|
return 3 # 21 bits
|
||||||
|
if tag < 268435456:
|
||||||
|
return 4 # 28 bits
|
||||||
|
return 5 # 32 bits (maximum for uint32_t)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
"""Calculate the size needed for encoding this field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the field
|
||||||
|
force: Whether to force encoding the field even if it has a default value
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
TYPE_INFO: dict[int, TypeInfo] = {}
|
TYPE_INFO: dict[int, TypeInfo] = {}
|
||||||
|
|
||||||
@@ -221,12 +279,18 @@ class DoubleType(TypeInfo):
|
|||||||
default_value = "0.0"
|
default_value = "0.0"
|
||||||
decode_64bit = "value.as_double()"
|
decode_64bit = "value.as_double()"
|
||||||
encode_func = "encode_double"
|
encode_func = "encode_double"
|
||||||
|
wire_type = WireType.FIXED64 # Uses wire type 1 according to protobuf spec
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%g", {name});\n'
|
o = f'sprintf(buffer, "%g", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0.0, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(2)
|
@register_type(2)
|
||||||
class FloatType(TypeInfo):
|
class FloatType(TypeInfo):
|
||||||
@@ -234,12 +298,18 @@ class FloatType(TypeInfo):
|
|||||||
default_value = "0.0f"
|
default_value = "0.0f"
|
||||||
decode_32bit = "value.as_float()"
|
decode_32bit = "value.as_float()"
|
||||||
encode_func = "encode_float"
|
encode_func = "encode_float"
|
||||||
|
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%g", {name});\n'
|
o = f'sprintf(buffer, "%g", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0.0f, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(3)
|
@register_type(3)
|
||||||
class Int64Type(TypeInfo):
|
class Int64Type(TypeInfo):
|
||||||
@@ -247,12 +317,18 @@ class Int64Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_varint = "value.as_int64()"
|
decode_varint = "value.as_int64()"
|
||||||
encode_func = "encode_int64"
|
encode_func = "encode_int64"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_int64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(4)
|
@register_type(4)
|
||||||
class UInt64Type(TypeInfo):
|
class UInt64Type(TypeInfo):
|
||||||
@@ -260,12 +336,18 @@ class UInt64Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_varint = "value.as_uint64()"
|
decode_varint = "value.as_uint64()"
|
||||||
encode_func = "encode_uint64"
|
encode_func = "encode_uint64"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_uint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(5)
|
@register_type(5)
|
||||||
class Int32Type(TypeInfo):
|
class Int32Type(TypeInfo):
|
||||||
@@ -273,12 +355,18 @@ class Int32Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_varint = "value.as_int32()"
|
decode_varint = "value.as_int32()"
|
||||||
encode_func = "encode_int32"
|
encode_func = "encode_int32"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_int32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(6)
|
@register_type(6)
|
||||||
class Fixed64Type(TypeInfo):
|
class Fixed64Type(TypeInfo):
|
||||||
@@ -286,12 +374,18 @@ class Fixed64Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_64bit = "value.as_fixed64()"
|
decode_64bit = "value.as_fixed64()"
|
||||||
encode_func = "encode_fixed64"
|
encode_func = "encode_fixed64"
|
||||||
|
wire_type = WireType.FIXED64 # Uses wire type 1
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%llu", {name});\n'
|
o = f'sprintf(buffer, "%llu", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(7)
|
@register_type(7)
|
||||||
class Fixed32Type(TypeInfo):
|
class Fixed32Type(TypeInfo):
|
||||||
@@ -299,12 +393,18 @@ class Fixed32Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_32bit = "value.as_fixed32()"
|
decode_32bit = "value.as_fixed32()"
|
||||||
encode_func = "encode_fixed32"
|
encode_func = "encode_fixed32"
|
||||||
|
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(8)
|
@register_type(8)
|
||||||
class BoolType(TypeInfo):
|
class BoolType(TypeInfo):
|
||||||
@@ -312,11 +412,17 @@ class BoolType(TypeInfo):
|
|||||||
default_value = "false"
|
default_value = "false"
|
||||||
decode_varint = "value.as_bool()"
|
decode_varint = "value.as_bool()"
|
||||||
encode_func = "encode_bool"
|
encode_func = "encode_bool"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f"out.append(YESNO({name}));"
|
o = f"out.append(YESNO({name}));"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_bool_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(9)
|
@register_type(9)
|
||||||
class StringType(TypeInfo):
|
class StringType(TypeInfo):
|
||||||
@@ -326,11 +432,17 @@ class StringType(TypeInfo):
|
|||||||
const_reference_type = "const std::string &"
|
const_reference_type = "const std::string &"
|
||||||
decode_length = "value.as_string()"
|
decode_length = "value.as_string()"
|
||||||
encode_func = "encode_string"
|
encode_func = "encode_string"
|
||||||
|
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
|
||||||
|
|
||||||
def dump(self, name):
|
def dump(self, name):
|
||||||
o = f'out.append("\'").append({name}).append("\'");'
|
o = f'out.append("\'").append({name}).append("\'");'
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(11)
|
@register_type(11)
|
||||||
class MessageType(TypeInfo):
|
class MessageType(TypeInfo):
|
||||||
@@ -339,6 +451,7 @@ class MessageType(TypeInfo):
|
|||||||
return self._field.type_name[1:]
|
return self._field.type_name[1:]
|
||||||
|
|
||||||
default_value = ""
|
default_value = ""
|
||||||
|
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reference_type(self) -> str:
|
def reference_type(self) -> str:
|
||||||
@@ -360,6 +473,11 @@ class MessageType(TypeInfo):
|
|||||||
o = f"{name}.dump_to(out);"
|
o = f"{name}.dump_to(out);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_message_object(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(12)
|
@register_type(12)
|
||||||
class BytesType(TypeInfo):
|
class BytesType(TypeInfo):
|
||||||
@@ -369,11 +487,17 @@ class BytesType(TypeInfo):
|
|||||||
const_reference_type = "const std::string &"
|
const_reference_type = "const std::string &"
|
||||||
decode_length = "value.as_string()"
|
decode_length = "value.as_string()"
|
||||||
encode_func = "encode_string"
|
encode_func = "encode_string"
|
||||||
|
wire_type = WireType.LENGTH_DELIMITED # Uses wire type 2
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'out.append("\'").append({name}).append("\'");'
|
o = f'out.append("\'").append({name}).append("\'");'
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_string_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(13)
|
@register_type(13)
|
||||||
class UInt32Type(TypeInfo):
|
class UInt32Type(TypeInfo):
|
||||||
@@ -381,12 +505,18 @@ class UInt32Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_varint = "value.as_uint32()"
|
decode_varint = "value.as_uint32()"
|
||||||
encode_func = "encode_uint32"
|
encode_func = "encode_uint32"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_uint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(14)
|
@register_type(14)
|
||||||
class EnumType(TypeInfo):
|
class EnumType(TypeInfo):
|
||||||
@@ -399,6 +529,7 @@ class EnumType(TypeInfo):
|
|||||||
return f"value.as_enum<{self.cpp_type}>()"
|
return f"value.as_enum<{self.cpp_type}>()"
|
||||||
|
|
||||||
default_value = ""
|
default_value = ""
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encode_func(self) -> str:
|
def encode_func(self) -> str:
|
||||||
@@ -408,6 +539,11 @@ class EnumType(TypeInfo):
|
|||||||
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
|
o = f"out.append(proto_enum_to_string<{self.cpp_type}>({name}));"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_enum_field(total_size, {field_id_size}, static_cast<uint32_t>({name}), {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(15)
|
@register_type(15)
|
||||||
class SFixed32Type(TypeInfo):
|
class SFixed32Type(TypeInfo):
|
||||||
@@ -415,12 +551,18 @@ class SFixed32Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_32bit = "value.as_sfixed32()"
|
decode_32bit = "value.as_sfixed32()"
|
||||||
encode_func = "encode_sfixed32"
|
encode_func = "encode_sfixed32"
|
||||||
|
wire_type = WireType.FIXED32 # Uses wire type 5
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_fixed_field<4>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(16)
|
@register_type(16)
|
||||||
class SFixed64Type(TypeInfo):
|
class SFixed64Type(TypeInfo):
|
||||||
@@ -428,12 +570,18 @@ class SFixed64Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_64bit = "value.as_sfixed64()"
|
decode_64bit = "value.as_sfixed64()"
|
||||||
encode_func = "encode_sfixed64"
|
encode_func = "encode_sfixed64"
|
||||||
|
wire_type = WireType.FIXED64 # Uses wire type 1
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_fixed_field<8>(total_size, {field_id_size}, {name} != 0, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(17)
|
@register_type(17)
|
||||||
class SInt32Type(TypeInfo):
|
class SInt32Type(TypeInfo):
|
||||||
@@ -441,12 +589,18 @@ class SInt32Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_varint = "value.as_sint32()"
|
decode_varint = "value.as_sint32()"
|
||||||
encode_func = "encode_sint32"
|
encode_func = "encode_sint32"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
o = f'sprintf(buffer, "%" PRId32, {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_sint32_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
@register_type(18)
|
@register_type(18)
|
||||||
class SInt64Type(TypeInfo):
|
class SInt64Type(TypeInfo):
|
||||||
@@ -454,12 +608,18 @@ class SInt64Type(TypeInfo):
|
|||||||
default_value = "0"
|
default_value = "0"
|
||||||
decode_varint = "value.as_sint64()"
|
decode_varint = "value.as_sint64()"
|
||||||
encode_func = "encode_sint64"
|
encode_func = "encode_sint64"
|
||||||
|
wire_type = WireType.VARINT # Uses wire type 0
|
||||||
|
|
||||||
def dump(self, name: str) -> str:
|
def dump(self, name: str) -> str:
|
||||||
o = f'sprintf(buffer, "%lld", {name});\n'
|
o = f'sprintf(buffer, "%lld", {name});\n'
|
||||||
o += "out.append(buffer);"
|
o += "out.append(buffer);"
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
field_id_size = self.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_sint64_field(total_size, {field_id_size}, {name}, {force_str(force)});"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
class RepeatedTypeInfo(TypeInfo):
|
class RepeatedTypeInfo(TypeInfo):
|
||||||
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
def __init__(self, field: descriptor.FieldDescriptorProto) -> None:
|
||||||
@@ -478,6 +638,14 @@ class RepeatedTypeInfo(TypeInfo):
|
|||||||
def const_reference_type(self) -> str:
|
def const_reference_type(self) -> str:
|
||||||
return f"const {self.cpp_type} &"
|
return f"const {self.cpp_type} &"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wire_type(self) -> WireType:
|
||||||
|
"""Get the wire type for this repeated field.
|
||||||
|
|
||||||
|
For repeated fields, we use the same wire type as the underlying field.
|
||||||
|
"""
|
||||||
|
return self._ti.wire_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def decode_varint_content(self) -> str:
|
def decode_varint_content(self) -> str:
|
||||||
content = self._ti.decode_varint
|
content = self._ti.decode_varint
|
||||||
@@ -554,6 +722,22 @@ class RepeatedTypeInfo(TypeInfo):
|
|||||||
def dump(self, _: str):
|
def dump(self, _: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||||
|
# For repeated fields, we always need to pass force=True to the underlying type's calculation
|
||||||
|
# This is because the encode method always sets force=true for repeated fields
|
||||||
|
if isinstance(self._ti, MessageType):
|
||||||
|
# For repeated messages, use the dedicated helper that handles iteration internally
|
||||||
|
field_id_size = self._ti.calculate_field_id_size()
|
||||||
|
o = f"ProtoSize::add_repeated_message(total_size, {field_id_size}, {name});"
|
||||||
|
return o
|
||||||
|
# For other repeated types, use the underlying type's size calculation with force=True
|
||||||
|
o = f"if (!{name}.empty()) {{\n"
|
||||||
|
o += f" for (const auto {'' if self._ti_is_bool else '&'}it : {name}) {{\n"
|
||||||
|
o += f" {self._ti.get_size_calculation('it', True)}\n"
|
||||||
|
o += " }\n"
|
||||||
|
o += "}"
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
def build_enum_type(desc) -> tuple[str, str]:
|
def build_enum_type(desc) -> tuple[str, str]:
|
||||||
"""Builds the enum type."""
|
"""Builds the enum type."""
|
||||||
@@ -587,6 +771,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
|||||||
decode_64bit: list[str] = []
|
decode_64bit: list[str] = []
|
||||||
encode: list[str] = []
|
encode: list[str] = []
|
||||||
dump: list[str] = []
|
dump: list[str] = []
|
||||||
|
size_calc: list[str] = []
|
||||||
|
|
||||||
for field in desc.field:
|
for field in desc.field:
|
||||||
if field.label == 3:
|
if field.label == 3:
|
||||||
@@ -596,6 +781,7 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
|||||||
protected_content.extend(ti.protected_content)
|
protected_content.extend(ti.protected_content)
|
||||||
public_content.extend(ti.public_content)
|
public_content.extend(ti.public_content)
|
||||||
encode.append(ti.encode_content)
|
encode.append(ti.encode_content)
|
||||||
|
size_calc.append(ti.get_size_calculation(f"this->{ti.field_name}"))
|
||||||
|
|
||||||
if ti.decode_varint_content:
|
if ti.decode_varint_content:
|
||||||
decode_varint.append(ti.decode_varint_content)
|
decode_varint.append(ti.decode_varint_content)
|
||||||
@@ -662,6 +848,25 @@ def build_message_type(desc: descriptor.DescriptorProto) -> tuple[str, str]:
|
|||||||
prot = "void encode(ProtoWriteBuffer buffer) const override;"
|
prot = "void encode(ProtoWriteBuffer buffer) const override;"
|
||||||
public_content.append(prot)
|
public_content.append(prot)
|
||||||
|
|
||||||
|
# Add calculate_size method
|
||||||
|
o = f"void {desc.name}::calculate_size(uint32_t &total_size) const {{"
|
||||||
|
|
||||||
|
# Add a check for empty/default objects to short-circuit the calculation
|
||||||
|
# Only add this optimization if we have fields to check
|
||||||
|
if size_calc:
|
||||||
|
# For a single field, just inline it for simplicity
|
||||||
|
if len(size_calc) == 1 and len(size_calc[0]) + len(o) + 3 < 120:
|
||||||
|
o += f" {size_calc[0]} "
|
||||||
|
else:
|
||||||
|
# For multiple fields, add a short-circuit check
|
||||||
|
o += "\n"
|
||||||
|
# Performance optimization: add all the size calculations
|
||||||
|
o += indent("\n".join(size_calc)) + "\n"
|
||||||
|
o += "}\n"
|
||||||
|
cpp += o
|
||||||
|
prot = "void calculate_size(uint32_t &total_size) const override;"
|
||||||
|
public_content.append(prot)
|
||||||
|
|
||||||
o = f"void {desc.name}::dump_to(std::string &out) const {{"
|
o = f"void {desc.name}::dump_to(std::string &out) const {{"
|
||||||
if dump:
|
if dump:
|
||||||
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
|
if len(dump) == 1 and len(dump[0]) + len(o) + 3 < 120:
|
||||||
@@ -796,6 +1001,7 @@ def main() -> None:
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "proto.h"
|
#include "proto.h"
|
||||||
|
#include "api_pb2_size.h"
|
||||||
|
|
||||||
namespace esphome {
|
namespace esphome {
|
||||||
namespace api {
|
namespace api {
|
||||||
@@ -805,6 +1011,7 @@ def main() -> None:
|
|||||||
cpp = FILE_HEADER
|
cpp = FILE_HEADER
|
||||||
cpp += """\
|
cpp += """\
|
||||||
#include "api_pb2.h"
|
#include "api_pb2.h"
|
||||||
|
#include "api_pb2_size.h"
|
||||||
#include "esphome/core/log.h"
|
#include "esphome/core/log.h"
|
||||||
|
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
|||||||
Reference in New Issue
Block a user