diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index bba4bbf487..d0eab4e44e 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -24,6 +24,7 @@ except ImportError: from esphome import core from esphome.config_helpers import Extend, Remove +from esphome.const import CONF_DEFAULTS from esphome.core import ( CORE, DocumentRange, @@ -88,6 +89,47 @@ def make_data_base( return value +class ConfigContext: + """This is a mixin class that holds substitution vars that should be applied + to the tagged node and its children. During configuration loading, context vars can + be added to nodes using `add_context` function, which applies the mixin storing + the captured values and unevaluated expressions. + The substitution pass then recreates the effective context by merging the context vars + from this node and parent nodes. + """ + + @property + def vars(self) -> dict[str, Any]: + return self._context_vars + + def set_context(self, vars: dict[str, Any]) -> None: + # pylint: disable=attribute-defined-outside-init + self._context_vars = vars + + +def add_context(value: Any, context_vars: dict[str, Any] | None) -> Any: + """Tags a list/string/dict value with context vars that must be applied to it and its children + during the substitution pass. If no vars are given, no tagging is done. + If the value is already tagged, the new context vars are merged with existing ones, + with new vars taking precedence. Returns the value tagged with ConfigContext. Returns + the original value if value is not a list/string/dict. + """ + if isinstance(value, dict) and CONF_DEFAULTS in value: + context_vars = { + **value.pop(CONF_DEFAULTS), + **(context_vars or {}), + } + + if isinstance(value, ConfigContext): + value.set_context({**value.vars, **(context_vars or {})}) + return value + + if context_vars and isinstance(value, (dict, list, str)): + value = add_class_to_obj(value, ConfigContext) + value.set_context(context_vars) + return value + + def _add_data_ref(fn): @functools.wraps(fn) def wrapped(loader, node): @@ -455,7 +497,7 @@ def parse_yaml( def substitute_vars(config, vars): from esphome.components import substitutions - from esphome.const import CONF_DEFAULTS, CONF_SUBSTITUTIONS + from esphome.const import CONF_SUBSTITUTIONS org_subs = None result = config @@ -612,6 +654,12 @@ class ESPHomeDumper(yaml.SafeDumper): return self.represent_secret(value.value) return self.represent_scalar(tag="!lambda", value=value.value, style="|") + def represent_extend(self, value): + return self.represent_scalar(tag="!extend", value=value.value) + + def represent_remove(self, value): + return self.represent_scalar(tag="!remove", value=value.value) + def represent_id(self, value): if is_secret(value.id): return self.represent_secret(value.id) @@ -638,6 +686,8 @@ ESPHomeDumper.add_multi_representer(_BaseNetwork, ESPHomeDumper.represent_string ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda) +ESPHomeDumper.add_multi_representer(Extend, ESPHomeDumper.represent_extend) +ESPHomeDumper.add_multi_representer(Remove, ESPHomeDumper.represent_remove) ESPHomeDumper.add_multi_representer(core.ID, ESPHomeDumper.represent_id) ESPHomeDumper.add_multi_representer(uuid.UUID, ESPHomeDumper.represent_stringify) ESPHomeDumper.add_multi_representer(Path, ESPHomeDumper.represent_stringify) diff --git a/tests/unit_tests/test_yaml_util.py b/tests/unit_tests/test_yaml_util.py index c8cb3e144f..adb7658bfd 100644 --- a/tests/unit_tests/test_yaml_util.py +++ b/tests/unit_tests/test_yaml_util.py @@ -6,6 +6,7 @@ import pytest from esphome import core, yaml_util from esphome.components import substitutions +from esphome.config_helpers import Extend, Remove from esphome.core import EsphomeError from esphome.util import OrderedDict @@ -306,3 +307,128 @@ def test_dump_sort_keys() -> None: # nested keys should also be sorted assert "a_key:" in sorted_dump assert sorted_dump.index("a_key:") < sorted_dump.index("z_key:") + + +@pytest.mark.parametrize( + "data", + [ + { + "key1": "value1", + "key2": 42, + }, + [1, 2, 3], + "simple string", + ], +) +def test_config_context_mixin(data) -> None: + """Test that ConfigContext mixin correctly stores and retrieves context vars in a dict.""" + + context_vars = { + "var1": "context_value1", + "var2": 100, + } + + # Add context to the data + tagged_data = yaml_util.add_context(data, context_vars) + + # Check that tagged_data has ConfigContext and correct vars + assert isinstance(tagged_data, type(data)) + assert isinstance(tagged_data, yaml_util.ConfigContext) + assert tagged_data.vars == context_vars + + # Check that original data is preserved + assert tagged_data == data + + +def test_config_context_mixin_no_context() -> None: + """Test that add_context does not tag data when no context vars are provided.""" + data = {"key": "value"} + + # Add context with None + tagged_data = yaml_util.add_context(data, None) + + # Should return original data without tagging + assert tagged_data is data + assert not isinstance(tagged_data, yaml_util.ConfigContext) + + +def test_config_context_mixin_merge_contexts() -> None: + """Test that add_context merges new context vars with existing ones.""" + data = {"key": "value"} + + initial_context = { + "var1": "initial_value", + } + + # First, add initial context + tagged_data = yaml_util.add_context(data, initial_context) + + assert isinstance(tagged_data, yaml_util.ConfigContext) + assert tagged_data.vars == initial_context + + # Now, add more context vars + new_context = { + "var2": "new_value", + "var1": "overridden_value", # This should override the initial var1 + } + + merged_tagged_data = yaml_util.add_context(tagged_data, new_context) + + # Check that merged_tagged_data has merged context vars + expected_context = { + "var1": "overridden_value", + "var2": "new_value", + } + assert isinstance(merged_tagged_data, yaml_util.ConfigContext) + assert merged_tagged_data.vars == expected_context + + # Check that original data is preserved + assert merged_tagged_data == data + + +@pytest.mark.parametrize("data", [42, 3.14, True, None]) +def test_config_context_non_taggable(data) -> None: + """Test that add_context ignores non-string scalar values.""" + + context_vars = { + "var1": "context_value", + } + + # Add context to the scalar data + tagged_data = yaml_util.add_context(data, context_vars) + + # Check that tagged_data has ConfigContext and correct vars + assert not isinstance(tagged_data, yaml_util.ConfigContext) + + # Check that original data is preserved + assert tagged_data == data + + +def test_config_context_defaults_only() -> None: + """Test that defaults: key is popped and used as context vars when no explicit vars given.""" + data = {"defaults": {"x": "1", "y": "2"}, "key": "value"} + tagged = yaml_util.add_context(data, None) + + assert isinstance(tagged, yaml_util.ConfigContext) + assert tagged.vars == {"x": "1", "y": "2"} + assert "defaults" not in tagged + + +def test_config_context_defaults_explicit_vars_override() -> None: + """Test that explicit vars take precedence over defaults: values.""" + data = {"defaults": {"x": "default_x", "z": "default_z"}, "key": "value"} + tagged = yaml_util.add_context(data, {"x": "explicit_x", "w": "explicit_w"}) + + assert isinstance(tagged, yaml_util.ConfigContext) + assert tagged.vars == {"x": "explicit_x", "z": "default_z", "w": "explicit_w"} + assert "defaults" not in tagged + + +def test_represent_extend() -> None: + """Test that Extend objects are dumped as plain !extend scalars.""" + assert yaml_util.dump({"key": Extend("my_id")}) == "key: !extend 'my_id'\n" + + +def test_represent_remove() -> None: + """Test that Remove objects are dumped as plain !remove scalars.""" + assert yaml_util.dump({"key": Remove("my_id")}) == "key: !remove 'my_id'\n"