From b3e09e5c68b4734bb8cd896ce8aedd00523e1a02 Mon Sep 17 00:00:00 2001 From: Clyde Stubbs <2366188+clydebarrow@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:14:09 +1100 Subject: [PATCH] [key_collector] Add text sensor and allow multiple callbacks (#13617) --- esphome/components/key_collector/__init__.py | 91 +++++++++++++------ .../key_collector/key_collector.cpp | 45 +++++---- .../components/key_collector/key_collector.h | 36 ++++---- .../key_collector/text_sensor/__init__.py | 28 ++++++ tests/components/key_collector/common.yaml | 14 +++ 5 files changed, 146 insertions(+), 68 deletions(-) create mode 100644 esphome/components/key_collector/text_sensor/__init__.py diff --git a/esphome/components/key_collector/__init__.py b/esphome/components/key_collector/__init__.py index 17af40da1a..badb28c32c 100644 --- a/esphome/components/key_collector/__init__.py +++ b/esphome/components/key_collector/__init__.py @@ -1,4 +1,7 @@ +from dataclasses import dataclass + from esphome import automation +from esphome.automation import Trigger import esphome.codegen as cg from esphome.components import key_provider import esphome.config_validation as cv @@ -10,7 +13,10 @@ from esphome.const import ( CONF_ON_TIMEOUT, CONF_SOURCE_ID, CONF_TIMEOUT, + CONF_TRIGGER_ID, ) +from esphome.cpp_generator import MockObj, literal +from esphome.types import TemplateArgsType CODEOWNERS = ["@ssieb"] @@ -32,22 +38,50 @@ KeyCollector = key_collector_ns.class_("KeyCollector", cg.Component) EnableAction = key_collector_ns.class_("EnableAction", automation.Action) DisableAction = key_collector_ns.class_("DisableAction", automation.Action) +X_TYPE = cg.std_string_ref.operator("const") + + +@dataclass +class Argument: + type: MockObj + name: str + + +TRIGGER_TYPES = { + CONF_ON_PROGRESS: [Argument(X_TYPE, "x"), Argument(cg.uint8, "start")], + CONF_ON_RESULT: [ + Argument(X_TYPE, "x"), + Argument(cg.uint8, "start"), + Argument(cg.uint8, "end"), + ], + CONF_ON_TIMEOUT: [Argument(X_TYPE, "x"), Argument(cg.uint8, "start")], +} + CONFIG_SCHEMA = cv.All( cv.COMPONENT_SCHEMA.extend( { cv.GenerateID(): cv.declare_id(KeyCollector), - cv.Optional(CONF_SOURCE_ID): cv.use_id(key_provider.KeyProvider), - cv.Optional(CONF_MIN_LENGTH): cv.int_, - cv.Optional(CONF_MAX_LENGTH): cv.int_, + cv.Optional(CONF_SOURCE_ID): cv.ensure_list( + cv.use_id(key_provider.KeyProvider) + ), + cv.Optional(CONF_MIN_LENGTH): cv.uint16_t, + cv.Optional(CONF_MAX_LENGTH): cv.uint16_t, cv.Optional(CONF_START_KEYS): cv.string, cv.Optional(CONF_END_KEYS): cv.string, cv.Optional(CONF_END_KEY_REQUIRED): cv.boolean, cv.Optional(CONF_BACK_KEYS): cv.string, cv.Optional(CONF_CLEAR_KEYS): cv.string, cv.Optional(CONF_ALLOWED_KEYS): cv.string, - cv.Optional(CONF_ON_PROGRESS): automation.validate_automation(single=True), - cv.Optional(CONF_ON_RESULT): automation.validate_automation(single=True), - cv.Optional(CONF_ON_TIMEOUT): automation.validate_automation(single=True), + **{ + cv.Optional(trigger_type): automation.validate_automation( + { + cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id( + Trigger.template(*[arg.type for arg in args]) + ), + } + ) + for trigger_type, args in TRIGGER_TYPES.items() + }, cv.Optional(CONF_TIMEOUT): cv.positive_time_period_milliseconds, cv.Optional(CONF_ENABLE_ON_BOOT, default=True): cv.boolean, } @@ -59,9 +93,9 @@ CONFIG_SCHEMA = cv.All( async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) - if CONF_SOURCE_ID in config: - source = await cg.get_variable(config[CONF_SOURCE_ID]) - cg.add(var.set_provider(source)) + for source_conf in config.get(CONF_SOURCE_ID, ()): + source = await cg.get_variable(source_conf) + cg.add(var.add_provider(source)) if CONF_MIN_LENGTH in config: cg.add(var.set_min_length(config[CONF_MIN_LENGTH])) if CONF_MAX_LENGTH in config: @@ -78,26 +112,25 @@ async def to_code(config): cg.add(var.set_clear_keys(config[CONF_CLEAR_KEYS])) if CONF_ALLOWED_KEYS in config: cg.add(var.set_allowed_keys(config[CONF_ALLOWED_KEYS])) - if CONF_ON_PROGRESS in config: - await automation.build_automation( - var.get_progress_trigger(), - [(cg.std_string, "x"), (cg.uint8, "start")], - config[CONF_ON_PROGRESS], - ) - if CONF_ON_RESULT in config: - await automation.build_automation( - var.get_result_trigger(), - [(cg.std_string, "x"), (cg.uint8, "start"), (cg.uint8, "end")], - config[CONF_ON_RESULT], - ) - if CONF_ON_TIMEOUT in config: - await automation.build_automation( - var.get_timeout_trigger(), - [(cg.std_string, "x"), (cg.uint8, "start")], - config[CONF_ON_TIMEOUT], - ) - if CONF_TIMEOUT in config: - cg.add(var.set_timeout(config[CONF_TIMEOUT])) + + for trigger_name, args in TRIGGER_TYPES.items(): + arglist: TemplateArgsType = [(arg.type, arg.name) for arg in args] + for conf in config.get(trigger_name, ()): + trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) + add_trig = getattr( + var, + f"add_on_{trigger_name.rsplit('_', maxsplit=1)[-1].lower()}_callback", + ) + await automation.build_automation( + trigger, + arglist, + conf, + ) + lamb = trigger.trigger(*[literal(arg.name) for arg in args]) + cg.add(add_trig(await cg.process_lambda(lamb, arglist))) + + if timeout := config.get(CONF_TIMEOUT): + cg.add(var.set_timeout(timeout)) cg.add(var.set_enabled(config[CONF_ENABLE_ON_BOOT])) diff --git a/esphome/components/key_collector/key_collector.cpp b/esphome/components/key_collector/key_collector.cpp index 9cfc74f50e..68d1c60bf9 100644 --- a/esphome/components/key_collector/key_collector.cpp +++ b/esphome/components/key_collector/key_collector.cpp @@ -7,15 +7,10 @@ namespace key_collector { static const char *const TAG = "key_collector"; -KeyCollector::KeyCollector() - : progress_trigger_(new Trigger()), - result_trigger_(new Trigger()), - timeout_trigger_(new Trigger()) {} - void KeyCollector::loop() { if ((this->timeout_ == 0) || this->result_.empty() || (millis() - this->last_key_time_ < this->timeout_)) return; - this->timeout_trigger_->trigger(this->result_, this->start_key_); + this->timeout_callbacks_.call(this->result_, this->start_key_); this->clear(); } @@ -43,64 +38,68 @@ void KeyCollector::dump_config() { ESP_LOGCONFIG(TAG, " entry timeout: %0.1f", this->timeout_ / 1000.0); } -void KeyCollector::set_provider(key_provider::KeyProvider *provider) { - provider->add_on_key_callback([this](uint8_t key) { this->key_pressed_(key); }); +void KeyCollector::add_provider(key_provider::KeyProvider *provider) { + provider->add_on_key_callback([this](uint8_t key) { this->send_key(key); }); } void KeyCollector::set_enabled(bool enabled) { this->enabled_ = enabled; - if (!enabled) + if (!enabled) { this->clear(false); + } } void KeyCollector::clear(bool progress_update) { + auto had_state = !this->result_.empty() || this->start_key_ != 0; this->result_.clear(); this->start_key_ = 0; - if (progress_update) - this->progress_trigger_->trigger(this->result_, 0); + if (progress_update && had_state) { + this->progress_callbacks_.call(this->result_, 0); + } + this->disable_loop(); } -void KeyCollector::send_key(uint8_t key) { this->key_pressed_(key); } - -void KeyCollector::key_pressed_(uint8_t key) { +void KeyCollector::send_key(uint8_t key) { if (!this->enabled_) return; this->last_key_time_ = millis(); if (!this->start_keys_.empty() && !this->start_key_) { if (this->start_keys_.find(key) != std::string::npos) { this->start_key_ = key; - this->progress_trigger_->trigger(this->result_, this->start_key_); + this->progress_callbacks_.call(this->result_, this->start_key_); } return; } if (this->back_keys_.find(key) != std::string::npos) { if (!this->result_.empty()) { this->result_.pop_back(); - this->progress_trigger_->trigger(this->result_, this->start_key_); + this->progress_callbacks_.call(this->result_, this->start_key_); } return; } if (this->clear_keys_.find(key) != std::string::npos) { - if (!this->result_.empty()) - this->clear(); + this->clear(); return; } if (this->end_keys_.find(key) != std::string::npos) { if ((this->min_length_ == 0) || (this->result_.size() >= this->min_length_)) { - this->result_trigger_->trigger(this->result_, this->start_key_, key); + this->result_callbacks_.call(this->result_, this->start_key_, key); this->clear(); } return; } - if (!this->allowed_keys_.empty() && (this->allowed_keys_.find(key) == std::string::npos)) + if (!this->allowed_keys_.empty() && this->allowed_keys_.find(key) == std::string::npos) return; - if ((this->max_length_ == 0) || (this->result_.size() < this->max_length_)) + if ((this->max_length_ == 0) || (this->result_.size() < this->max_length_)) { + if (this->result_.empty()) + this->enable_loop(); this->result_.push_back(key); + } if ((this->max_length_ > 0) && (this->result_.size() == this->max_length_) && (!this->end_key_required_)) { - this->result_trigger_->trigger(this->result_, this->start_key_, 0); + this->result_callbacks_.call(this->result_, this->start_key_, 0); this->clear(false); } - this->progress_trigger_->trigger(this->result_, this->start_key_); + this->progress_callbacks_.call(this->result_, this->start_key_); } } // namespace key_collector diff --git a/esphome/components/key_collector/key_collector.h b/esphome/components/key_collector/key_collector.h index 735f396809..8e30c333df 100644 --- a/esphome/components/key_collector/key_collector.h +++ b/esphome/components/key_collector/key_collector.h @@ -3,27 +3,33 @@ #include #include "esphome/components/key_provider/key_provider.h" #include "esphome/core/automation.h" +#include "esphome/core/helpers.h" namespace esphome { namespace key_collector { class KeyCollector : public Component { public: - KeyCollector(); void loop() override; void dump_config() override; - void set_provider(key_provider::KeyProvider *provider); - void set_min_length(uint32_t min_length) { this->min_length_ = min_length; }; - void set_max_length(uint32_t max_length) { this->max_length_ = max_length; }; + void add_provider(key_provider::KeyProvider *provider); + void set_min_length(uint16_t min_length) { this->min_length_ = min_length; }; + void set_max_length(uint16_t max_length) { this->max_length_ = max_length; }; void set_start_keys(std::string start_keys) { this->start_keys_ = std::move(start_keys); }; void set_end_keys(std::string end_keys) { this->end_keys_ = std::move(end_keys); }; void set_end_key_required(bool end_key_required) { this->end_key_required_ = end_key_required; }; void set_back_keys(std::string back_keys) { this->back_keys_ = std::move(back_keys); }; void set_clear_keys(std::string clear_keys) { this->clear_keys_ = std::move(clear_keys); }; void set_allowed_keys(std::string allowed_keys) { this->allowed_keys_ = std::move(allowed_keys); }; - Trigger *get_progress_trigger() const { return this->progress_trigger_; }; - Trigger *get_result_trigger() const { return this->result_trigger_; }; - Trigger *get_timeout_trigger() const { return this->timeout_trigger_; }; + void add_on_progress_callback(std::function &&callback) { + this->progress_callbacks_.add(std::move(callback)); + } + void add_on_result_callback(std::function &&callback) { + this->result_callbacks_.add(std::move(callback)); + } + void add_on_timeout_callback(std::function &&callback) { + this->timeout_callbacks_.add(std::move(callback)); + } void set_timeout(int timeout) { this->timeout_ = timeout; }; void set_enabled(bool enabled); @@ -31,10 +37,8 @@ class KeyCollector : public Component { void send_key(uint8_t key); protected: - void key_pressed_(uint8_t key); - - uint32_t min_length_{0}; - uint32_t max_length_{0}; + uint16_t min_length_{0}; + uint16_t max_length_{0}; std::string start_keys_; std::string end_keys_; bool end_key_required_{false}; @@ -43,12 +47,12 @@ class KeyCollector : public Component { std::string allowed_keys_; std::string result_; uint8_t start_key_{0}; - Trigger *progress_trigger_; - Trigger *result_trigger_; - Trigger *timeout_trigger_; - uint32_t last_key_time_; + LazyCallbackManager progress_callbacks_; + LazyCallbackManager result_callbacks_; + LazyCallbackManager timeout_callbacks_; + uint32_t last_key_time_{}; uint32_t timeout_{0}; - bool enabled_; + bool enabled_{}; }; template class EnableAction : public Action, public Parented { diff --git a/esphome/components/key_collector/text_sensor/__init__.py b/esphome/components/key_collector/text_sensor/__init__.py new file mode 100644 index 0000000000..1676cf7bdf --- /dev/null +++ b/esphome/components/key_collector/text_sensor/__init__.py @@ -0,0 +1,28 @@ +import esphome.codegen as cg +from esphome.components import text_sensor +from esphome.components.text_sensor import TextSensor +import esphome.config_validation as cv +from esphome.const import CONF_ID +from esphome.cpp_generator import literal +from esphome.types import TemplateArgsType + +from .. import CONF_ON_RESULT, CONF_SOURCE_ID, TRIGGER_TYPES, KeyCollector + +CONFIG_SCHEMA = text_sensor.text_sensor_schema(TextSensor).extend( + { + cv.GenerateID(CONF_SOURCE_ID): cv.use_id(KeyCollector), + } +) + + +async def to_code(config): + parent = await cg.get_variable(config[CONF_SOURCE_ID]) + var = cg.new_Pvariable(config[CONF_ID]) + await text_sensor.register_text_sensor(var, config) + args = TRIGGER_TYPES[CONF_ON_RESULT] + arglist: TemplateArgsType = [(arg.type, arg.name) for arg in args] + cg.add( + parent.add_on_result_callback( + await cg.process_lambda(var.publish_state(literal(args[0].name)), arglist) + ) + ) diff --git a/tests/components/key_collector/common.yaml b/tests/components/key_collector/common.yaml index 12e541c865..43a0478a18 100644 --- a/tests/components/key_collector/common.yaml +++ b/tests/components/key_collector/common.yaml @@ -18,14 +18,23 @@ key_collector: - logger.log: format: "input progress: '%s', started by '%c'" args: ['x.c_str()', "(start == 0 ? '~' : start)"] + - logger.log: + format: "second listener - progress: '%s'" + args: ['x.c_str()'] on_result: - logger.log: format: "input result: '%s', started by '%c', ended by '%c'" args: ['x.c_str()', "(start == 0 ? '~' : start)", "(end == 0 ? '~' : end)"] + - logger.log: + format: "second listener - result: '%s'" + args: ['x.c_str()'] on_timeout: - logger.log: format: "input timeout: '%s', started by '%c'" args: ['x.c_str()', "(start == 0 ? '~' : start)"] + - logger.log: + format: "second listener - timeout: '%s'" + args: ['x.c_str()'] enable_on_boot: false button: @@ -34,3 +43,8 @@ button: on_press: - key_collector.enable: - key_collector.disable: + +text_sensor: + - platform: key_collector + id: collected_keys + source_id: reader