mirror of
https://github.com/esphome/esphome.git
synced 2026-05-22 18:56:40 +08:00
[api] Add 48-bit MAC address varint fast path for BLE advertisements (#15988)
This commit is contained in:
@@ -1639,7 +1639,7 @@ message BluetoothLEAdvertisementResponse {
|
||||
|
||||
message BluetoothLERawAdvertisement {
|
||||
option (inline_encode) = true;
|
||||
uint64 address = 1 [(force) = true];
|
||||
uint64 address = 1 [(force) = true, (mac_address) = true];
|
||||
sint32 rssi = 2 [(force) = true];
|
||||
uint32 address_type = 3 [(max_value) = 4];
|
||||
|
||||
|
||||
@@ -110,4 +110,10 @@ extend google.protobuf.FieldOptions {
|
||||
// length varint calculations and direct byte writes, since the length
|
||||
// varint is guaranteed to be 1 byte.
|
||||
optional uint32 max_data_length = 50018;
|
||||
|
||||
// mac_address: Field is a 48-bit MAC address stored in a uint64.
|
||||
// Emits encode_varint_raw_48bit which has a 7-byte fast path that avoids
|
||||
// the per-byte loop when the upper bits are non-zero (the common case
|
||||
// for real MAC addresses, since OUIs occupy the top 24 bits).
|
||||
optional bool mac_address = 50019 [default=false];
|
||||
}
|
||||
|
||||
@@ -2352,7 +2352,7 @@ BluetoothLERawAdvertisementsResponse::encode(ProtoWriteBuffer &buffer PROTO_ENCO
|
||||
uint8_t *len_pos = pos;
|
||||
ProtoEncode::reserve_byte(pos PROTO_ENCODE_DEBUG_ARG);
|
||||
ProtoEncode::write_raw_byte(pos PROTO_ENCODE_DEBUG_ARG, 8);
|
||||
ProtoEncode::encode_varint_raw_64(pos PROTO_ENCODE_DEBUG_ARG, sub_msg.address);
|
||||
ProtoEncode::encode_varint_raw_48bit(pos PROTO_ENCODE_DEBUG_ARG, sub_msg.address);
|
||||
ProtoEncode::write_raw_byte(pos PROTO_ENCODE_DEBUG_ARG, 16);
|
||||
ProtoEncode::encode_varint_raw_short(pos PROTO_ENCODE_DEBUG_ARG, encode_zigzag32(sub_msg.rssi));
|
||||
if (sub_msg.address_type) {
|
||||
@@ -2373,7 +2373,7 @@ BluetoothLERawAdvertisementsResponse::calculate_size() const {
|
||||
for (uint16_t i = 0; i < this->advertisements_len; i++) {
|
||||
auto &sub_msg = this->advertisements[i];
|
||||
size += 2;
|
||||
size += ProtoSize::calc_uint64_force(1, sub_msg.address);
|
||||
size += ProtoSize::calc_uint64_48bit_force(1, sub_msg.address);
|
||||
size += ProtoSize::calc_sint32_force(1, sub_msg.rssi);
|
||||
size += sub_msg.address_type ? 2 : 0;
|
||||
size += 2 + sub_msg.data_len;
|
||||
|
||||
@@ -21,6 +21,7 @@ void APIServerConnectionBase::log_receive_message_(const LogString *name) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_API
|
||||
void APIConnection::read_message_(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {
|
||||
// Check authentication/connection requirements
|
||||
switch (msg_type) {
|
||||
@@ -706,5 +707,6 @@ void APIConnection::read_message_(uint32_t msg_size, uint32_t msg_type, const ui
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif // USE_API
|
||||
|
||||
} // namespace esphome::api
|
||||
|
||||
@@ -342,6 +342,32 @@ class ProtoEncode {
|
||||
}
|
||||
encode_varint_raw_loop(pos PROTO_ENCODE_DEBUG_ARG, value);
|
||||
}
|
||||
/// Encode a 48-bit MAC address (stored in a uint64) as varint.
|
||||
/// Real MAC addresses occupy the full 48 bits (OUI in upper 24), so the
|
||||
/// fast path -- any non-zero bit in the top 6 of 48 -- emits exactly 7 bytes
|
||||
/// with no per-byte branch. Falls back to the general loop otherwise.
|
||||
/// Caller must guarantee value fits in 48 bits (checked in debug builds).
|
||||
static inline void ESPHOME_ALWAYS_INLINE encode_varint_raw_48bit(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
|
||||
uint64_t value) {
|
||||
#ifdef ESPHOME_DEBUG_API
|
||||
assert(value < (1ULL << (MAC_ADDRESS_SIZE * 8)) && "encode_varint_raw_48bit: value exceeds 48 bits");
|
||||
#endif
|
||||
// 7-byte varint holds 49 bits (7 * 7), so a 48-bit value needs all 7 bytes
|
||||
// whenever bit 42 or higher is set (i.e. value >= 1 << (48 - 6)).
|
||||
if (value >= (1ULL << (MAC_ADDRESS_SIZE * 8 - 6))) [[likely]] {
|
||||
PROTO_ENCODE_CHECK_BOUNDS(pos, 7);
|
||||
pos[0] = static_cast<uint8_t>(value | 0x80);
|
||||
pos[1] = static_cast<uint8_t>((value >> 7) | 0x80);
|
||||
pos[2] = static_cast<uint8_t>((value >> 14) | 0x80);
|
||||
pos[3] = static_cast<uint8_t>((value >> 21) | 0x80);
|
||||
pos[4] = static_cast<uint8_t>((value >> 28) | 0x80);
|
||||
pos[5] = static_cast<uint8_t>((value >> 35) | 0x80);
|
||||
pos[6] = static_cast<uint8_t>(value >> 42);
|
||||
pos += 7;
|
||||
return;
|
||||
}
|
||||
encode_varint_raw_64(pos PROTO_ENCODE_DEBUG_ARG, value);
|
||||
}
|
||||
static inline void ESPHOME_ALWAYS_INLINE encode_field_raw(uint8_t *__restrict__ &pos PROTO_ENCODE_DEBUG_PARAM,
|
||||
uint32_t field_id, uint32_t type) {
|
||||
encode_varint_raw(pos PROTO_ENCODE_DEBUG_ARG, (field_id << 3) | type);
|
||||
@@ -817,6 +843,14 @@ class ProtoSize {
|
||||
static constexpr inline uint32_t ESPHOME_ALWAYS_INLINE calc_uint64_force(uint32_t field_id_size, uint64_t value) {
|
||||
return field_id_size + varint(value);
|
||||
}
|
||||
/// 48-bit MAC address variant: matches encode_varint_raw_48bit's fast path.
|
||||
/// When any of the top 6 of 48 bits is set the encoded varint is 7 bytes;
|
||||
/// otherwise fall back to the general size calculation.
|
||||
/// Caller must guarantee value fits in 48 bits (encoder asserts in debug).
|
||||
static constexpr inline uint32_t ESPHOME_ALWAYS_INLINE calc_uint64_48bit_force(uint32_t field_id_size,
|
||||
uint64_t value) {
|
||||
return field_id_size + (value >= (1ULL << (MAC_ADDRESS_SIZE * 8 - 6)) ? 7 : varint(value));
|
||||
}
|
||||
static constexpr uint32_t calc_length(uint32_t field_id_size, size_t len) {
|
||||
return len ? field_id_size + varint(static_cast<uint32_t>(len)) + static_cast<uint32_t>(len) : 0;
|
||||
}
|
||||
|
||||
@@ -184,6 +184,11 @@ class TypeInfo(ABC):
|
||||
"""Check if this field should always be encoded (skip zero/empty check)."""
|
||||
return get_field_opt(self._field, pb.force, False)
|
||||
|
||||
@property
|
||||
def mac_address(self) -> bool:
|
||||
"""Check if this uint64 field is a 48-bit MAC address (use 7-byte fast path)."""
|
||||
return get_field_opt(self._field, pb.mac_address, False)
|
||||
|
||||
@property
|
||||
def max_value(self) -> int | None:
|
||||
"""Get the max_value option for this field, or None if not set."""
|
||||
@@ -665,8 +670,22 @@ class UInt64Type(VarintTypeMixin, TypeInfo):
|
||||
return o
|
||||
|
||||
def get_size_calculation(self, name: str, force: bool = False) -> str:
|
||||
if self.mac_address and force:
|
||||
field_id_size = self.calculate_field_id_size()
|
||||
return (
|
||||
f"size += ProtoSize::calc_uint64_48bit_force({field_id_size}, {name});"
|
||||
)
|
||||
return self._get_simple_size_calculation(name, force, "uint64")
|
||||
|
||||
@property
|
||||
def RAW_ENCODE_MAP(self) -> dict[str, str]: # noqa: N802
|
||||
if self.mac_address:
|
||||
return {
|
||||
**TypeInfo.RAW_ENCODE_MAP,
|
||||
"encode_uint64": "ProtoEncode::encode_varint_raw_48bit(pos, {value});",
|
||||
}
|
||||
return TypeInfo.RAW_ENCODE_MAP
|
||||
|
||||
def get_estimated_size(self) -> int:
|
||||
return self.calculate_field_id_size() + 3 # field ID + 3 bytes typical varint
|
||||
|
||||
@@ -3558,8 +3577,13 @@ static const char *const TAG = "api.service";
|
||||
# Generate read_message_ as APIConnection method (not base class) so the compiler
|
||||
# can devirtualize and inline the on_* handler calls within the same class.
|
||||
# APIConnection declares this method in api_connection.h.
|
||||
# Guard with #ifdef USE_API since APIConnection itself is only defined when
|
||||
# USE_API is set; without this, builds that compile this .cpp without
|
||||
# USE_API (e.g. C++ unit tests for api dependencies) fail to find the
|
||||
# class declaration.
|
||||
|
||||
out = "void APIConnection::read_message_(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {\n"
|
||||
out = "#ifdef USE_API\n"
|
||||
out += "void APIConnection::read_message_(uint32_t msg_size, uint32_t msg_type, const uint8_t *msg_data) {\n"
|
||||
|
||||
# Auth check block before dispatch switch
|
||||
out += " // Check authentication/connection requirements\n"
|
||||
@@ -3604,6 +3628,7 @@ static const char *const TAG = "api.service";
|
||||
out += " break;\n"
|
||||
out += " }\n"
|
||||
out += "}\n"
|
||||
out += "#endif // USE_API\n"
|
||||
cpp += out
|
||||
hpp += "};\n"
|
||||
|
||||
|
||||
+17
-2
@@ -324,8 +324,23 @@ def compile_and_get_binary(
|
||||
domain_list.append({CONF_PLATFORM: component})
|
||||
# Skip "core" — it's a pseudo-component handled by the build
|
||||
# system, not a real loadable component (get_component returns None)
|
||||
elif get_component(component_name) is not None:
|
||||
config.setdefault(component_name, [])
|
||||
elif (component := get_component(component_name)) is not None:
|
||||
# MULTI_CONF components store their config as a list of dicts,
|
||||
# everything else stores a single dict. Run the component's
|
||||
# schema with {} so defaults get populated -- code paths like
|
||||
# socket.FILTER_SOURCE_FILES expect a fully-populated mapping.
|
||||
if component.multi_conf:
|
||||
config.setdefault(component_name, [])
|
||||
elif component_name not in config:
|
||||
schema = component.config_schema
|
||||
try:
|
||||
config[component_name] = schema({}) if schema is not None else {}
|
||||
except Exception: # noqa: BLE001
|
||||
# Schema requires explicit input we can't synthesize; fall
|
||||
# back to an empty mapping so subscripting at least returns
|
||||
# KeyError on missing keys rather than crashing on the
|
||||
# wrong type.
|
||||
config[component_name] = {}
|
||||
|
||||
# Register platforms from the extra config (benchmark.yaml) so
|
||||
# USE_SENSOR, USE_LIGHT, etc. defines are emitted without needing
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <ios>
|
||||
#include <random>
|
||||
|
||||
#include "esphome/components/api/api_buffer.h"
|
||||
#include "esphome/components/api/proto.h"
|
||||
|
||||
namespace esphome::api::testing {
|
||||
|
||||
// Generic varint decoder, used to verify the encoded bytes round-trip back to
|
||||
// the original 48-bit MAC value, independent of the specialized encoder under
|
||||
// test.
|
||||
static uint64_t decode_varint(const uint8_t *buf, size_t len, size_t *consumed) {
|
||||
uint64_t value = 0;
|
||||
int shift = 0;
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
value |= static_cast<uint64_t>(buf[i] & 0x7F) << shift;
|
||||
if ((buf[i] & 0x80) == 0) {
|
||||
*consumed = i + 1;
|
||||
return value;
|
||||
}
|
||||
shift += 7;
|
||||
}
|
||||
*consumed = 0;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Reference encoder mirroring ProtoEncode::encode_varint_raw_64.
|
||||
static size_t reference_encode(uint64_t value, uint8_t *out) {
|
||||
uint8_t *p = out;
|
||||
if (value < 128) {
|
||||
*p++ = static_cast<uint8_t>(value);
|
||||
return p - out;
|
||||
}
|
||||
do {
|
||||
*p++ = static_cast<uint8_t>(value | 0x80);
|
||||
value >>= 7;
|
||||
} while (value > 0x7F);
|
||||
*p++ = static_cast<uint8_t>(value);
|
||||
return p - out;
|
||||
}
|
||||
|
||||
// Encode `mac` via the 48-bit fast path and verify:
|
||||
// - byte-identical output to the reference loop
|
||||
// - encoded byte length matches `expected_bytes`
|
||||
// - calc_uint64_48bit_force agrees on the size
|
||||
// - the bytes round-trip through a generic varint decoder
|
||||
static void verify_mac(uint64_t mac, size_t expected_bytes) {
|
||||
ASSERT_LT(mac, 1ULL << 48) << "test fixture mac exceeds 48 bits";
|
||||
|
||||
uint8_t ref_buf[16] = {0};
|
||||
size_t ref_len = reference_encode(mac, ref_buf);
|
||||
|
||||
APIBuffer api_buf;
|
||||
api_buf.resize(16);
|
||||
uint8_t *pos = api_buf.data();
|
||||
#ifdef ESPHOME_DEBUG_API
|
||||
uint8_t *proto_debug_end_ = api_buf.data() + api_buf.size();
|
||||
#endif
|
||||
ProtoEncode::encode_varint_raw_48bit(pos PROTO_ENCODE_DEBUG_ARG, mac);
|
||||
size_t new_len = pos - api_buf.data();
|
||||
|
||||
EXPECT_EQ(new_len, expected_bytes) << "mac=0x" << std::hex << mac << std::dec;
|
||||
EXPECT_EQ(ref_len, expected_bytes) << "reference disagrees on length for mac=0x" << std::hex << mac << std::dec;
|
||||
|
||||
for (size_t i = 0; i < new_len; i++) {
|
||||
EXPECT_EQ(api_buf.data()[i], ref_buf[i])
|
||||
<< "byte " << i << " differs for mac=0x" << std::hex << mac << " (got 0x" << static_cast<int>(api_buf.data()[i])
|
||||
<< ", expected 0x" << static_cast<int>(ref_buf[i]) << ")" << std::dec;
|
||||
}
|
||||
|
||||
size_t consumed = 0;
|
||||
uint64_t decoded = decode_varint(api_buf.data(), new_len, &consumed);
|
||||
EXPECT_EQ(consumed, new_len) << "decoder did not consume all bytes for mac=0x" << std::hex << mac << std::dec;
|
||||
EXPECT_EQ(decoded, mac) << "round-trip mismatch for mac=0x" << std::hex << mac << std::dec;
|
||||
|
||||
// Verify the size helper agrees. field_id_size = 1 (typical 1-byte tag).
|
||||
uint32_t calc_size = ProtoSize::calc_uint64_48bit_force(1, mac);
|
||||
EXPECT_EQ(calc_size, 1 + expected_bytes)
|
||||
<< "calc_uint64_48bit_force size mismatch for mac=0x" << std::hex << mac << std::dec;
|
||||
}
|
||||
|
||||
// Compute the canonical varint byte length for a value < 1<<48.
|
||||
static size_t expected_varint_len(uint64_t v) {
|
||||
if (v < (1ULL << 7))
|
||||
return 1;
|
||||
if (v < (1ULL << 14))
|
||||
return 2;
|
||||
if (v < (1ULL << 21))
|
||||
return 3;
|
||||
if (v < (1ULL << 28))
|
||||
return 4;
|
||||
if (v < (1ULL << 35))
|
||||
return 5;
|
||||
if (v < (1ULL << 42))
|
||||
return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
// --- Specific MACs requested for verification ---
|
||||
|
||||
TEST(ProtoMacVarint, AllZeros) { verify_mac(0x000000000000ULL, 1); } // 00:00:00:00:00:00
|
||||
TEST(ProtoMacVarint, FirstByteOnly) { verify_mac(0x110000000000ULL, 7); } // 11:00:00:00:00:00
|
||||
TEST(ProtoMacVarint, SecondByteOnly) { verify_mac(0x00AA00000000ULL, 6); } // 00:AA:00:00:00:00
|
||||
TEST(ProtoMacVarint, ThirdByteOnly) { verify_mac(0x0000BB000000ULL, 5); } // 00:00:BB:00:00:00
|
||||
TEST(ProtoMacVarint, FourthByteOnly) { verify_mac(0x000000CC0000ULL, 4); } // 00:00:00:CC:00:00
|
||||
TEST(ProtoMacVarint, FifthByteOnly) { verify_mac(0x00000000DD00ULL, 3); } // 00:00:00:00:DD:00
|
||||
TEST(ProtoMacVarint, SixthByteOnly) { verify_mac(0x0000000000EEULL, 2); } // 00:00:00:00:00:EE
|
||||
TEST(ProtoMacVarint, AllOnes) { verify_mac(0xFFFFFFFFFFFFULL, 7); } // FF:FF:FF:FF:FF:FF
|
||||
|
||||
// 100 deterministic-random 48-bit MACs to catch regressions across the space.
|
||||
TEST(ProtoMacVarint, RandomSample) {
|
||||
// NOLINTNEXTLINE(cert-msc32-c,cert-msc51-cpp) -- intentional fixed seed for reproducibility.
|
||||
std::mt19937_64 rng(0xC0FFEE);
|
||||
for (int i = 0; i < 100; i++) {
|
||||
uint64_t mac = rng() & 0xFFFFFFFFFFFFULL;
|
||||
verify_mac(mac, expected_varint_len(mac));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace esphome::api::testing
|
||||
@@ -0,0 +1,9 @@
|
||||
from tests.testing_helpers import ComponentManifestOverride
|
||||
|
||||
|
||||
def override_manifest(manifest: ComponentManifestOverride) -> None:
|
||||
# json's to_code calls cg.add_library("bblanchon/ArduinoJson", ...). C++
|
||||
# unit test builds that pull json in transitively (e.g. api) need that
|
||||
# library registration to happen, otherwise json_util.cpp fails to find
|
||||
# ArduinoJson.h.
|
||||
manifest.enable_codegen()
|
||||
Reference in New Issue
Block a user