[automation] Eliminate trigger trampolines with deduplicated forwarder structs (#15174)

This commit is contained in:
J. Nick Koston
2026-03-26 13:50:50 -10:00
committed by GitHub
parent 6aafb521c1
commit fa8a609bcc
10 changed files with 226 additions and 166 deletions
+44
View File
@@ -137,6 +137,9 @@ UpdateComponentAction = cg.esphome_ns.class_("UpdateComponentAction", Action)
SuspendComponentAction = cg.esphome_ns.class_("SuspendComponentAction", Action)
ResumeComponentAction = cg.esphome_ns.class_("ResumeComponentAction", Action)
Automation = cg.esphome_ns.class_("Automation")
TriggerForwarder = cg.esphome_ns.class_("TriggerForwarder")
TriggerOnTrueForwarder = cg.esphome_ns.class_("TriggerOnTrueForwarder")
TriggerOnFalseForwarder = cg.esphome_ns.class_("TriggerOnFalseForwarder")
LambdaCondition = cg.esphome_ns.class_("LambdaCondition", Condition)
StatelessLambdaCondition = cg.esphome_ns.class_("StatelessLambdaCondition", Condition)
@@ -661,3 +664,44 @@ async def build_automation(
actions = await build_action_list(config[CONF_THEN], templ, args)
cg.add(obj.add_actions(actions))
return obj
async def build_callback_automation(
parent: MockObj,
callback_method: str,
args: TemplateArgsType,
config: ConfigType,
forwarder: MockObj | MockObjClass | None = None,
) -> None:
"""Build an Automation and register it as a callback on the parent.
Eliminates the need for a Trigger wrapper object by registering the
automation's trigger() directly as a callback on the parent component.
Uses template forwarder structs so the compiler deduplicates the operator()
body across all call sites with the same signature. The forwarder must be
pointer-sized (single Automation* field) to fit inline in Callback::ctx_
and avoid heap allocation.
:param parent: The component object (e.g., button, sensor).
:param callback_method: Name of the callback method (e.g., "add_on_press_callback").
:param args: Automation template args as list of (type, name) tuples.
:param config: The automation config dict.
:param forwarder: Optional forwarder type to use instead of the default
TriggerForwarder<Ts...>. Pass any struct type whose aggregate init takes
a single Automation pointer (e.g., TriggerOnTrueForwarder).
"""
arg_types = [arg[0] for arg in args]
templ = cg.TemplateArguments(*arg_types)
obj = cg.new_Pvariable(config[CONF_AUTOMATION_ID], templ)
actions = await build_action_list(config[CONF_THEN], templ, args)
cg.add(obj.add_actions(actions))
# Use template forwarder structs for deduplication. The compiler generates
# one operator() per forwarder type; different automation pointers are just
# data in the struct.
if forwarder is None:
forwarder = TriggerForwarder.template(templ)
# RawExpression for aggregate init — both forwarder and obj are codegen
# MockObjs (not user input), and there's no Expression type for positional
# aggregate initialization (StructInitializer uses named fields).
cg.add(getattr(parent, callback_method)(cg.RawExpression(f"{forwarder}{{{obj}}}")))
+18 -43
View File
@@ -120,10 +120,6 @@ BinarySensorInitiallyOff = binary_sensor_ns.class_(
BinarySensorPtr = BinarySensor.operator("ptr")
# Triggers
PressTrigger = binary_sensor_ns.class_("PressTrigger", automation.Trigger.template())
ReleaseTrigger = binary_sensor_ns.class_(
"ReleaseTrigger", automation.Trigger.template()
)
ClickTrigger = binary_sensor_ns.class_("ClickTrigger", automation.Trigger.template())
DoubleClickTrigger = binary_sensor_ns.class_(
"DoubleClickTrigger", automation.Trigger.template()
@@ -132,13 +128,6 @@ MultiClickTrigger = binary_sensor_ns.class_(
"MultiClickTrigger", automation.Trigger.template(), cg.Component
)
MultiClickTriggerEvent = binary_sensor_ns.struct("MultiClickTriggerEvent")
StateTrigger = binary_sensor_ns.class_(
"StateTrigger", automation.Trigger.template(bool)
)
StateChangeTrigger = binary_sensor_ns.class_(
"StateChangeTrigger",
automation.Trigger.template(cg.optional.template(bool), cg.optional.template(bool)),
)
BinarySensorPublishAction = binary_sensor_ns.class_(
"BinarySensorPublishAction", automation.Action
@@ -458,16 +447,8 @@ _BINARY_SENSOR_SCHEMA = (
): cv.boolean,
cv.Optional(CONF_DEVICE_CLASS): validate_device_class,
cv.Optional(CONF_FILTERS): validate_filters,
cv.Optional(CONF_ON_PRESS): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(PressTrigger),
}
),
cv.Optional(CONF_ON_RELEASE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ReleaseTrigger),
}
),
cv.Optional(CONF_ON_PRESS): automation.validate_automation({}),
cv.Optional(CONF_ON_RELEASE): automation.validate_automation({}),
cv.Optional(CONF_ON_CLICK): cv.All(
automation.validate_automation(
{
@@ -509,16 +490,8 @@ _BINARY_SENSOR_SCHEMA = (
): cv.positive_time_period_milliseconds,
}
),
cv.Optional(CONF_ON_STATE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(StateTrigger),
}
),
cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(StateChangeTrigger),
}
),
cv.Optional(CONF_ON_STATE): automation.validate_automation({}),
cv.Optional(CONF_ON_STATE_CHANGE): automation.validate_automation({}),
}
)
)
@@ -556,13 +529,14 @@ def binary_sensor_schema(
@coroutine_with_priority(CoroPriority.AUTOMATION)
async def _build_binary_sensor_automations(var, config):
for conf in config.get(CONF_ON_PRESS, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
for conf in config.get(CONF_ON_RELEASE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
for conf_key, forwarder in (
(CONF_ON_PRESS, automation.TriggerOnTrueForwarder),
(CONF_ON_RELEASE, automation.TriggerOnFalseForwarder),
):
for conf in config.get(conf_key, []):
await automation.build_callback_automation(
var, "add_on_state_callback", [], conf, forwarder=forwarder
)
for conf in config.get(CONF_ON_CLICK, []):
trigger = cg.new_Pvariable(
@@ -593,13 +567,14 @@ async def _build_binary_sensor_automations(var, config):
await automation.build_automation(trigger, [], conf)
for conf in config.get(CONF_ON_STATE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(bool, "x")], conf)
await automation.build_callback_automation(
var, "add_on_state_callback", [(bool, "x")], conf
)
for conf in config.get(CONF_ON_STATE_CHANGE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(
trigger,
await automation.build_callback_automation(
var,
"add_full_state_callback",
[
(cg.optional.template(bool), "x_previous"),
(cg.optional.template(bool), "x"),
+4 -12
View File
@@ -10,7 +10,6 @@ from esphome.const import (
CONF_ID,
CONF_MQTT_ID,
CONF_ON_PRESS,
CONF_TRIGGER_ID,
CONF_WEB_SERVER,
DEVICE_CLASS_EMPTY,
DEVICE_CLASS_IDENTIFY,
@@ -41,10 +40,6 @@ ButtonPtr = Button.operator("ptr")
PressAction = button_ns.class_("PressAction", automation.Action)
ButtonPressTrigger = button_ns.class_(
"ButtonPressTrigger", automation.Trigger.template()
)
validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_")
@@ -55,11 +50,7 @@ _BUTTON_SCHEMA = (
{
cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTButtonComponent),
cv.Optional(CONF_DEVICE_CLASS): validate_device_class,
cv.Optional(CONF_ON_PRESS): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ButtonPressTrigger),
}
),
cv.Optional(CONF_ON_PRESS): automation.validate_automation({}),
}
)
)
@@ -91,8 +82,9 @@ def button_schema(
@setup_entity("button")
async def setup_button_core_(var, config):
for conf in config.get(CONF_ON_PRESS, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
await automation.build_callback_automation(
var, "add_on_press_callback", [], conf
)
setup_device_class(config)
+4 -10
View File
@@ -10,7 +10,6 @@ from esphome.const import (
CONF_ID,
CONF_MQTT_ID,
CONF_ON_EVENT,
CONF_TRIGGER_ID,
CONF_WEB_SERVER,
DEVICE_CLASS_BUTTON,
DEVICE_CLASS_DOORBELL,
@@ -41,8 +40,6 @@ EventPtr = Event.operator("ptr")
TriggerEventAction = event_ns.class_("TriggerEventAction", automation.Action)
EventTrigger = event_ns.class_("EventTrigger", automation.Trigger.template())
validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True, space="_")
_EVENT_SCHEMA = (
@@ -53,11 +50,7 @@ _EVENT_SCHEMA = (
cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTEventComponent),
cv.GenerateID(): cv.declare_id(Event),
cv.Optional(CONF_DEVICE_CLASS): validate_device_class,
cv.Optional(CONF_ON_EVENT): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(EventTrigger),
}
),
cv.Optional(CONF_ON_EVENT): automation.validate_automation({}),
}
)
)
@@ -92,8 +85,9 @@ def event_schema(
@setup_entity("event")
async def setup_event_core_(var, config, *, event_types: list[str]):
for conf in config.get(CONF_ON_EVENT, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(cg.StringRef, "event_type")], conf)
await automation.build_callback_automation(
var, "add_on_event_callback", [(cg.StringRef, "event_type")], conf
)
cg.add(var.set_event_types(event_types))
+4 -10
View File
@@ -155,9 +155,6 @@ Number = number_ns.class_("Number", cg.EntityBase)
NumberPtr = Number.operator("ptr")
# Triggers
NumberStateTrigger = number_ns.class_(
"NumberStateTrigger", automation.Trigger.template(cg.float_)
)
ValueRangeTrigger = number_ns.class_(
"ValueRangeTrigger", automation.Trigger.template(cg.float_), cg.Component
)
@@ -198,11 +195,7 @@ _NUMBER_SCHEMA = (
.extend(
{
cv.OnlyWith(CONF_MQTT_ID, "mqtt"): cv.declare_id(mqtt.MQTTNumberComponent),
cv.Optional(CONF_ON_VALUE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(NumberStateTrigger),
}
),
cv.Optional(CONF_ON_VALUE): automation.validate_automation({}),
cv.Optional(CONF_ON_VALUE_RANGE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ValueRangeTrigger),
@@ -248,8 +241,9 @@ def number_schema(
@coroutine_with_priority(CoroPriority.AUTOMATION)
async def _build_number_automations(var, config):
for conf in config.get(CONF_ON_VALUE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(float, "x")], conf)
await automation.build_callback_automation(
var, "add_on_state_callback", [(float, "x")], conf
)
for conf in config.get(CONF_ON_VALUE_RANGE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await cg.register_component(trigger, conf)
+10 -24
View File
@@ -238,12 +238,6 @@ Sensor = sensor_ns.class_("Sensor", cg.EntityBase)
SensorPtr = Sensor.operator("ptr")
# Triggers
SensorStateTrigger = sensor_ns.class_(
"SensorStateTrigger", automation.Trigger.template(cg.float_)
)
SensorRawStateTrigger = sensor_ns.class_(
"SensorRawStateTrigger", automation.Trigger.template(cg.float_)
)
ValueRangeTrigger = sensor_ns.class_(
"ValueRangeTrigger", automation.Trigger.template(cg.float_), cg.Component
)
@@ -316,18 +310,8 @@ _SENSOR_SCHEMA = (
cv.Any(None, cv.positive_time_period_milliseconds),
),
cv.Optional(CONF_FILTERS): validate_filters,
cv.Optional(CONF_ON_VALUE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(SensorStateTrigger),
}
),
cv.Optional(CONF_ON_RAW_VALUE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(
SensorRawStateTrigger
),
}
),
cv.Optional(CONF_ON_VALUE): automation.validate_automation({}),
cv.Optional(CONF_ON_RAW_VALUE): automation.validate_automation({}),
cv.Optional(CONF_ON_VALUE_RANGE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(ValueRangeTrigger),
@@ -897,12 +881,14 @@ async def build_filters(config):
@coroutine_with_priority(CoroPriority.AUTOMATION)
async def _build_sensor_automations(var, config):
for conf in config.get(CONF_ON_VALUE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(float, "x")], conf)
for conf in config.get(CONF_ON_RAW_VALUE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(float, "x")], conf)
for conf_key, callback in (
(CONF_ON_VALUE, "add_on_state_callback"),
(CONF_ON_RAW_VALUE, "add_on_raw_state_callback"),
):
for conf in config.get(conf_key, []):
await automation.build_callback_automation(
var, callback, [(float, "x")], conf
)
for conf in config.get(CONF_ON_VALUE_RANGE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await cg.register_component(trigger, conf)
+12 -36
View File
@@ -15,7 +15,6 @@ from esphome.const import (
CONF_ON_TURN_ON,
CONF_RESTORE_MODE,
CONF_STATE,
CONF_TRIGGER_ID,
CONF_WEB_SERVER,
DEVICE_CLASS_EMPTY,
DEVICE_CLASS_OUTLET,
@@ -61,17 +60,6 @@ TurnOnAction = switch_ns.class_("TurnOnAction", automation.Action)
SwitchPublishAction = switch_ns.class_("SwitchPublishAction", automation.Action)
SwitchCondition = switch_ns.class_("SwitchCondition", Condition)
SwitchStateTrigger = switch_ns.class_(
"SwitchStateTrigger", automation.Trigger.template(bool)
)
SwitchTurnOnTrigger = switch_ns.class_(
"SwitchTurnOnTrigger", automation.Trigger.template()
)
SwitchTurnOffTrigger = switch_ns.class_(
"SwitchTurnOffTrigger", automation.Trigger.template()
)
validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True)
@@ -86,21 +74,9 @@ _SWITCH_SCHEMA = (
cv.Optional(CONF_RESTORE_MODE, default="ALWAYS_OFF"): cv.enum(
RESTORE_MODES, upper=True, space="_"
),
cv.Optional(CONF_ON_STATE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(SwitchStateTrigger),
}
),
cv.Optional(CONF_ON_TURN_ON): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(SwitchTurnOnTrigger),
}
),
cv.Optional(CONF_ON_TURN_OFF): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(SwitchTurnOffTrigger),
}
),
cv.Optional(CONF_ON_STATE): automation.validate_automation({}),
cv.Optional(CONF_ON_TURN_ON): automation.validate_automation({}),
cv.Optional(CONF_ON_TURN_OFF): automation.validate_automation({}),
cv.Optional(CONF_DEVICE_CLASS): validate_device_class,
}
)
@@ -147,15 +123,15 @@ def switch_schema(
@coroutine_with_priority(CoroPriority.AUTOMATION)
async def _build_switch_automations(var, config):
for conf in config.get(CONF_ON_STATE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(bool, "x")], conf)
for conf in config.get(CONF_ON_TURN_ON, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
for conf in config.get(CONF_ON_TURN_OFF, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [], conf)
for conf_key, args, forwarder in (
(CONF_ON_STATE, [(bool, "x")], None),
(CONF_ON_TURN_ON, [], automation.TriggerOnTrueForwarder),
(CONF_ON_TURN_OFF, [], automation.TriggerOnFalseForwarder),
):
for conf in config.get(conf_key, []):
await automation.build_callback_automation(
var, "add_on_state_callback", args, conf, forwarder=forwarder
)
@setup_entity("switch")
+10 -28
View File
@@ -14,7 +14,6 @@ from esphome.const import (
CONF_ON_VALUE,
CONF_STATE,
CONF_TO,
CONF_TRIGGER_ID,
CONF_WEB_SERVER,
DEVICE_CLASS_DATE,
DEVICE_CLASS_EMPTY,
@@ -42,12 +41,6 @@ text_sensor_ns = cg.esphome_ns.namespace("text_sensor")
TextSensor = text_sensor_ns.class_("TextSensor", cg.EntityBase)
TextSensorPtr = TextSensor.operator("ptr")
TextSensorStateTrigger = text_sensor_ns.class_(
"TextSensorStateTrigger", automation.Trigger.template(cg.std_string)
)
TextSensorStateRawTrigger = text_sensor_ns.class_(
"TextSensorStateRawTrigger", automation.Trigger.template(cg.std_string)
)
TextSensorPublishAction = text_sensor_ns.class_(
"TextSensorPublishAction", automation.Action
)
@@ -150,20 +143,8 @@ _TEXT_SENSOR_SCHEMA = (
cv.GenerateID(): cv.declare_id(TextSensor),
cv.Optional(CONF_DEVICE_CLASS): validate_device_class,
cv.Optional(CONF_FILTERS): validate_filters,
cv.Optional(CONF_ON_VALUE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(
TextSensorStateTrigger
),
}
),
cv.Optional(CONF_ON_RAW_VALUE): automation.validate_automation(
{
cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(
TextSensorStateRawTrigger
),
}
),
cv.Optional(CONF_ON_VALUE): automation.validate_automation({}),
cv.Optional(CONF_ON_RAW_VALUE): automation.validate_automation({}),
}
)
)
@@ -203,13 +184,14 @@ async def build_filters(config):
@coroutine_with_priority(CoroPriority.AUTOMATION)
async def _build_text_sensor_automations(var, config):
for conf in config.get(CONF_ON_VALUE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(cg.std_string, "x")], conf)
for conf in config.get(CONF_ON_RAW_VALUE, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], var)
await automation.build_automation(trigger, [(cg.std_string, "x")], conf)
for conf_key, callback in (
(CONF_ON_VALUE, "add_on_state_callback"),
(CONF_ON_RAW_VALUE, "add_on_raw_state_callback"),
):
for conf in config.get(conf_key, []):
await automation.build_callback_automation(
var, callback, [(cg.std_string, "x")], conf
)
@setup_entity("text_sensor")
+40 -2
View File
@@ -470,7 +470,9 @@ template<typename... Ts> class ActionList {
template<typename... Ts> class Automation {
public:
explicit Automation(Trigger<Ts...> *trigger) : trigger_(trigger) { this->trigger_->set_automation_parent(this); }
/// Default constructor for use with TriggerForwarder (no Trigger object needed).
Automation() = default;
explicit Automation(Trigger<Ts...> *trigger) { trigger->set_automation_parent(this); }
void add_action(Action<Ts...> *action) { this->actions_.add_action(action); }
void add_actions(const std::initializer_list<Action<Ts...> *> &actions) { this->actions_.add_actions(actions); }
@@ -487,8 +489,44 @@ template<typename... Ts> class Automation {
int num_running() { return this->actions_.num_running(); }
protected:
Trigger<Ts...> *trigger_;
ActionList<Ts...> actions_;
};
/// Callback forwarder that triggers an Automation directly.
/// One operator() instantiation per Automation<Ts...> signature, shared across all call sites.
/// Must stay pointer-sized to fit inline in Callback::ctx_ without heap allocation.
template<typename... Ts> struct TriggerForwarder {
Automation<Ts...> *automation;
void operator()(const Ts &...args) const { this->automation->trigger(args...); }
};
/// Callback forwarder that triggers an Automation<> only when the bool arg is true.
/// Must stay pointer-sized to fit inline in Callback::ctx_ without heap allocation.
struct TriggerOnTrueForwarder {
Automation<> *automation;
void operator()(bool state) const {
if (state)
this->automation->trigger();
}
};
/// Callback forwarder that triggers an Automation<> only when the bool arg is false.
/// Must stay pointer-sized to fit inline in Callback::ctx_ without heap allocation.
struct TriggerOnFalseForwarder {
Automation<> *automation;
void operator()(bool state) const {
if (!state)
this->automation->trigger();
}
};
// Ensure forwarders fit in Callback::ctx_ (pointer-sized inline storage).
// If these fail, the forwarder would heap-allocate in Callback::create().
static_assert(sizeof(TriggerForwarder<>) <= sizeof(void *));
static_assert(sizeof(TriggerOnTrueForwarder) <= sizeof(void *));
static_assert(sizeof(TriggerOnFalseForwarder) <= sizeof(void *));
static_assert(std::is_trivially_copyable_v<TriggerForwarder<>>);
static_assert(std::is_trivially_copyable_v<TriggerOnTrueForwarder>);
static_assert(std::is_trivially_copyable_v<TriggerOnFalseForwarder>);
} // namespace esphome
+80 -1
View File
@@ -5,7 +5,13 @@ from unittest.mock import patch
import pytest
from esphome.automation import has_non_synchronous_actions
from esphome.automation import (
TriggerForwarder,
TriggerOnFalseForwarder,
TriggerOnTrueForwarder,
has_non_synchronous_actions,
)
from esphome.cpp_generator import MockObj, RawExpression
from esphome.util import RegistryEntry
@@ -175,3 +181,76 @@ def test_has_non_synchronous_actions_dict_input(
"""Direct dict input (single action)."""
assert has_non_synchronous_actions({"delay": "1s"}) is True
assert has_non_synchronous_actions({"logger.log": "hello"}) is False
def _build_forwarder(
automation_name: str,
args: list[tuple[str, str]],
forwarder: MockObj | None = None,
) -> str:
"""Build a trigger forwarder expression the same way build_callback_automation does.
Mirrors the forwarder selection logic in automation.build_callback_automation.
"""
import esphome.codegen as cg
obj = MockObj(automation_name, "->")
if forwarder is None:
arg_types = [RawExpression(t) for t, _ in args]
templ = (
cg.TemplateArguments(*arg_types) if arg_types else cg.TemplateArguments()
)
forwarder = TriggerForwarder.template(templ)
return f"{forwarder}{{{obj}}}"
def test_trigger_forwarder_no_args() -> None:
"""Button on_press: TriggerForwarder<> with no args."""
result = _build_forwarder("auto_1", [])
assert result == "TriggerForwarder<>{auto_1}"
def test_trigger_forwarder_single_float_arg() -> None:
"""Sensor on_value: TriggerForwarder<float>."""
result = _build_forwarder("auto_1", [("float", "x")])
assert result == "TriggerForwarder<float>{auto_1}"
def test_trigger_forwarder_single_bool_arg() -> None:
"""Switch on_state: TriggerForwarder<bool>."""
result = _build_forwarder("auto_1", [("bool", "x")])
assert result == "TriggerForwarder<bool>{auto_1}"
def test_trigger_forwarder_on_true() -> None:
"""Binary_sensor on_press / switch on_turn_on: TriggerOnTrueForwarder."""
result = _build_forwarder("auto_1", [], forwarder=TriggerOnTrueForwarder)
assert result == "TriggerOnTrueForwarder{auto_1}"
def test_trigger_forwarder_on_false() -> None:
"""Binary_sensor on_release / switch on_turn_off: TriggerOnFalseForwarder."""
result = _build_forwarder("auto_1", [], forwarder=TriggerOnFalseForwarder)
assert result == "TriggerOnFalseForwarder{auto_1}"
def test_trigger_forwarder_multiple_args() -> None:
"""Binary_sensor on_state_change: TriggerForwarder with two args."""
result = _build_forwarder(
"auto_1",
[("optional<bool>", "x_previous"), ("optional<bool>", "x")],
)
assert result == "TriggerForwarder<optional<bool>, optional<bool>>{auto_1}"
def test_trigger_forwarder_string_arg() -> None:
"""Text_sensor on_value: TriggerForwarder<std::string>."""
result = _build_forwarder("auto_1", [("std::string", "x")])
assert result == "TriggerForwarder<std::string>{auto_1}"
def test_trigger_forwarder_custom_type() -> None:
"""Custom forwarder type passed directly."""
custom = MockObj("MyForwarder", "")
result = _build_forwarder("auto_1", [], forwarder=custom)
assert result == "MyForwarder{auto_1}"