From d20d613c1da39c5c0bcddb1b426fbac9a52a68fd Mon Sep 17 00:00:00 2001 From: Javier Peletier Date: Wed, 8 Apr 2026 03:12:55 +0200 Subject: [PATCH] [substitutions] `!include ${filename}`, Substitutions in include filename paths (package refactor part 5) (#12213) Co-authored-by: J. Nick Koston Co-authored-by: J. Nick Koston --- esphome/components/packages/__init__.py | 61 +++++-- esphome/components/substitutions/__init__.py | 58 ++++++ esphome/components/substitutions/jinja.py | 16 +- esphome/config_validation.py | 7 +- esphome/expression.py | 25 +++ esphome/yaml_util.py | 148 ++++++++++++--- .../11-include_path.approved.yaml | 15 ++ .../substitutions/11-include_path.input.yaml | 21 +++ .../substitutions/12-yaml-merge.approved.yaml | 9 + .../substitutions/12-yaml-merge.input.yaml | 10 ++ .../fixtures/substitutions/inc2.yaml | 6 + .../fixtures/substitutions/inc3.yaml | 3 + tests/unit_tests/test_substitutions.py | 68 ++++++- tests/unit_tests/test_yaml_util.py | 168 +++++++++++++++++- 14 files changed, 562 insertions(+), 53 deletions(-) create mode 100644 esphome/expression.py create mode 100644 tests/unit_tests/fixtures/substitutions/11-include_path.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/11-include_path.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/12-yaml-merge.approved.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/12-yaml-merge.input.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/inc2.yaml create mode 100644 tests/unit_tests/fixtures/substitutions/inc3.yaml diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index 1a6df84fe0..04db690c6f 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -6,7 +6,12 @@ from pathlib import Path from typing import Any from esphome import git, yaml_util -from esphome.components.substitutions import ContextVars, push_context, substitute +from esphome.components.substitutions import ( + ContextVars, + push_context, + resolve_include, + substitute, +) from esphome.components.substitutions.jinja import has_jinja from esphome.config_helpers import Remove, merge_config import esphome.config_validation as cv @@ -31,6 +36,8 @@ from esphome.core import EsphomeError _LOGGER = logging.getLogger(__name__) DOMAIN = CONF_PACKAGES +# Guard against infinite include chains (e.g. A includes B includes A). +MAX_INCLUDE_DEPTH = 20 def is_remote_package(package_config: dict) -> bool: @@ -59,8 +66,8 @@ def valid_package_contents(package_config: dict) -> dict: for k, v in package_config.items(): if not isinstance(k, str): raise cv.Invalid("Package content keys must be strings") - if isinstance(v, (dict, list, Remove)): - continue # e.g. script: [], psram: !remove, logger: {level: debug} + if isinstance(v, (dict, list, Remove, yaml_util.IncludeFile)): + continue # e.g. script: [], psram: !remove, logger: {level: debug}, switch: !include switches.yaml if v is None: continue # e.g. web_server: if isinstance(v, str) and has_jinja(v): @@ -160,6 +167,7 @@ REMOTE_PACKAGE_SCHEMA = cv.All( PACKAGE_SCHEMA = cv.Any( # A package definition is either: validate_source_shorthand, # A git URL shorthand string that expands to a remote package schema, or REMOTE_PACKAGE_SCHEMA, # a valid remote package schema, or + yaml_util.IncludeFile, # isinstance check — passes IncludeFile objects through unchanged, or: valid_package_contents, # Something that at least looks like an actual package, e.g. {wifi:{ssid: xxx}} # which will have to be fully validated later as per each component's schema. ) @@ -396,16 +404,49 @@ class _PackageProcessor: self.skip_update = skip_update def resolve_package( - self, package_config: dict | str, context_vars: ContextVars | None + self, + package_config: dict | str | yaml_util.IncludeFile, + context_vars: ContextVars | None, ) -> dict: - """Substitute variables in the definition and fetch remote packages. + """Resolve a package definition to a concrete ``dict`` and fetch remote packages. - The input may be a ``str`` (git shorthand or Jinja expression) or a - ``dict`` (remote or local package). After ``PACKAGE_SCHEMA`` validation - the result is always a ``dict``. + The input may be a ``str`` (git shorthand or Jinja expression), a + ``dict`` (remote or local package), or an ``IncludeFile`` whose filename + may itself contain substitution expressions. + + The loop handles the case where loading an ``IncludeFile`` yields another + ``IncludeFile`` (e.g. a chain of deferred includes). Each iteration: + + 1. If the current value is an ``IncludeFile``, load it — resolving any + substitutions in its filename first. + 2. Substitute variables in the resulting value (for strings and remote + package dicts). + 3. Validate against ``PACKAGE_SCHEMA``. If the result is a ``dict``, + the loop exits; otherwise another iteration is needed. + + Raises ``cv.Invalid`` if the chain has not resolved to a ``dict`` after + ``MAX_INCLUDE_DEPTH`` iterations. """ - package_config = _substitute_package_definition(package_config, context_vars) - package_config = PACKAGE_SCHEMA(package_config) + for _ in range(MAX_INCLUDE_DEPTH): + if isinstance(package_config, yaml_util.IncludeFile): + package_config, _ = resolve_include( + package_config, + [], + context_vars or ContextVars(), + strict_undefined=False, + ) + + package_config = _substitute_package_definition( + package_config, context_vars + ) + package_config = PACKAGE_SCHEMA(package_config) + if isinstance(package_config, dict): + break + else: + raise cv.Invalid( + f"Maximum include nesting depth ({MAX_INCLUDE_DEPTH}) exceeded" + ) + if is_remote_package(package_config): package_config = _process_remote_package(package_config, self.skip_update) return package_config diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index aab1712b65..c0bd9d7be9 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -2,6 +2,7 @@ from collections import ChainMap import logging from typing import Any +import esphome from esphome import core from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered import esphome.config_validation as cv @@ -12,6 +13,7 @@ from esphome.yaml_util import ( ConfigContext, ESPHomeDataBase, ESPLiteralValue, + IncludeFile, make_data_base, ) @@ -291,6 +293,59 @@ def push_context( return parent_context +def resolve_include( + include: IncludeFile, + path: list[int | str], + context_vars: ContextVars, + strict_undefined: bool = True, + errors: ErrList | None = None, +) -> tuple[Any, str]: + """Resolve an include, substituting the filename if needed. + + Returns the loaded content and the resolved filename. + + Note: no path-traversal validation is performed on the resolved filename. + A substitution that resolves to an absolute path will bypass the parent + directory (Path.__truediv__ ignores the left operand for absolute paths). + ESPHome's trust model assumes the config author controls all substitution + values (including command-line substitutions), so path restrictions are + an explicit non-goal here. + """ + original = str(include.file) + filename = str( + _expand_substitutions( + original, path + ["file"], context_vars, strict_undefined, errors + ) + ) + if filename != original: + include = IncludeFile( + include.parent_file, filename, include.vars, include.yaml_loader + ) + try: + return include.load(), filename + except esphome.core.EsphomeError as err: + raise cv.Invalid( + f"Error including file '{filename}': {err}", + path + [f"<{filename}>"], + ) from err + + +def _substitute_include( + include: IncludeFile, + path: list[int | str], + context_vars: ContextVars, + strict_undefined: bool, + errors: ErrList | None, +) -> Any: + """Resolve an include and substitute its content.""" + content, filename = resolve_include( + include, path, context_vars, strict_undefined, errors + ) + return substitute( + content, path + [f"<{filename}>"], context_vars, strict_undefined, errors + ) + + def substitute( item: Any, path: SubstitutionPath, @@ -333,6 +388,9 @@ def substitute( if item.value != value: result = type(item)(value) + elif isinstance(item, IncludeFile): + result = _substitute_include(item, path, context_vars, strict_undefined, errors) + if isinstance(item, ESPHomeDataBase): result = make_data_base(result, item) return result diff --git a/esphome/components/substitutions/jinja.py b/esphome/components/substitutions/jinja.py index 37e9fa4d2d..36a7425a69 100644 --- a/esphome/components/substitutions/jinja.py +++ b/esphome/components/substitutions/jinja.py @@ -2,7 +2,6 @@ from ast import literal_eval from collections.abc import Iterator, Mapping from itertools import chain, islice import math -import re from types import GeneratorType from typing import Any @@ -10,6 +9,9 @@ import jinja2 as jinja from jinja2.nativetypes import NativeCodeGenerator, NativeTemplate from jinja2.runtime import missing as Missing +# Re-exported for backward compatibility — consumers import has_jinja from here +from esphome.expression import has_jinja # noqa: F401 # pylint: disable=unused-import + TemplateError = jinja.TemplateError TemplateSyntaxError = jinja.TemplateSyntaxError TemplateRuntimeError = jinja.TemplateRuntimeError @@ -20,18 +22,6 @@ Undefined = jinja.Undefined Resolver = ".resolver" -DETECT_JINJA = r"(\$\{)" -detect_jinja_re = re.compile( - r"<%.+?%>" # Block form expression: <% ... %> - r"|\$\{[^}]+\}", # Braced form expression: ${ ... } - flags=re.MULTILINE, -) - - -def has_jinja(st: str) -> bool: - return detect_jinja_re.search(st) is not None - - # SAFE_GLOBALS defines a allowlist of built-in functions or modules that are considered safe to expose # in Jinja templates or other sandboxed evaluation contexts. Only functions that do not allow # arbitrary code execution, file access, or other security risks are included. diff --git a/esphome/config_validation.py b/esphome/config_validation.py index b0bd9e6231..31cfb41a6d 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -75,7 +75,6 @@ from esphome.const import ( SCHEDULER_DONT_RUN, TYPE_GIT, TYPE_LOCAL, - VALID_SUBSTITUTIONS_CHARACTERS, Framework, __version__ as ESPHOME_VERSION, ) @@ -90,6 +89,7 @@ from esphome.core import ( TimePeriodNanoseconds, TimePeriodSeconds, ) +from esphome.expression import SUBSTITUTION_VARIABLE_PROG as VARIABLE_PROG from esphome.helpers import add_class_to_obj, docs_url, list_starts_with from esphome.schema_extractors import ( SCHEMA_EXTRACT, @@ -104,11 +104,6 @@ from esphome.yaml_util import make_data_base _LOGGER = logging.getLogger(__name__) -# pylint: disable=consider-using-f-string -VARIABLE_PROG = re.compile( - f"\\$([{VALID_SUBSTITUTIONS_CHARACTERS}]+|\\{{[{VALID_SUBSTITUTIONS_CHARACTERS}]*\\}})" -) - # pylint: disable=invalid-name Schema = _Schema diff --git a/esphome/expression.py b/esphome/expression.py new file mode 100644 index 0000000000..d425d822a4 --- /dev/null +++ b/esphome/expression.py @@ -0,0 +1,25 @@ +"""Helpers for detecting substitution variables and Jinja expressions.""" + +import re + +from esphome.const import VALID_SUBSTITUTIONS_CHARACTERS + +SUBSTITUTION_VARIABLE_PROG = re.compile( + rf"\$([{VALID_SUBSTITUTIONS_CHARACTERS}]+|\{{[{VALID_SUBSTITUTIONS_CHARACTERS}]*\}})" +) + +_JINJA_RE = re.compile( + r"<%.+?%>" # Block: <% ... %> + r"|\$\{[^}]+\}", # Braced: ${ ... } + flags=re.MULTILINE, +) + + +def has_jinja(value: str) -> bool: + """Check if a string contains Jinja expressions.""" + return _JINJA_RE.search(value) is not None + + +def has_substitution_or_expression(value: str) -> bool: + """Check if a string contains substitution variables ($name, ${name}) or Jinja expressions.""" + return SUBSTITUTION_VARIABLE_PROG.search(value) is not None or has_jinja(value) diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index a24c1ebccb..c621428196 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -33,6 +33,7 @@ from esphome.core import ( MACAddress, TimePeriod, ) +from esphome.expression import has_substitution_or_expression from esphome.helpers import add_class_to_obj from esphome.util import OrderedDict, filter_yaml_files @@ -110,24 +111,6 @@ def make_data_base( return value -class ConfigContext: - """This is a mixin class that holds substitution vars that should be applied - to the tagged node and its children. During configuration loading, context vars can - be added to nodes using `add_context` function, which applies the mixin storing - the captured values and unevaluated expressions. - The substitution pass then recreates the effective context by merging the context vars - from this node and parent nodes. - """ - - @property - def vars(self) -> dict[str, Any]: - return self._context_vars - - def set_context(self, vars: dict[str, Any]) -> None: - # pylint: disable=attribute-defined-outside-init - self._context_vars = vars - - def add_context(value: Any, context_vars: dict[str, Any] | None) -> Any: """Tags a list/string/dict value with context vars that must be applied to it and its children during the substitution pass. If no vars are given, no tagging is done. @@ -151,6 +134,94 @@ def add_context(value: Any, context_vars: dict[str, Any] | None) -> Any: return value +class ConfigContext: + """This is a mixin class that holds substitution vars that should be applied + to the tagged node and its children. During configuration loading, context vars can + be added to nodes using `add_context` function, which applies the mixin storing + the captured values and unevaluated expressions. + The substitution pass then recreates the effective context by merging the context vars + from this node and parent nodes. + """ + + @property + def vars(self) -> dict[str, Any]: + return self._context_vars + + def set_context(self, vars: dict[str, Any]) -> None: + # pylint: disable=attribute-defined-outside-init + self._context_vars = vars + + def copy_context_to_children(self) -> None: + """Propagate context to children. + + isinstance(self, dict/list) works because ConfigContext is dynamically + mixed into dict/list subclasses via add_class_to_obj in add_context(). + """ + if isinstance(self, dict): + # pylint: disable=no-member + tagged = { + add_context(k, self.vars): add_context(v, self.vars) + for k, v in self.items() + } + self.clear() + self.update(tagged) + elif isinstance(self, list): + for i, item in enumerate(self): + # pylint: disable=unsupported-assignment-operation + self[i] = add_context(item, self.vars) + + +_UNSET = object() + + +class IncludeFile: + """Deferred !include that is resolved during the substitution pass. + + Created during YAML parsing instead of loading the file immediately, + allowing substitution variables to appear in the filename path + (e.g. ``!include device-${platform}.yaml``). The actual file is + loaded on the first call to ``load()``, and the result is cached. + """ + + def __init__( + self, + parent_file: Path, + file: Path | str, + vars: dict[str, Any] | None, + yaml_loader: Callable[[Path], Any], + ) -> None: + self.parent_file = parent_file + self.file = Path(file) + self.vars = vars + self.yaml_loader = yaml_loader + self._content: Any = _UNSET + + def __repr__(self) -> str: + return f"IncludeFile({self.file.as_posix()})" + + def load(self) -> Any: + """Load and cache the included file content. + + Note: returns the cached mutable object on subsequent calls. + Callers that need to modify the result should copy it first. + """ + if self._content is not _UNSET: + return self._content + if self.has_unresolved_expressions(): + from esphome.config_validation import Invalid + + raise Invalid( + f"Cannot load include with unresolved substitutions: {self.file}" + ) + self._content = self.yaml_loader(Path(self.parent_file.parent / self.file)) + self._content = add_context(self._content, self.vars) + return self._content + + def has_unresolved_expressions(self) -> bool: + """Check if the filename contains substitution variables or Jinja expressions.""" + return has_substitution_or_expression(str(self.file)) + + def _add_data_ref(fn): @functools.wraps(fn) def wrapped(loader, node): @@ -170,6 +241,36 @@ def _add_data_ref(fn): return wrapped +_MAX_MERGE_INCLUDE_DEPTH = 10 + + +def _resolve_merge_include(value: Any, node: yaml.Node, value_node: yaml.Node) -> Any: + """Resolve an IncludeFile (and chains) and propagate context for merge key handling.""" + for _ in range(_MAX_MERGE_INCLUDE_DEPTH): + if not isinstance(value, IncludeFile): + break + if value.has_unresolved_expressions(): + raise yaml.constructor.ConstructorError( + "While constructing a mapping", + node.start_mark, + "Substitution in include filename with merge keys is not supported yet.", + value_node.start_mark, + ) + value = value.load() + else: + raise yaml.constructor.ConstructorError( + "While constructing a mapping", + node.start_mark, + f"Maximum include chain depth ({_MAX_MERGE_INCLUDE_DEPTH}) exceeded in merge key", + value_node.start_mark, + ) + if isinstance(value, ConfigContext): + # Since the parent dict/list will disappear, propagate + # context to children now to retain context vars + value.copy_context_to_children() + return value + + class ESPHomeLoaderMixin: """Loader class that keeps track of line numbers.""" @@ -261,6 +362,9 @@ class ESPHomeLoaderMixin: # This is a merge key, resolve value and add to merge_pairs value = self.construct_object(value_node) + + value = _resolve_merge_include(value, node, value_node) + if isinstance(value, dict): # base case, copy directly to merge_pairs # direct merge, like "<<: {some_key: some_value}" @@ -268,6 +372,7 @@ class ESPHomeLoaderMixin: elif isinstance(value, list): # sequence merge, like "<<: [{some_key: some_value}, {other_key: some_value}]" for item in value: + item = _resolve_merge_include(item, node, value_node) if not isinstance(item, dict): raise yaml.constructor.ConstructorError( "While constructing a mapping", @@ -362,8 +467,11 @@ class ESPHomeLoaderMixin: else: file, vars = node.value, None - result = self.yaml_loader(self._rel_path(file)) - return add_context(result, vars) + return IncludeFile(self.name, file, vars, self.yaml_loader) + + # Directory includes (!include_dir_*) load eagerly during YAML parsing + # because their paths are directory names, not individual files, and + # substitutions in directory paths are not supported. @_add_data_ref def construct_include_dir_list(self, node: yaml.Node) -> list[dict[str, Any]]: diff --git a/tests/unit_tests/fixtures/substitutions/11-include_path.approved.yaml b/tests/unit_tests/fixtures/substitutions/11-include_path.approved.yaml new file mode 100644 index 0000000000..d758a832a4 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/11-include_path.approved.yaml @@ -0,0 +1,15 @@ +values: + - var1: 4 + - a: 5 + - b: 6 + - c: The value of C is 7 + - This value comes from inc2.yaml. x is 3, y is 4 + - From main config, x is 3, y is 2 + - $a $b $c are out of scope here + - keys_in_inc3: + x: 3 + y: 2 +substitutions: + x: 3 + y: 2 + include_file: inc1 diff --git a/tests/unit_tests/fixtures/substitutions/11-include_path.input.yaml b/tests/unit_tests/fixtures/substitutions/11-include_path.input.yaml new file mode 100644 index 0000000000..78b1cb4fb9 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/11-include_path.input.yaml @@ -0,0 +1,21 @@ +substitutions: + include_file: inc1 + x: 3 # override x from inc2.yaml + +packages: + my_package: !include + file: ${include_file + ".yaml"} # includes inc1.yaml + vars: + var1: 4 + a: ${x+2} + b: ${a+1} + c: 7 + other_package: !include + file: inc${1+1}.yaml # includes inc2.yaml + vars: + y: 4 + +values: + - From main config, x is $x, y is $y + - $a $b $c are out of scope here + - !include ${"inc" + "3.yaml"} # includes inc3.yaml here (not a package) diff --git a/tests/unit_tests/fixtures/substitutions/12-yaml-merge.approved.yaml b/tests/unit_tests/fixtures/substitutions/12-yaml-merge.approved.yaml new file mode 100644 index 0000000000..02d8512498 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/12-yaml-merge.approved.yaml @@ -0,0 +1,9 @@ +substitutions: + x: 7 +test_list: + - content: + before: Content before + after: Content after + keys_in_inc3: + x: 7 + y: 8 diff --git a/tests/unit_tests/fixtures/substitutions/12-yaml-merge.input.yaml b/tests/unit_tests/fixtures/substitutions/12-yaml-merge.input.yaml new file mode 100644 index 0000000000..a03e66e393 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/12-yaml-merge.input.yaml @@ -0,0 +1,10 @@ +substitutions: + x: 7 +test_list: + - content: + before: Content before + <<: !include + file: inc3.yaml + vars: + y: 8 + after: Content after diff --git a/tests/unit_tests/fixtures/substitutions/inc2.yaml b/tests/unit_tests/fixtures/substitutions/inc2.yaml new file mode 100644 index 0000000000..29a1833efc --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/inc2.yaml @@ -0,0 +1,6 @@ +substitutions: + x: 1 + y: 2 + +values: + - This value comes from inc2.yaml. x is $x, y is $y diff --git a/tests/unit_tests/fixtures/substitutions/inc3.yaml b/tests/unit_tests/fixtures/substitutions/inc3.yaml new file mode 100644 index 0000000000..03d459dc97 --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/inc3.yaml @@ -0,0 +1,3 @@ +keys_in_inc3: + x: ${x} + y: ${y} diff --git a/tests/unit_tests/test_substitutions.py b/tests/unit_tests/test_substitutions.py index c7b0bbcf7c..01c669e542 100644 --- a/tests/unit_tests/test_substitutions.py +++ b/tests/unit_tests/test_substitutions.py @@ -8,12 +8,17 @@ import pytest from esphome import config as config_module, yaml_util from esphome.components import substitutions -from esphome.components.packages import do_packages_pass, merge_packages +from esphome.components.packages import ( + MAX_INCLUDE_DEPTH, + _PackageProcessor, + do_packages_pass, + merge_packages, +) from esphome.config import resolve_extend_remove from esphome.config_helpers import Extend, merge_config import esphome.config_validation as cv from esphome.const import CONF_SUBSTITUTIONS -from esphome.core import CORE, Lambda +from esphome.core import CORE, EsphomeError, Lambda from esphome.util import OrderedDict _LOGGER = logging.getLogger(__name__) @@ -630,3 +635,62 @@ def test_do_substitution_pass_substitutions_must_be_mapping_from_config() -> Non cv.Invalid, match="Substitutions must be a key to value mapping" ): substitutions.do_substitution_pass(config) + + +# ── IncludeFile / package loading tests ──────────────────────────────────── + + +def test_resolve_package_max_depth_exceeded(tmp_path: Path) -> None: + """A yaml_loader that always returns another IncludeFile triggers the depth guard.""" + parent = tmp_path / "main.yaml" + parent.write_text("") + + # Each call to the loader returns a fresh IncludeFile pointing at itself, + # so PACKAGE_SCHEMA always sees an IncludeFile and never a dict. + def always_returns_include(path: Path) -> yaml_util.IncludeFile: + return yaml_util.IncludeFile(parent, path.name, None, always_returns_include) + + package_config = yaml_util.IncludeFile( + parent, "test.yaml", None, always_returns_include + ) + processor = _PackageProcessor({}, None, False) + with pytest.raises( + cv.Invalid, + match=f"Maximum include nesting depth \\({MAX_INCLUDE_DEPTH}\\) exceeded", + ): + processor.resolve_package(package_config, substitutions.ContextVars()) + + +def test_include_filename_substitution_undefined_var(tmp_path: Path) -> None: + """!include with an undefined substitution variable raises cv.Invalid. + + The error message must reference the unresolved filename template so the + user knows which include failed, rather than seeing a bare file-not-found. + """ + main_file = tmp_path / "main.yaml" + main_file.write_text("result: !include ${undefined_var}.yaml\n") + + config = yaml_util.load_yaml(main_file) + with pytest.raises(cv.Invalid, match=r"\$\{undefined_var\}"): + substitutions.do_substitution_pass(config) + + +def test_resolve_package_undefined_var_in_include_filename(tmp_path: Path) -> None: + """An undefined substitution in a package include filename raises cv.Invalid. + + Previously this would raise an unhandled UndefinedError. With + strict_undefined=False, the unresolved filename passes through to + file loading which produces a clean cv.Invalid error. + """ + parent = tmp_path / "main.yaml" + parent.write_text("") + + def loader(path: Path): + raise EsphomeError(f"Error reading file {path}: No such file") + + package_config = yaml_util.IncludeFile( + parent, "${undefined_var}.yaml", None, loader + ) + processor = _PackageProcessor({}, None, False) + with pytest.raises(cv.Invalid, match="unresolved substitutions"): + processor.resolve_package(package_config, substitutions.ContextVars()) diff --git a/tests/unit_tests/test_yaml_util.py b/tests/unit_tests/test_yaml_util.py index 0342d12540..0bd7c9453b 100644 --- a/tests/unit_tests/test_yaml_util.py +++ b/tests/unit_tests/test_yaml_util.py @@ -1,3 +1,4 @@ +import io from pathlib import Path import shutil from unittest.mock import patch @@ -7,6 +8,7 @@ import pytest from esphome import core, yaml_util from esphome.components import substitutions from esphome.config_helpers import Extend, Remove +import esphome.config_validation as cv from esphome.core import EsphomeError from esphome.util import OrderedDict @@ -74,7 +76,9 @@ def test_parsing_with_custom_loader(fixture_path): loader_calls.append(fname) with yaml_file.open(encoding="utf-8") as f_handle: - yaml_util.parse_yaml(yaml_file, f_handle, custom_loader) + config = yaml_util.parse_yaml(yaml_file, f_handle, custom_loader) + # substitute config to expand includes: + substitutions.substitute(config, [], substitutions.ContextVars(), False) assert len(loader_calls) == 3 assert loader_calls[0].parts[-2:] == ("includes", "included.yaml") @@ -348,7 +352,9 @@ def test_track_yaml_loads_records_includes(tmp_path: Path) -> None: main.write_text("child: !include included.yaml\n") with yaml_util.track_yaml_loads() as loaded: - yaml_util.load_yaml(main) + result = yaml_util.load_yaml(main) + # !include is deferred; resolve it to trigger the nested load + result["child"].load() resolved = [p.name for p in loaded] assert "main.yaml" in resolved @@ -500,3 +506,161 @@ def test_represent_extend() -> None: def test_represent_remove() -> None: """Test that Remove objects are dumped as plain !remove scalars.""" assert yaml_util.dump({"key": Remove("my_id")}) == "key: !remove 'my_id'\n" + + +# ── IncludeFile unit tests ────────────────────────────────────────────────── + + +def test_include_file_repr(tmp_path: Path) -> None: + """repr() includes the filename so it appears usefully in error messages.""" + parent = tmp_path / "main.yaml" + include = yaml_util.IncludeFile(parent, "some/nested.yaml", None, lambda _: {}) + assert repr(include) == "IncludeFile(some/nested.yaml)" + + +def test_include_file_load_caches_result(tmp_path: Path) -> None: + """load() invokes the yaml_loader only once; subsequent calls return the cached object.""" + parent = tmp_path / "main.yaml" + content = {"key": "value"} + call_count = 0 + + def counting_loader(_): + nonlocal call_count + call_count += 1 + return content + + include = yaml_util.IncludeFile(parent, "child.yaml", None, counting_loader) + first = include.load() + second = include.load() + + assert call_count == 1 + assert first is second + + +def test_include_file_load_caches_none_result(tmp_path: Path) -> None: + """load() caches None content (empty YAML files) and does not re-invoke the loader.""" + parent = tmp_path / "main.yaml" + call_count = 0 + + def counting_loader(_): + nonlocal call_count + call_count += 1 + + include = yaml_util.IncludeFile(parent, "empty.yaml", None, counting_loader) + first = include.load() + second = include.load() + + assert call_count == 1 + assert first is None + assert second is None + + +def test_include_file_load_raises_on_unresolved_expressions(tmp_path: Path) -> None: + """load() raises if the filename contains unresolved substitutions or expressions.""" + parent = tmp_path / "main.yaml" + include = yaml_util.IncludeFile(parent, "${undefined_var}.yaml", None, lambda _: {}) + with pytest.raises(cv.Invalid, match="unresolved"): + include.load() + + +@pytest.mark.parametrize( + ("filename", "expected"), + [ + ("device-${platform}.yaml", True), + ("$platform.yaml", True), + ("${a + b}.yaml", True), # Jinja expression + ("device.yaml", False), + ("path/to/device.yaml", False), + ("my$file.yaml", True), # $file is a valid substitution + ("price-100$.yaml", False), # $ at end, not followed by valid substitution + ], +) +def test_include_file_has_unresolved_expressions( + tmp_path: Path, filename: str, expected: bool +) -> None: + """has_unresolved_expressions() detects substitution patterns in the filename.""" + parent = tmp_path / "main.yaml" + include = yaml_util.IncludeFile(parent, filename, None, lambda _: {}) + assert include.has_unresolved_expressions() == expected + + +def test_include_in_list_context() -> None: + """!include of a file returning a list is handled correctly, + including when that list itself contains a nested IncludeFile.""" + parent = Path("/fake/main.yaml") + + # The nested IncludeFile resolves to a plain string value + inner = yaml_util.IncludeFile(parent, "inner.yaml", None, lambda _: "gamma") + + # The outer IncludeFile returns a list whose last element is itself an IncludeFile, + # exercising the substitution pass's ability to recurse into loaded content. + outer = yaml_util.IncludeFile( + parent, "items.yaml", None, lambda _: ["alpha", "beta", inner] + ) + + config = OrderedDict({"values": outer}) + config = substitutions.do_substitution_pass(config) + + assert config["values"] == ["alpha", "beta", "gamma"] + + +def test_include_plain_filename_loads_after_deferred_refactor() -> None: + """!include with a plain filename (no $ expressions) still loads correctly. + + Regression guard: the deferred-loading refactor must not break the simple case. + """ + parent = Path("/fake/main.yaml") + include = yaml_util.IncludeFile( + parent, "child.yaml", None, lambda _: {"answer": 42} + ) + + config = OrderedDict({"result": include}) + config = substitutions.do_substitution_pass(config) + + assert config["result"]["answer"] == 42 + + +def test_yaml_merge_include_with_filename_substitution_raises() -> None: + """<<: !include ${expr} raises a clear error — substitutions in merge-key filenames + are not yet supported, and the error message must say so.""" + yaml_text = "base:\n existing: value\n <<: !include ${filename}.yaml\n" + with pytest.raises(EsphomeError, match="not supported yet"): + yaml_util.parse_yaml( + Path("/fake/main.yaml"), io.StringIO(yaml_text), lambda _: {} + ) + + +def test_yaml_merge_list_include_with_filename_substitution_raises() -> None: + """Substitutions in include filenames within merge-key lists raise a clear error.""" + yaml_text = "base:\n existing: value\n <<:\n - !include ${filename}.yaml\n" + with pytest.raises(EsphomeError, match="not supported yet"): + yaml_util.parse_yaml( + Path("/fake/main.yaml"), io.StringIO(yaml_text), lambda _: {} + ) + + +def test_yaml_merge_chain_include_resolves() -> None: + """Chained includes in merge keys resolve through multiple IncludeFile layers.""" + parent = Path("/fake/main.yaml") + + inner = yaml_util.IncludeFile(parent, "inner.yaml", None, lambda _: {"x": 1}) + outer = yaml_util.IncludeFile(parent, "outer.yaml", None, lambda _: inner) + + yaml_text = "base:\n existing: value\n <<: !include outer.yaml\n" + config = yaml_util.parse_yaml(parent, io.StringIO(yaml_text), lambda _: outer) + config = substitutions.do_substitution_pass(config) + + assert config["base"]["x"] == 1 + assert config["base"]["existing"] == "value" + + +def test_yaml_merge_chain_include_depth_exceeded() -> None: + """Chain includes in merge keys exceeding depth limit raise a clear error.""" + parent = Path("/fake/main.yaml") + + def self_referencing_loader(path: Path) -> yaml_util.IncludeFile: + return yaml_util.IncludeFile(parent, path.name, None, self_referencing_loader) + + yaml_text = "base:\n <<: !include loop.yaml\n" + with pytest.raises(EsphomeError, match="Maximum include chain depth"): + yaml_util.parse_yaml(parent, io.StringIO(yaml_text), self_referencing_loader)