[key_collector] Add text sensor and allow multiple callbacks (#13617)

This commit is contained in:
Clyde Stubbs
2026-02-03 21:14:09 +11:00
committed by GitHub
parent d4110bf650
commit b3e09e5c68
5 changed files with 146 additions and 68 deletions

View File

@@ -1,4 +1,7 @@
from dataclasses import dataclass
from esphome import automation
from esphome.automation import Trigger
import esphome.codegen as cg
from esphome.components import key_provider
import esphome.config_validation as cv
@@ -10,7 +13,10 @@ from esphome.const import (
CONF_ON_TIMEOUT,
CONF_SOURCE_ID,
CONF_TIMEOUT,
CONF_TRIGGER_ID,
)
from esphome.cpp_generator import MockObj, literal
from esphome.types import TemplateArgsType
CODEOWNERS = ["@ssieb"]
@@ -32,22 +38,50 @@ KeyCollector = key_collector_ns.class_("KeyCollector", cg.Component)
EnableAction = key_collector_ns.class_("EnableAction", automation.Action)
DisableAction = key_collector_ns.class_("DisableAction", automation.Action)
X_TYPE = cg.std_string_ref.operator("const")
@dataclass
class Argument:
type: MockObj
name: str
TRIGGER_TYPES = {
CONF_ON_PROGRESS: [Argument(X_TYPE, "x"), Argument(cg.uint8, "start")],
CONF_ON_RESULT: [
Argument(X_TYPE, "x"),
Argument(cg.uint8, "start"),
Argument(cg.uint8, "end"),
],
CONF_ON_TIMEOUT: [Argument(X_TYPE, "x"), Argument(cg.uint8, "start")],
}
CONFIG_SCHEMA = cv.All(
cv.COMPONENT_SCHEMA.extend(
{
cv.GenerateID(): cv.declare_id(KeyCollector),
cv.Optional(CONF_SOURCE_ID): cv.use_id(key_provider.KeyProvider),
cv.Optional(CONF_MIN_LENGTH): cv.int_,
cv.Optional(CONF_MAX_LENGTH): cv.int_,
cv.Optional(CONF_SOURCE_ID): cv.ensure_list(
cv.use_id(key_provider.KeyProvider)
),
cv.Optional(CONF_MIN_LENGTH): cv.uint16_t,
cv.Optional(CONF_MAX_LENGTH): cv.uint16_t,
cv.Optional(CONF_START_KEYS): cv.string,
cv.Optional(CONF_END_KEYS): cv.string,
cv.Optional(CONF_END_KEY_REQUIRED): cv.boolean,
cv.Optional(CONF_BACK_KEYS): cv.string,
cv.Optional(CONF_CLEAR_KEYS): cv.string,
cv.Optional(CONF_ALLOWED_KEYS): cv.string,
cv.Optional(CONF_ON_PROGRESS): automation.validate_automation(single=True),
cv.Optional(CONF_ON_RESULT): automation.validate_automation(single=True),
cv.Optional(CONF_ON_TIMEOUT): automation.validate_automation(single=True),
**{
cv.Optional(trigger_type): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(
Trigger.template(*[arg.type for arg in args])
),
}
)
for trigger_type, args in TRIGGER_TYPES.items()
},
cv.Optional(CONF_TIMEOUT): cv.positive_time_period_milliseconds,
cv.Optional(CONF_ENABLE_ON_BOOT, default=True): cv.boolean,
}
@@ -59,9 +93,9 @@ CONFIG_SCHEMA = cv.All(
async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID])
await cg.register_component(var, config)
if CONF_SOURCE_ID in config:
source = await cg.get_variable(config[CONF_SOURCE_ID])
cg.add(var.set_provider(source))
for source_conf in config.get(CONF_SOURCE_ID, ()):
source = await cg.get_variable(source_conf)
cg.add(var.add_provider(source))
if CONF_MIN_LENGTH in config:
cg.add(var.set_min_length(config[CONF_MIN_LENGTH]))
if CONF_MAX_LENGTH in config:
@@ -78,26 +112,25 @@ async def to_code(config):
cg.add(var.set_clear_keys(config[CONF_CLEAR_KEYS]))
if CONF_ALLOWED_KEYS in config:
cg.add(var.set_allowed_keys(config[CONF_ALLOWED_KEYS]))
if CONF_ON_PROGRESS in config:
await automation.build_automation(
var.get_progress_trigger(),
[(cg.std_string, "x"), (cg.uint8, "start")],
config[CONF_ON_PROGRESS],
)
if CONF_ON_RESULT in config:
await automation.build_automation(
var.get_result_trigger(),
[(cg.std_string, "x"), (cg.uint8, "start"), (cg.uint8, "end")],
config[CONF_ON_RESULT],
)
if CONF_ON_TIMEOUT in config:
await automation.build_automation(
var.get_timeout_trigger(),
[(cg.std_string, "x"), (cg.uint8, "start")],
config[CONF_ON_TIMEOUT],
)
if CONF_TIMEOUT in config:
cg.add(var.set_timeout(config[CONF_TIMEOUT]))
for trigger_name, args in TRIGGER_TYPES.items():
arglist: TemplateArgsType = [(arg.type, arg.name) for arg in args]
for conf in config.get(trigger_name, ()):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID])
add_trig = getattr(
var,
f"add_on_{trigger_name.rsplit('_', maxsplit=1)[-1].lower()}_callback",
)
await automation.build_automation(
trigger,
arglist,
conf,
)
lamb = trigger.trigger(*[literal(arg.name) for arg in args])
cg.add(add_trig(await cg.process_lambda(lamb, arglist)))
if timeout := config.get(CONF_TIMEOUT):
cg.add(var.set_timeout(timeout))
cg.add(var.set_enabled(config[CONF_ENABLE_ON_BOOT]))

