[mdns] Fix delete/malloc bug and store string constants in flash (#11105)

This commit is contained in:
J. Nick Koston
2025-10-07 18:19:29 -10:00
committed by GitHub
parent 0fe6e7169c
commit ec63247ae0
8 changed files with 61 additions and 51 deletions
+4 -4
View File
@@ -61,7 +61,7 @@ CONFIG_SCHEMA = cv.All(
def mdns_txt_record(key: str, value: str): def mdns_txt_record(key: str, value: str):
return cg.StructInitializer( return cg.StructInitializer(
MDNSTXTRecord, MDNSTXTRecord,
("key", key), ("key", cg.RawExpression(f"MDNS_STR({cg.safe_exp(key)})")),
("value", value), ("value", value),
) )
@@ -71,8 +71,8 @@ def mdns_service(
): ):
return cg.StructInitializer( return cg.StructInitializer(
MDNSService, MDNSService,
("service_type", service), ("service_type", cg.RawExpression(f"MDNS_STR({cg.safe_exp(service)})")),
("proto", proto), ("proto", cg.RawExpression(f"MDNS_STR({cg.safe_exp(proto)})")),
("port", port), ("port", port),
("txt_records", txt_records), ("txt_records", txt_records),
) )
@@ -114,7 +114,7 @@ async def to_code(config):
txt = [ txt = [
cg.StructInitializer( cg.StructInitializer(
MDNSTXTRecord, MDNSTXTRecord,
("key", txt_key), ("key", cg.RawExpression(f"MDNS_STR({cg.safe_exp(txt_key)})")),
("value", await cg.templatable(txt_value, [], cg.std_string)), ("value", await cg.templatable(txt_value, [], cg.std_string)),
) )
for txt_key, txt_value in service[CONF_TXT].items() for txt_key, txt_value in service[CONF_TXT].items()
+20 -23
View File
@@ -9,24 +9,21 @@
#include <pgmspace.h> #include <pgmspace.h>
// Macro to define strings in PROGMEM on ESP8266, regular memory on other platforms // Macro to define strings in PROGMEM on ESP8266, regular memory on other platforms
#define MDNS_STATIC_CONST_CHAR(name, value) static const char name[] PROGMEM = value #define MDNS_STATIC_CONST_CHAR(name, value) static const char name[] PROGMEM = value
// Helper to get string from PROGMEM - returns a temporary std::string // Helper to convert PROGMEM string to std::string for TemplatableValue
// Only define this function if we have services that will use it // Only define this function if we have services that will use it
#if defined(USE_API) || defined(USE_PROMETHEUS) || defined(USE_WEBSERVER) || defined(USE_MDNS_EXTRA_SERVICES) #if defined(USE_API) || defined(USE_PROMETHEUS) || defined(USE_WEBSERVER) || defined(USE_MDNS_EXTRA_SERVICES)
static std::string mdns_string_p(const char *src) { static std::string mdns_str_value(PGM_P str) {
char buf[64]; char buf[64];
strncpy_P(buf, src, sizeof(buf) - 1); strncpy_P(buf, str, sizeof(buf) - 1);
buf[sizeof(buf) - 1] = '\0'; buf[sizeof(buf) - 1] = '\0';
return std::string(buf); return std::string(buf);
} }
#define MDNS_STR(name) mdns_string_p(name) #define MDNS_STR_VALUE(name) mdns_str_value(name)
#else
// If no services are configured, we still need the fallback service but it uses string literals
#define MDNS_STR(name) std::string(name)
#endif #endif
#else #else
// On non-ESP8266 platforms, use regular const char* // On non-ESP8266 platforms, use regular const char*
#define MDNS_STATIC_CONST_CHAR(name, value) static constexpr const char *name = value #define MDNS_STATIC_CONST_CHAR(name, value) static constexpr const char name[] = value
#define MDNS_STR(name) name #define MDNS_STR_VALUE(name) std::string(name)
#endif #endif
#ifdef USE_API #ifdef USE_API
@@ -118,31 +115,31 @@ void MDNSComponent::compile_records_() {
txt_records.push_back({MDNS_STR(TXT_MAC), get_mac_address()}); txt_records.push_back({MDNS_STR(TXT_MAC), get_mac_address()});
#ifdef USE_ESP8266 #ifdef USE_ESP8266
txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR(PLATFORM_ESP8266)}); txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR_VALUE(PLATFORM_ESP8266)});
#elif defined(USE_ESP32) #elif defined(USE_ESP32)
txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR(PLATFORM_ESP32)}); txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR_VALUE(PLATFORM_ESP32)});
#elif defined(USE_RP2040) #elif defined(USE_RP2040)
txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR(PLATFORM_RP2040)}); txt_records.push_back({MDNS_STR(TXT_PLATFORM), MDNS_STR_VALUE(PLATFORM_RP2040)});
#elif defined(USE_LIBRETINY) #elif defined(USE_LIBRETINY)
txt_records.emplace_back(MDNSTXTRecord{"platform", lt_cpu_get_model_name()}); txt_records.push_back({MDNS_STR(TXT_PLATFORM), lt_cpu_get_model_name()});
#endif #endif
txt_records.push_back({MDNS_STR(TXT_BOARD), ESPHOME_BOARD}); txt_records.push_back({MDNS_STR(TXT_BOARD), ESPHOME_BOARD});
#if defined(USE_WIFI) #if defined(USE_WIFI)
txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR(NETWORK_WIFI)}); txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR_VALUE(NETWORK_WIFI)});
#elif defined(USE_ETHERNET) #elif defined(USE_ETHERNET)
txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR(NETWORK_ETHERNET)}); txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR_VALUE(NETWORK_ETHERNET)});
#elif defined(USE_OPENTHREAD) #elif defined(USE_OPENTHREAD)
txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR(NETWORK_THREAD)}); txt_records.push_back({MDNS_STR(TXT_NETWORK), MDNS_STR_VALUE(NETWORK_THREAD)});
#endif #endif
#ifdef USE_API_NOISE #ifdef USE_API_NOISE
MDNS_STATIC_CONST_CHAR(NOISE_ENCRYPTION, "Noise_NNpsk0_25519_ChaChaPoly_SHA256"); MDNS_STATIC_CONST_CHAR(NOISE_ENCRYPTION, "Noise_NNpsk0_25519_ChaChaPoly_SHA256");
if (api::global_api_server->get_noise_ctx()->has_psk()) { if (api::global_api_server->get_noise_ctx()->has_psk()) {
txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION), MDNS_STR(NOISE_ENCRYPTION)}); txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION), MDNS_STR_VALUE(NOISE_ENCRYPTION)});
} else { } else {
txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION_SUPPORTED), MDNS_STR(NOISE_ENCRYPTION)}); txt_records.push_back({MDNS_STR(TXT_API_ENCRYPTION_SUPPORTED), MDNS_STR_VALUE(NOISE_ENCRYPTION)});
} }
#endif #endif
@@ -175,10 +172,10 @@ void MDNSComponent::compile_records_() {
// Publish "http" service if not using native API or any other services // Publish "http" service if not using native API or any other services
// This is just to have *some* mDNS service so that .local resolution works // This is just to have *some* mDNS service so that .local resolution works
auto &fallback_service = this->services_.emplace_next(); auto &fallback_service = this->services_.emplace_next();
fallback_service.service_type = "_http"; fallback_service.service_type = MDNS_STR(SERVICE_HTTP);
fallback_service.proto = "_tcp"; fallback_service.proto = MDNS_STR(SERVICE_TCP);
fallback_service.port = USE_WEBSERVER_PORT; fallback_service.port = USE_WEBSERVER_PORT;
fallback_service.txt_records.emplace_back(MDNSTXTRecord{"version", ESPHOME_VERSION}); fallback_service.txt_records.push_back({MDNS_STR(TXT_VERSION), ESPHOME_VERSION});
#endif #endif
} }
@@ -190,10 +187,10 @@ void MDNSComponent::dump_config() {
#if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE #if ESPHOME_LOG_LEVEL >= ESPHOME_LOG_LEVEL_VERBOSE
ESP_LOGV(TAG, " Services:"); ESP_LOGV(TAG, " Services:");
for (const auto &service : this->services_) { for (const auto &service : this->services_) {
ESP_LOGV(TAG, " - %s, %s, %d", service.service_type.c_str(), service.proto.c_str(), ESP_LOGV(TAG, " - %s, %s, %d", MDNS_STR_ARG(service.service_type), MDNS_STR_ARG(service.proto),
const_cast<TemplatableValue<uint16_t> &>(service.port).value()); const_cast<TemplatableValue<uint16_t> &>(service.port).value());
for (const auto &record : service.txt_records) { for (const auto &record : service.txt_records) {
ESP_LOGV(TAG, " TXT: %s = %s", record.key.c_str(), ESP_LOGV(TAG, " TXT: %s = %s", MDNS_STR_ARG(record.key),
const_cast<TemplatableValue<std::string> &>(record.value).value().c_str()); const_cast<TemplatableValue<std::string> &>(record.value).value().c_str());
} }
} }
+16 -3
View File
@@ -9,21 +9,34 @@
namespace esphome { namespace esphome {
namespace mdns { namespace mdns {
// Helper struct that identifies strings that may be stored in flash storage (similar to LogString)
struct MDNSString;
// Macro to cast string literals to MDNSString* (works on all platforms)
#define MDNS_STR(name) (reinterpret_cast<const esphome::mdns::MDNSString *>(name))
#ifdef USE_ESP8266
#include <pgmspace.h>
#define MDNS_STR_ARG(s) ((PGM_P) (s))
#else
#define MDNS_STR_ARG(s) (reinterpret_cast<const char *>(s))
#endif
// Service count is calculated at compile time by Python codegen // Service count is calculated at compile time by Python codegen
// MDNS_SERVICE_COUNT will always be defined // MDNS_SERVICE_COUNT will always be defined
struct MDNSTXTRecord { struct MDNSTXTRecord {
std::string key; const MDNSString *key;
TemplatableValue<std::string> value; TemplatableValue<std::string> value;
}; };
struct MDNSService { struct MDNSService {
// service name _including_ underscore character prefix // service name _including_ underscore character prefix
// as defined in RFC6763 Section 7 // as defined in RFC6763 Section 7
std::string service_type; const MDNSString *service_type;
// second label indicating protocol _including_ underscore character prefix // second label indicating protocol _including_ underscore character prefix
// as defined in RFC6763 Section 7, like "_tcp" or "_udp" // as defined in RFC6763 Section 7, like "_tcp" or "_udp"
std::string proto; const MDNSString *proto;
TemplatableValue<uint16_t> port; TemplatableValue<uint16_t> port;
std::vector<MDNSTXTRecord> txt_records; std::vector<MDNSTXTRecord> txt_records;
}; };
+7 -7
View File
@@ -29,23 +29,23 @@ void MDNSComponent::setup() {
std::vector<mdns_txt_item_t> txt_records; std::vector<mdns_txt_item_t> txt_records;
for (const auto &record : service.txt_records) { for (const auto &record : service.txt_records) {
mdns_txt_item_t it{}; mdns_txt_item_t it{};
// dup strings to ensure the pointer is valid even after the record loop // key is a compile-time string literal in flash, no need to strdup
it.key = strdup(record.key.c_str()); it.key = MDNS_STR_ARG(record.key);
// value is a temporary from TemplatableValue, must strdup to keep it alive
it.value = strdup(const_cast<TemplatableValue<std::string> &>(record.value).value().c_str()); it.value = strdup(const_cast<TemplatableValue<std::string> &>(record.value).value().c_str());
txt_records.push_back(it); txt_records.push_back(it);
} }
uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value(); uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value();
err = mdns_service_add(nullptr, service.service_type.c_str(), service.proto.c_str(), port, txt_records.data(), err = mdns_service_add(nullptr, MDNS_STR_ARG(service.service_type), MDNS_STR_ARG(service.proto), port,
txt_records.size()); txt_records.data(), txt_records.size());
// free records // free records
for (const auto &it : txt_records) { for (const auto &it : txt_records) {
delete it.key; // NOLINT(cppcoreguidelines-owning-memory) free((void *) it.value); // NOLINT(cppcoreguidelines-no-malloc)
delete it.value; // NOLINT(cppcoreguidelines-owning-memory)
} }
if (err != ESP_OK) { if (err != ESP_OK) {
ESP_LOGW(TAG, "Failed to register service %s: %s", service.service_type.c_str(), esp_err_to_name(err)); ESP_LOGW(TAG, "Failed to register service %s: %s", MDNS_STR_ARG(service.service_type), esp_err_to_name(err));
} }
} }
} }
+6 -6
View File
@@ -21,18 +21,18 @@ void MDNSComponent::setup() {
// part of the wire protocol to have an underscore, and for example ESP-IDF // part of the wire protocol to have an underscore, and for example ESP-IDF
// expects the underscore to be there, the ESP8266 implementation always adds // expects the underscore to be there, the ESP8266 implementation always adds
// the underscore itself. // the underscore itself.
auto *proto = service.proto.c_str(); auto *proto = MDNS_STR_ARG(service.proto);
while (*proto == '_') { while (progmem_read_byte((const uint8_t *) proto) == '_') {
proto++; proto++;
} }
auto *service_type = service.service_type.c_str(); auto *service_type = MDNS_STR_ARG(service.service_type);
while (*service_type == '_') { while (progmem_read_byte((const uint8_t *) service_type) == '_') {
service_type++; service_type++;
} }
uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value(); uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value();
MDNS.addService(service_type, proto, port); MDNS.addService(FPSTR(service_type), FPSTR(proto), port);
for (const auto &record : service.txt_records) { for (const auto &record : service.txt_records) {
MDNS.addServiceTxt(service_type, proto, record.key.c_str(), MDNS.addServiceTxt(FPSTR(service_type), FPSTR(proto), FPSTR(MDNS_STR_ARG(record.key)),
const_cast<TemplatableValue<std::string> &>(record.value).value().c_str()); const_cast<TemplatableValue<std::string> &>(record.value).value().c_str());
} }
} }
+3 -3
View File
@@ -21,18 +21,18 @@ void MDNSComponent::setup() {
// part of the wire protocol to have an underscore, and for example ESP-IDF // part of the wire protocol to have an underscore, and for example ESP-IDF
// expects the underscore to be there, the ESP8266 implementation always adds // expects the underscore to be there, the ESP8266 implementation always adds
// the underscore itself. // the underscore itself.
auto *proto = service.proto.c_str(); auto *proto = MDNS_STR_ARG(service.proto);
while (*proto == '_') { while (*proto == '_') {
proto++; proto++;
} }
auto *service_type = service.service_type.c_str(); auto *service_type = MDNS_STR_ARG(service.service_type);
while (*service_type == '_') { while (*service_type == '_') {
service_type++; service_type++;
} }
uint16_t port_ = const_cast<TemplatableValue<uint16_t> &>(service.port).value(); uint16_t port_ = const_cast<TemplatableValue<uint16_t> &>(service.port).value();
MDNS.addService(service_type, proto, port_); MDNS.addService(service_type, proto, port_);
for (const auto &record : service.txt_records) { for (const auto &record : service.txt_records) {
MDNS.addServiceTxt(service_type, proto, record.key.c_str(), MDNS.addServiceTxt(service_type, proto, MDNS_STR_ARG(record.key),
const_cast<TemplatableValue<std::string> &>(record.value).value().c_str()); const_cast<TemplatableValue<std::string> &>(record.value).value().c_str());
} }
} }
+3 -3
View File
@@ -21,18 +21,18 @@ void MDNSComponent::setup() {
// part of the wire protocol to have an underscore, and for example ESP-IDF // part of the wire protocol to have an underscore, and for example ESP-IDF
// expects the underscore to be there, the ESP8266 implementation always adds // expects the underscore to be there, the ESP8266 implementation always adds
// the underscore itself. // the underscore itself.
auto *proto = service.proto.c_str(); auto *proto = MDNS_STR_ARG(service.proto);
while (*proto == '_') { while (*proto == '_') {
proto++; proto++;
} }
auto *service_type = service.service_type.c_str(); auto *service_type = MDNS_STR_ARG(service.service_type);
while (*service_type == '_') { while (*service_type == '_') {
service_type++; service_type++;
} }
uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value(); uint16_t port = const_cast<TemplatableValue<uint16_t> &>(service.port).value();
MDNS.addService(service_type, proto, port); MDNS.addService(service_type, proto, port);
for (const auto &record : service.txt_records) { for (const auto &record : service.txt_records) {
MDNS.addServiceTxt(service_type, proto, record.key.c_str(), MDNS.addServiceTxt(service_type, proto, MDNS_STR_ARG(record.key),
const_cast<TemplatableValue<std::string> &>(record.value).value().c_str()); const_cast<TemplatableValue<std::string> &>(record.value).value().c_str());
} }
} }
+2 -2
View File
@@ -155,7 +155,7 @@ void OpenThreadSrpComponent::setup() {
// Set service name // Set service name
char *string = otSrpClientBuffersGetServiceEntryServiceNameString(entry, &size); char *string = otSrpClientBuffersGetServiceEntryServiceNameString(entry, &size);
std::string full_service = service.service_type + "." + service.proto; std::string full_service = std::string(MDNS_STR_ARG(service.service_type)) + "." + MDNS_STR_ARG(service.proto);
if (full_service.size() > size) { if (full_service.size() > size) {
ESP_LOGW(TAG, "Service name too long: %s", full_service.c_str()); ESP_LOGW(TAG, "Service name too long: %s", full_service.c_str());
continue; continue;
@@ -181,7 +181,7 @@ void OpenThreadSrpComponent::setup() {
for (size_t i = 0; i < service.txt_records.size(); i++) { for (size_t i = 0; i < service.txt_records.size(); i++) {
const auto &txt = service.txt_records[i]; const auto &txt = service.txt_records[i];
auto value = const_cast<TemplatableValue<std::string> &>(txt.value).value(); auto value = const_cast<TemplatableValue<std::string> &>(txt.value).value();
txt_entries[i].mKey = strdup(txt.key.c_str()); txt_entries[i].mKey = MDNS_STR_ARG(txt.key);
txt_entries[i].mValue = reinterpret_cast<const uint8_t *>(strdup(value.c_str())); txt_entries[i].mValue = reinterpret_cast<const uint8_t *>(strdup(value.c_str()));
txt_entries[i].mValueLength = value.size(); txt_entries[i].mValueLength = value.size();
} }