View File

@@ -7,15 +7,10 @@ namespace key_collector {
static const char *const TAG = "key_collector";
KeyCollector::KeyCollector()
: progress_trigger_(new Trigger<std::string, uint8_t>()),
result_trigger_(new Trigger<std::string, uint8_t, uint8_t>()),
timeout_trigger_(new Trigger<std::string, uint8_t>()) {}
void KeyCollector::loop() {
if ((this->timeout_ == 0) || this->result_.empty() || (millis() - this->last_key_time_ < this->timeout_))
return;
this->timeout_trigger_->trigger(this->result_, this->start_key_);
this->timeout_callbacks_.call(this->result_, this->start_key_);
this->clear();
}
@@ -43,64 +38,68 @@ void KeyCollector::dump_config() {
ESP_LOGCONFIG(TAG, " entry timeout: %0.1f", this->timeout_ / 1000.0);
}
void KeyCollector::set_provider(key_provider::KeyProvider *provider) {
provider->add_on_key_callback([this](uint8_t key) { this->key_pressed_(key); });
void KeyCollector::add_provider(key_provider::KeyProvider *provider) {
provider->add_on_key_callback([this](uint8_t key) { this->send_key(key); });
}
void KeyCollector::set_enabled(bool enabled) {
this->enabled_ = enabled;
if (!enabled)
if (!enabled) {
this->clear(false);
}
}
void KeyCollector::clear(bool progress_update) {
auto had_state = !this->result_.empty() || this->start_key_ != 0;
this->result_.clear();
this->start_key_ = 0;
if (progress_update)
this->progress_trigger_->trigger(this->result_, 0);
if (progress_update && had_state) {
this->progress_callbacks_.call(this->result_, 0);
}
this->disable_loop();
}
void KeyCollector::send_key(uint8_t key) { this->key_pressed_(key); }
void KeyCollector::key_pressed_(uint8_t key) {
void KeyCollector::send_key(uint8_t key) {
if (!this->enabled_)
return;
this->last_key_time_ = millis();
if (!this->start_keys_.empty() && !this->start_key_) {
if (this->start_keys_.find(key) != std::string::npos) {
this->start_key_ = key;
this->progress_trigger_->trigger(this->result_, this->start_key_);
this->progress_callbacks_.call(this->result_, this->start_key_);
}
return;
}
if (this->back_keys_.find(key) != std::string::npos) {
if (!this->result_.empty()) {
this->result_.pop_back();
this->progress_trigger_->trigger(this->result_, this->start_key_);
this->progress_callbacks_.call(this->result_, this->start_key_);
}
return;
}
if (this->clear_keys_.find(key) != std::string::npos) {
if (!this->result_.empty())
this->clear();
this->clear();
return;
}
if (this->end_keys_.find(key) != std::string::npos) {
if ((this->min_length_ == 0) || (this->result_.size() >= this->min_length_)) {
this->result_trigger_->trigger(this->result_, this->start_key_, key);
this->result_callbacks_.call(this->result_, this->start_key_, key);
this->clear();
}
return;
}
if (!this->allowed_keys_.empty() && (this->allowed_keys_.find(key) == std::string::npos))
if (!this->allowed_keys_.empty() && this->allowed_keys_.find(key) == std::string::npos)
return;
if ((this->max_length_ == 0) || (this->result_.size() < this->max_length_))
if ((this->max_length_ == 0) || (this->result_.size() < this->max_length_)) {
if (this->result_.empty())
this->enable_loop();
this->result_.push_back(key);
}
if ((this->max_length_ > 0) && (this->result_.size() == this->max_length_) && (!this->end_key_required_)) {
this->result_trigger_->trigger(this->result_, this->start_key_, 0);
this->result_callbacks_.call(this->result_, this->start_key_, 0);
this->clear(false);
}
this->progress_trigger_->trigger(this->result_, this->start_key_);
this->progress_callbacks_.call(this->result_, this->start_key_);
}
} // namespace key_collector

View File

@@ -3,27 +3,33 @@
#include <utility>
#include "esphome/components/key_provider/key_provider.h"
#include "esphome/core/automation.h"
#include "esphome/core/helpers.h"
namespace esphome {
namespace key_collector {
class KeyCollector : public Component {
public:
KeyCollector();
void loop() override;
void dump_config() override;
void set_provider(key_provider::KeyProvider *provider);
void set_min_length(uint32_t min_length) { this->min_length_ = min_length; };
void set_max_length(uint32_t max_length) { this->max_length_ = max_length; };
void add_provider(key_provider::KeyProvider *provider);
void set_min_length(uint16_t min_length) { this->min_length_ = min_length; };
void set_max_length(uint16_t max_length) { this->max_length_ = max_length; };
void set_start_keys(std::string start_keys) { this->start_keys_ = std::move(start_keys); };
void set_end_keys(std::string end_keys) { this->end_keys_ = std::move(end_keys); };
void set_end_key_required(bool end_key_required) { this->end_key_required_ = end_key_required; };
void set_back_keys(std::string back_keys) { this->back_keys_ = std::move(back_keys); };
void set_clear_keys(std::string clear_keys) { this->clear_keys_ = std::move(clear_keys); };
void set_allowed_keys(std::string allowed_keys) { this->allowed_keys_ = std::move(allowed_keys); };
Trigger<std::string, uint8_t> *get_progress_trigger() const { return this->progress_trigger_; };
Trigger<std::string, uint8_t, uint8_t> *get_result_trigger() const { return this->result_trigger_; };
Trigger<std::string, uint8_t> *get_timeout_trigger() const { return this->timeout_trigger_; };
void add_on_progress_callback(std::function<void(const std::string &, uint8_t)> &&callback) {
this->progress_callbacks_.add(std::move(callback));
}
void add_on_result_callback(std::function<void(const std::string &, uint8_t, uint8_t)> &&callback) {
this->result_callbacks_.add(std::move(callback));
}
void add_on_timeout_callback(std::function<void(const std::string &, uint8_t)> &&callback) {
this->timeout_callbacks_.add(std::move(callback));
}
void set_timeout(int timeout) { this->timeout_ = timeout; };
void set_enabled(bool enabled);
@@ -31,10 +37,8 @@ class KeyCollector : public Component {
void send_key(uint8_t key);
protected:
void key_pressed_(uint8_t key);
uint32_t min_length_{0};
uint32_t max_length_{0};
uint16_t min_length_{0};
uint16_t max_length_{0};
std::string start_keys_;
std::string end_keys_;
bool end_key_required_{false};
@@ -43,12 +47,12 @@ class KeyCollector : public Component {
std::string allowed_keys_;
std::string result_;
uint8_t start_key_{0};
Trigger<std::string, uint8_t> *progress_trigger_;
Trigger<std::string, uint8_t, uint8_t> *result_trigger_;
Trigger<std::string, uint8_t> *timeout_trigger_;
uint32_t last_key_time_;
LazyCallbackManager<void(const std::string &, uint8_t)> progress_callbacks_;
LazyCallbackManager<void(const std::string &, uint8_t, uint8_t)> result_callbacks_;
LazyCallbackManager<void(const std::string &, uint8_t)> timeout_callbacks_;
uint32_t last_key_time_{};
uint32_t timeout_{0};
bool enabled_;
bool enabled_{};
};
template<typename... Ts> class EnableAction : public Action<Ts...>, public Parented<KeyCollector> {

View File

@@ -0,0 +1,28 @@
import esphome.codegen as cg
from esphome.components import text_sensor
from esphome.components.text_sensor import TextSensor
import esphome.config_validation as cv
from esphome.const import CONF_ID
from esphome.cpp_generator import literal
from esphome.types import TemplateArgsType
from .. import CONF_ON_RESULT, CONF_SOURCE_ID, TRIGGER_TYPES, KeyCollector
CONFIG_SCHEMA = text_sensor.text_sensor_schema(TextSensor).extend(
{
cv.GenerateID(CONF_SOURCE_ID): cv.use_id(KeyCollector),
}
)
async def to_code(config):
parent = await cg.get_variable(config[CONF_SOURCE_ID])
var = cg.new_Pvariable(config[CONF_ID])
await text_sensor.register_text_sensor(var, config)
args = TRIGGER_TYPES[CONF_ON_RESULT]
arglist: TemplateArgsType = [(arg.type, arg.name) for arg in args]
cg.add(
parent.add_on_result_callback(
await cg.process_lambda(var.publish_state(literal(args[0].name)), arglist)
)
)

View File

@@ -18,14 +18,23 @@ key_collector:
- logger.log:
format: "input progress: '%s', started by '%c'"
args: ['x.c_str()', "(start == 0 ? '~' : start)"]
- logger.log:
format: "second listener - progress: '%s'"
args: ['x.c_str()']
on_result:
- logger.log:
format: "input result: '%s', started by '%c', ended by '%c'"
args: ['x.c_str()', "(start == 0 ? '~' : start)", "(end == 0 ? '~' : end)"]
- logger.log:
format: "second listener - result: '%s'"
args: ['x.c_str()']
on_timeout:
- logger.log:
format: "input timeout: '%s', started by '%c'"
args: ['x.c_str()', "(start == 0 ? '~' : start)"]
- logger.log:
format: "second listener - timeout: '%s'"
args: ['x.c_str()']
enable_on_boot: false
button:
@@ -34,3 +43,8 @@ button:
on_press:
- key_collector.enable:
- key_collector.disable:
text_sensor:
- platform: key_collector
id: collected_keys
source_id: reader