mirror of
https://github.com/esphome/esphome.git
synced 2026-05-18 09:43:53 +08:00
[substitutions] !include ${filename}, Substitutions in include filename paths (package refactor part 5) (#12213)
Co-authored-by: J. Nick Koston <nick@home-assistant.io> Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
@@ -6,7 +6,12 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from esphome import git, yaml_util
|
||||
from esphome.components.substitutions import ContextVars, push_context, substitute
|
||||
from esphome.components.substitutions import (
|
||||
ContextVars,
|
||||
push_context,
|
||||
resolve_include,
|
||||
substitute,
|
||||
)
|
||||
from esphome.components.substitutions.jinja import has_jinja
|
||||
from esphome.config_helpers import Remove, merge_config
|
||||
import esphome.config_validation as cv
|
||||
@@ -31,6 +36,8 @@ from esphome.core import EsphomeError
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN = CONF_PACKAGES
|
||||
# Guard against infinite include chains (e.g. A includes B includes A).
|
||||
MAX_INCLUDE_DEPTH = 20
|
||||
|
||||
|
||||
def is_remote_package(package_config: dict) -> bool:
|
||||
@@ -59,8 +66,8 @@ def valid_package_contents(package_config: dict) -> dict:
|
||||
for k, v in package_config.items():
|
||||
if not isinstance(k, str):
|
||||
raise cv.Invalid("Package content keys must be strings")
|
||||
if isinstance(v, (dict, list, Remove)):
|
||||
continue # e.g. script: [], psram: !remove, logger: {level: debug}
|
||||
if isinstance(v, (dict, list, Remove, yaml_util.IncludeFile)):
|
||||
continue # e.g. script: [], psram: !remove, logger: {level: debug}, switch: !include switches.yaml
|
||||
if v is None:
|
||||
continue # e.g. web_server:
|
||||
if isinstance(v, str) and has_jinja(v):
|
||||
@@ -160,6 +167,7 @@ REMOTE_PACKAGE_SCHEMA = cv.All(
|
||||
PACKAGE_SCHEMA = cv.Any( # A package definition is either:
|
||||
validate_source_shorthand, # A git URL shorthand string that expands to a remote package schema, or
|
||||
REMOTE_PACKAGE_SCHEMA, # a valid remote package schema, or
|
||||
yaml_util.IncludeFile, # isinstance check — passes IncludeFile objects through unchanged, or:
|
||||
valid_package_contents, # Something that at least looks like an actual package, e.g. {wifi:{ssid: xxx}}
|
||||
# which will have to be fully validated later as per each component's schema.
|
||||
)
|
||||
@@ -396,16 +404,49 @@ class _PackageProcessor:
|
||||
self.skip_update = skip_update
|
||||
|
||||
def resolve_package(
|
||||
self, package_config: dict | str, context_vars: ContextVars | None
|
||||
self,
|
||||
package_config: dict | str | yaml_util.IncludeFile,
|
||||
context_vars: ContextVars | None,
|
||||
) -> dict:
|
||||
"""Substitute variables in the definition and fetch remote packages.
|
||||
"""Resolve a package definition to a concrete ``dict`` and fetch remote packages.
|
||||
|
||||
The input may be a ``str`` (git shorthand or Jinja expression) or a
|
||||
``dict`` (remote or local package). After ``PACKAGE_SCHEMA`` validation
|
||||
the result is always a ``dict``.
|
||||
The input may be a ``str`` (git shorthand or Jinja expression), a
|
||||
``dict`` (remote or local package), or an ``IncludeFile`` whose filename
|
||||
may itself contain substitution expressions.
|
||||
|
||||
The loop handles the case where loading an ``IncludeFile`` yields another
|
||||
``IncludeFile`` (e.g. a chain of deferred includes). Each iteration:
|
||||
|
||||
1. If the current value is an ``IncludeFile``, load it — resolving any
|
||||
substitutions in its filename first.
|
||||
2. Substitute variables in the resulting value (for strings and remote
|
||||
package dicts).
|
||||
3. Validate against ``PACKAGE_SCHEMA``. If the result is a ``dict``,
|
||||
the loop exits; otherwise another iteration is needed.
|
||||
|
||||
Raises ``cv.Invalid`` if the chain has not resolved to a ``dict`` after
|
||||
``MAX_INCLUDE_DEPTH`` iterations.
|
||||
"""
|
||||
package_config = _substitute_package_definition(package_config, context_vars)
|
||||
package_config = PACKAGE_SCHEMA(package_config)
|
||||
for _ in range(MAX_INCLUDE_DEPTH):
|
||||
if isinstance(package_config, yaml_util.IncludeFile):
|
||||
package_config, _ = resolve_include(
|
||||
package_config,
|
||||
[],
|
||||
context_vars or ContextVars(),
|
||||
strict_undefined=False,
|
||||
)
|
||||
|
||||
package_config = _substitute_package_definition(
|
||||
package_config, context_vars
|
||||
)
|
||||
package_config = PACKAGE_SCHEMA(package_config)
|
||||
if isinstance(package_config, dict):
|
||||
break
|
||||
else:
|
||||
raise cv.Invalid(
|
||||
f"Maximum include nesting depth ({MAX_INCLUDE_DEPTH}) exceeded"
|
||||
)
|
||||
|
||||
if is_remote_package(package_config):
|
||||
package_config = _process_remote_package(package_config, self.skip_update)
|
||||
return package_config
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections import ChainMap
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import esphome
|
||||
from esphome import core
|
||||
from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered
|
||||
import esphome.config_validation as cv
|
||||
@@ -12,6 +13,7 @@ from esphome.yaml_util import (
|
||||
ConfigContext,
|
||||
ESPHomeDataBase,
|
||||
ESPLiteralValue,
|
||||
IncludeFile,
|
||||
make_data_base,
|
||||
)
|
||||
|
||||
@@ -291,6 +293,59 @@ def push_context(
|
||||
return parent_context
|
||||
|
||||
|
||||
def resolve_include(
|
||||
include: IncludeFile,
|
||||
path: list[int | str],
|
||||
context_vars: ContextVars,
|
||||
strict_undefined: bool = True,
|
||||
errors: ErrList | None = None,
|
||||
) -> tuple[Any, str]:
|
||||
"""Resolve an include, substituting the filename if needed.
|
||||
|
||||
Returns the loaded content and the resolved filename.
|
||||
|
||||
Note: no path-traversal validation is performed on the resolved filename.
|
||||
A substitution that resolves to an absolute path will bypass the parent
|
||||
directory (Path.__truediv__ ignores the left operand for absolute paths).
|
||||
ESPHome's trust model assumes the config author controls all substitution
|
||||
values (including command-line substitutions), so path restrictions are
|
||||
an explicit non-goal here.
|
||||
"""
|
||||
original = str(include.file)
|
||||
filename = str(
|
||||
_expand_substitutions(
|
||||
original, path + ["file"], context_vars, strict_undefined, errors
|
||||
)
|
||||
)
|
||||
if filename != original:
|
||||
include = IncludeFile(
|
||||
include.parent_file, filename, include.vars, include.yaml_loader
|
||||
)
|
||||
try:
|
||||
return include.load(), filename
|
||||
except esphome.core.EsphomeError as err:
|
||||
raise cv.Invalid(
|
||||
f"Error including file '{filename}': {err}",
|
||||
path + [f"<{filename}>"],
|
||||
) from err
|
||||
|
||||
|
||||
def _substitute_include(
|
||||
include: IncludeFile,
|
||||
path: list[int | str],
|
||||
context_vars: ContextVars,
|
||||
strict_undefined: bool,
|
||||
errors: ErrList | None,
|
||||
) -> Any:
|
||||
"""Resolve an include and substitute its content."""
|
||||
content, filename = resolve_include(
|
||||
include, path, context_vars, strict_undefined, errors
|
||||
)
|
||||
return substitute(
|
||||
content, path + [f"<{filename}>"], context_vars, strict_undefined, errors
|
||||
)
|
||||
|
||||
|
||||
def substitute(
|
||||
item: Any,
|
||||
path: SubstitutionPath,
|
||||
@@ -333,6 +388,9 @@ def substitute(
|
||||
if item.value != value:
|
||||
result = type(item)(value)
|
||||
|
||||
elif isinstance(item, IncludeFile):
|
||||
result = _substitute_include(item, path, context_vars, strict_undefined, errors)
|
||||
|
||||
if isinstance(item, ESPHomeDataBase):
|
||||
result = make_data_base(result, item)
|
||||
return result
|
||||
|
||||
@@ -2,7 +2,6 @@ from ast import literal_eval
|
||||
from collections.abc import Iterator, Mapping
|
||||
from itertools import chain, islice
|
||||
import math
|
||||
import re
|
||||
from types import GeneratorType
|
||||
from typing import Any
|
||||
|
||||
@@ -10,6 +9,9 @@ import jinja2 as jinja
|
||||
from jinja2.nativetypes import NativeCodeGenerator, NativeTemplate
|
||||
from jinja2.runtime import missing as Missing
|
||||
|
||||
# Re-exported for backward compatibility — consumers import has_jinja from here
|
||||
from esphome.expression import has_jinja # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
TemplateError = jinja.TemplateError
|
||||
TemplateSyntaxError = jinja.TemplateSyntaxError
|
||||
TemplateRuntimeError = jinja.TemplateRuntimeError
|
||||
@@ -20,18 +22,6 @@ Undefined = jinja.Undefined
|
||||
Resolver = ".resolver"
|
||||
|
||||
|
||||
DETECT_JINJA = r"(\$\{)"
|
||||
detect_jinja_re = re.compile(
|
||||
r"<%.+?%>" # Block form expression: <% ... %>
|
||||
r"|\$\{[^}]+\}", # Braced form expression: ${ ... }
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def has_jinja(st: str) -> bool:
|
||||
return detect_jinja_re.search(st) is not None
|
||||
|
||||
|
||||
# SAFE_GLOBALS defines a allowlist of built-in functions or modules that are considered safe to expose
|
||||
# in Jinja templates or other sandboxed evaluation contexts. Only functions that do not allow
|
||||
# arbitrary code execution, file access, or other security risks are included.
|
||||
|
||||
@@ -75,7 +75,6 @@ from esphome.const import (
|
||||
SCHEDULER_DONT_RUN,
|
||||
TYPE_GIT,
|
||||
TYPE_LOCAL,
|
||||
VALID_SUBSTITUTIONS_CHARACTERS,
|
||||
Framework,
|
||||
__version__ as ESPHOME_VERSION,
|
||||
)
|
||||
@@ -90,6 +89,7 @@ from esphome.core import (
|
||||
TimePeriodNanoseconds,
|
||||
TimePeriodSeconds,
|
||||
)
|
||||
from esphome.expression import SUBSTITUTION_VARIABLE_PROG as VARIABLE_PROG
|
||||
from esphome.helpers import add_class_to_obj, docs_url, list_starts_with
|
||||
from esphome.schema_extractors import (
|
||||
SCHEMA_EXTRACT,
|
||||
@@ -104,11 +104,6 @@ from esphome.yaml_util import make_data_base
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# pylint: disable=consider-using-f-string
|
||||
VARIABLE_PROG = re.compile(
|
||||
f"\\$([{VALID_SUBSTITUTIONS_CHARACTERS}]+|\\{{[{VALID_SUBSTITUTIONS_CHARACTERS}]*\\}})"
|
||||
)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
Schema = _Schema
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Helpers for detecting substitution variables and Jinja expressions."""
|
||||
|
||||
import re
|
||||
|
||||
from esphome.const import VALID_SUBSTITUTIONS_CHARACTERS
|
||||
|
||||
SUBSTITUTION_VARIABLE_PROG = re.compile(
|
||||
rf"\$([{VALID_SUBSTITUTIONS_CHARACTERS}]+|\{{[{VALID_SUBSTITUTIONS_CHARACTERS}]*\}})"
|
||||
)
|
||||
|
||||
_JINJA_RE = re.compile(
|
||||
r"<%.+?%>" # Block: <% ... %>
|
||||
r"|\$\{[^}]+\}", # Braced: ${ ... }
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def has_jinja(value: str) -> bool:
|
||||
"""Check if a string contains Jinja expressions."""
|
||||
return _JINJA_RE.search(value) is not None
|
||||
|
||||
|
||||
def has_substitution_or_expression(value: str) -> bool:
|
||||
"""Check if a string contains substitution variables ($name, ${name}) or Jinja expressions."""
|
||||
return SUBSTITUTION_VARIABLE_PROG.search(value) is not None or has_jinja(value)
|
||||
+128
-20
@@ -33,6 +33,7 @@ from esphome.core import (
|
||||
MACAddress,
|
||||
TimePeriod,
|
||||
)
|
||||
from esphome.expression import has_substitution_or_expression
|
||||
from esphome.helpers import add_class_to_obj
|
||||
from esphome.util import OrderedDict, filter_yaml_files
|
||||
|
||||
@@ -110,24 +111,6 @@ def make_data_base(
|
||||
return value
|
||||
|
||||
|
||||
class ConfigContext:
|
||||
"""This is a mixin class that holds substitution vars that should be applied
|
||||
to the tagged node and its children. During configuration loading, context vars can
|
||||
be added to nodes using `add_context` function, which applies the mixin storing
|
||||
the captured values and unevaluated expressions.
|
||||
The substitution pass then recreates the effective context by merging the context vars
|
||||
from this node and parent nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
def vars(self) -> dict[str, Any]:
|
||||
return self._context_vars
|
||||
|
||||
def set_context(self, vars: dict[str, Any]) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._context_vars = vars
|
||||
|
||||
|
||||
def add_context(value: Any, context_vars: dict[str, Any] | None) -> Any:
|
||||
"""Tags a list/string/dict value with context vars that must be applied to it and its children
|
||||
during the substitution pass. If no vars are given, no tagging is done.
|
||||
@@ -151,6 +134,94 @@ def add_context(value: Any, context_vars: dict[str, Any] | None) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
class ConfigContext:
|
||||
"""This is a mixin class that holds substitution vars that should be applied
|
||||
to the tagged node and its children. During configuration loading, context vars can
|
||||
be added to nodes using `add_context` function, which applies the mixin storing
|
||||
the captured values and unevaluated expressions.
|
||||
The substitution pass then recreates the effective context by merging the context vars
|
||||
from this node and parent nodes.
|
||||
"""
|
||||
|
||||
@property
|
||||
def vars(self) -> dict[str, Any]:
|
||||
return self._context_vars
|
||||
|
||||
def set_context(self, vars: dict[str, Any]) -> None:
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._context_vars = vars
|
||||
|
||||
def copy_context_to_children(self) -> None:
|
||||
"""Propagate context to children.
|
||||
|
||||
isinstance(self, dict/list) works because ConfigContext is dynamically
|
||||
mixed into dict/list subclasses via add_class_to_obj in add_context().
|
||||
"""
|
||||
if isinstance(self, dict):
|
||||
# pylint: disable=no-member
|
||||
tagged = {
|
||||
add_context(k, self.vars): add_context(v, self.vars)
|
||||
for k, v in self.items()
|
||||
}
|
||||
self.clear()
|
||||
self.update(tagged)
|
||||
elif isinstance(self, list):
|
||||
for i, item in enumerate(self):
|
||||
# pylint: disable=unsupported-assignment-operation
|
||||
self[i] = add_context(item, self.vars)
|
||||
|
||||
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class IncludeFile:
|
||||
"""Deferred !include that is resolved during the substitution pass.
|
||||
|
||||
Created during YAML parsing instead of loading the file immediately,
|
||||
allowing substitution variables to appear in the filename path
|
||||
(e.g. ``!include device-${platform}.yaml``). The actual file is
|
||||
loaded on the first call to ``load()``, and the result is cached.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_file: Path,
|
||||
file: Path | str,
|
||||
vars: dict[str, Any] | None,
|
||||
yaml_loader: Callable[[Path], Any],
|
||||
) -> None:
|
||||
self.parent_file = parent_file
|
||||
self.file = Path(file)
|
||||
self.vars = vars
|
||||
self.yaml_loader = yaml_loader
|
||||
self._content: Any = _UNSET
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IncludeFile({self.file.as_posix()})"
|
||||
|
||||
def load(self) -> Any:
|
||||
"""Load and cache the included file content.
|
||||
|
||||
Note: returns the cached mutable object on subsequent calls.
|
||||
Callers that need to modify the result should copy it first.
|
||||
"""
|
||||
if self._content is not _UNSET:
|
||||
return self._content
|
||||
if self.has_unresolved_expressions():
|
||||
from esphome.config_validation import Invalid
|
||||
|
||||
raise Invalid(
|
||||
f"Cannot load include with unresolved substitutions: {self.file}"
|
||||
)
|
||||
self._content = self.yaml_loader(Path(self.parent_file.parent / self.file))
|
||||
self._content = add_context(self._content, self.vars)
|
||||
return self._content
|
||||
|
||||
def has_unresolved_expressions(self) -> bool:
|
||||
"""Check if the filename contains substitution variables or Jinja expressions."""
|
||||
return has_substitution_or_expression(str(self.file))
|
||||
|
||||
|
||||
def _add_data_ref(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapped(loader, node):
|
||||
@@ -170,6 +241,36 @@ def _add_data_ref(fn):
|
||||
return wrapped
|
||||
|
||||
|
||||
_MAX_MERGE_INCLUDE_DEPTH = 10
|
||||
|
||||
|
||||
def _resolve_merge_include(value: Any, node: yaml.Node, value_node: yaml.Node) -> Any:
|
||||
"""Resolve an IncludeFile (and chains) and propagate context for merge key handling."""
|
||||
for _ in range(_MAX_MERGE_INCLUDE_DEPTH):
|
||||
if not isinstance(value, IncludeFile):
|
||||
break
|
||||
if value.has_unresolved_expressions():
|
||||
raise yaml.constructor.ConstructorError(
|
||||
"While constructing a mapping",
|
||||
node.start_mark,
|
||||
"Substitution in include filename with merge keys is not supported yet.",
|
||||
value_node.start_mark,
|
||||
)
|
||||
value = value.load()
|
||||
else:
|
||||
raise yaml.constructor.ConstructorError(
|
||||
"While constructing a mapping",
|
||||
node.start_mark,
|
||||
f"Maximum include chain depth ({_MAX_MERGE_INCLUDE_DEPTH}) exceeded in merge key",
|
||||
value_node.start_mark,
|
||||
)
|
||||
if isinstance(value, ConfigContext):
|
||||
# Since the parent dict/list will disappear, propagate
|
||||
# context to children now to retain context vars
|
||||
value.copy_context_to_children()
|
||||
return value
|
||||
|
||||
|
||||
class ESPHomeLoaderMixin:
|
||||
"""Loader class that keeps track of line numbers."""
|
||||
|
||||
@@ -261,6 +362,9 @@ class ESPHomeLoaderMixin:
|
||||
|
||||
# This is a merge key, resolve value and add to merge_pairs
|
||||
value = self.construct_object(value_node)
|
||||
|
||||
value = _resolve_merge_include(value, node, value_node)
|
||||
|
||||
if isinstance(value, dict):
|
||||
# base case, copy directly to merge_pairs
|
||||
# direct merge, like "<<: {some_key: some_value}"
|
||||
@@ -268,6 +372,7 @@ class ESPHomeLoaderMixin:
|
||||
elif isinstance(value, list):
|
||||
# sequence merge, like "<<: [{some_key: some_value}, {other_key: some_value}]"
|
||||
for item in value:
|
||||
item = _resolve_merge_include(item, node, value_node)
|
||||
if not isinstance(item, dict):
|
||||
raise yaml.constructor.ConstructorError(
|
||||
"While constructing a mapping",
|
||||
@@ -362,8 +467,11 @@ class ESPHomeLoaderMixin:
|
||||
else:
|
||||
file, vars = node.value, None
|
||||
|
||||
result = self.yaml_loader(self._rel_path(file))
|
||||
return add_context(result, vars)
|
||||
return IncludeFile(self.name, file, vars, self.yaml_loader)
|
||||
|
||||
# Directory includes (!include_dir_*) load eagerly during YAML parsing
|
||||
# because their paths are directory names, not individual files, and
|
||||
# substitutions in directory paths are not supported.
|
||||
|
||||
@_add_data_ref
|
||||
def construct_include_dir_list(self, node: yaml.Node) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
values:
|
||||
- var1: 4
|
||||
- a: 5
|
||||
- b: 6
|
||||
- c: The value of C is 7
|
||||
- This value comes from inc2.yaml. x is 3, y is 4
|
||||
- From main config, x is 3, y is 2
|
||||
- $a $b $c are out of scope here
|
||||
- keys_in_inc3:
|
||||
x: 3
|
||||
y: 2
|
||||
substitutions:
|
||||
x: 3
|
||||
y: 2
|
||||
include_file: inc1
|
||||
@@ -0,0 +1,21 @@
|
||||
substitutions:
|
||||
include_file: inc1
|
||||
x: 3 # override x from inc2.yaml
|
||||
|
||||
packages:
|
||||
my_package: !include
|
||||
file: ${include_file + ".yaml"} # includes inc1.yaml
|
||||
vars:
|
||||
var1: 4
|
||||
a: ${x+2}
|
||||
b: ${a+1}
|
||||
c: 7
|
||||
other_package: !include
|
||||
file: inc${1+1}.yaml # includes inc2.yaml
|
||||
vars:
|
||||
y: 4
|
||||
|
||||
values:
|
||||
- From main config, x is $x, y is $y
|
||||
- $a $b $c are out of scope here
|
||||
- !include ${"inc" + "3.yaml"} # includes inc3.yaml here (not a package)
|
||||
@@ -0,0 +1,9 @@
|
||||
substitutions:
|
||||
x: 7
|
||||
test_list:
|
||||
- content:
|
||||
before: Content before
|
||||
after: Content after
|
||||
keys_in_inc3:
|
||||
x: 7
|
||||
y: 8
|
||||
@@ -0,0 +1,10 @@
|
||||
substitutions:
|
||||
x: 7
|
||||
test_list:
|
||||
- content:
|
||||
before: Content before
|
||||
<<: !include
|
||||
file: inc3.yaml
|
||||
vars:
|
||||
y: 8
|
||||
after: Content after
|
||||
@@ -0,0 +1,6 @@
|
||||
substitutions:
|
||||
x: 1
|
||||
y: 2
|
||||
|
||||
values:
|
||||
- This value comes from inc2.yaml. x is $x, y is $y
|
||||
@@ -0,0 +1,3 @@
|
||||
keys_in_inc3:
|
||||
x: ${x}
|
||||
y: ${y}
|
||||
@@ -8,12 +8,17 @@ import pytest
|
||||
|
||||
from esphome import config as config_module, yaml_util
|
||||
from esphome.components import substitutions
|
||||
from esphome.components.packages import do_packages_pass, merge_packages
|
||||
from esphome.components.packages import (
|
||||
MAX_INCLUDE_DEPTH,
|
||||
_PackageProcessor,
|
||||
do_packages_pass,
|
||||
merge_packages,
|
||||
)
|
||||
from esphome.config import resolve_extend_remove
|
||||
from esphome.config_helpers import Extend, merge_config
|
||||
import esphome.config_validation as cv
|
||||
from esphome.const import CONF_SUBSTITUTIONS
|
||||
from esphome.core import CORE, Lambda
|
||||
from esphome.core import CORE, EsphomeError, Lambda
|
||||
from esphome.util import OrderedDict
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -630,3 +635,62 @@ def test_do_substitution_pass_substitutions_must_be_mapping_from_config() -> Non
|
||||
cv.Invalid, match="Substitutions must be a key to value mapping"
|
||||
):
|
||||
substitutions.do_substitution_pass(config)
|
||||
|
||||
|
||||
# ── IncludeFile / package loading tests ────────────────────────────────────
|
||||
|
||||
|
||||
def test_resolve_package_max_depth_exceeded(tmp_path: Path) -> None:
|
||||
"""A yaml_loader that always returns another IncludeFile triggers the depth guard."""
|
||||
parent = tmp_path / "main.yaml"
|
||||
parent.write_text("")
|
||||
|
||||
# Each call to the loader returns a fresh IncludeFile pointing at itself,
|
||||
# so PACKAGE_SCHEMA always sees an IncludeFile and never a dict.
|
||||
def always_returns_include(path: Path) -> yaml_util.IncludeFile:
|
||||
return yaml_util.IncludeFile(parent, path.name, None, always_returns_include)
|
||||
|
||||
package_config = yaml_util.IncludeFile(
|
||||
parent, "test.yaml", None, always_returns_include
|
||||
)
|
||||
processor = _PackageProcessor({}, None, False)
|
||||
with pytest.raises(
|
||||
cv.Invalid,
|
||||
match=f"Maximum include nesting depth \\({MAX_INCLUDE_DEPTH}\\) exceeded",
|
||||
):
|
||||
processor.resolve_package(package_config, substitutions.ContextVars())
|
||||
|
||||
|
||||
def test_include_filename_substitution_undefined_var(tmp_path: Path) -> None:
|
||||
"""!include with an undefined substitution variable raises cv.Invalid.
|
||||
|
||||
The error message must reference the unresolved filename template so the
|
||||
user knows which include failed, rather than seeing a bare file-not-found.
|
||||
"""
|
||||
main_file = tmp_path / "main.yaml"
|
||||
main_file.write_text("result: !include ${undefined_var}.yaml\n")
|
||||
|
||||
config = yaml_util.load_yaml(main_file)
|
||||
with pytest.raises(cv.Invalid, match=r"\$\{undefined_var\}"):
|
||||
substitutions.do_substitution_pass(config)
|
||||
|
||||
|
||||
def test_resolve_package_undefined_var_in_include_filename(tmp_path: Path) -> None:
|
||||
"""An undefined substitution in a package include filename raises cv.Invalid.
|
||||
|
||||
Previously this would raise an unhandled UndefinedError. With
|
||||
strict_undefined=False, the unresolved filename passes through to
|
||||
file loading which produces a clean cv.Invalid error.
|
||||
"""
|
||||
parent = tmp_path / "main.yaml"
|
||||
parent.write_text("")
|
||||
|
||||
def loader(path: Path):
|
||||
raise EsphomeError(f"Error reading file {path}: No such file")
|
||||
|
||||
package_config = yaml_util.IncludeFile(
|
||||
parent, "${undefined_var}.yaml", None, loader
|
||||
)
|
||||
processor = _PackageProcessor({}, None, False)
|
||||
with pytest.raises(cv.Invalid, match="unresolved substitutions"):
|
||||
processor.resolve_package(package_config, substitutions.ContextVars())
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from unittest.mock import patch
|
||||
@@ -7,6 +8,7 @@ import pytest
|
||||
from esphome import core, yaml_util
|
||||
from esphome.components import substitutions
|
||||
from esphome.config_helpers import Extend, Remove
|
||||
import esphome.config_validation as cv
|
||||
from esphome.core import EsphomeError
|
||||
from esphome.util import OrderedDict
|
||||
|
||||
@@ -74,7 +76,9 @@ def test_parsing_with_custom_loader(fixture_path):
|
||||
loader_calls.append(fname)
|
||||
|
||||
with yaml_file.open(encoding="utf-8") as f_handle:
|
||||
yaml_util.parse_yaml(yaml_file, f_handle, custom_loader)
|
||||
config = yaml_util.parse_yaml(yaml_file, f_handle, custom_loader)
|
||||
# substitute config to expand includes:
|
||||
substitutions.substitute(config, [], substitutions.ContextVars(), False)
|
||||
|
||||
assert len(loader_calls) == 3
|
||||
assert loader_calls[0].parts[-2:] == ("includes", "included.yaml")
|
||||
@@ -348,7 +352,9 @@ def test_track_yaml_loads_records_includes(tmp_path: Path) -> None:
|
||||
main.write_text("child: !include included.yaml\n")
|
||||
|
||||
with yaml_util.track_yaml_loads() as loaded:
|
||||
yaml_util.load_yaml(main)
|
||||
result = yaml_util.load_yaml(main)
|
||||
# !include is deferred; resolve it to trigger the nested load
|
||||
result["child"].load()
|
||||
|
||||
resolved = [p.name for p in loaded]
|
||||
assert "main.yaml" in resolved
|
||||
@@ -500,3 +506,161 @@ def test_represent_extend() -> None:
|
||||
def test_represent_remove() -> None:
|
||||
"""Test that Remove objects are dumped as plain !remove scalars."""
|
||||
assert yaml_util.dump({"key": Remove("my_id")}) == "key: !remove 'my_id'\n"
|
||||
|
||||
|
||||
# ── IncludeFile unit tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_include_file_repr(tmp_path: Path) -> None:
|
||||
"""repr() includes the filename so it appears usefully in error messages."""
|
||||
parent = tmp_path / "main.yaml"
|
||||
include = yaml_util.IncludeFile(parent, "some/nested.yaml", None, lambda _: {})
|
||||
assert repr(include) == "IncludeFile(some/nested.yaml)"
|
||||
|
||||
|
||||
def test_include_file_load_caches_result(tmp_path: Path) -> None:
|
||||
"""load() invokes the yaml_loader only once; subsequent calls return the cached object."""
|
||||
parent = tmp_path / "main.yaml"
|
||||
content = {"key": "value"}
|
||||
call_count = 0
|
||||
|
||||
def counting_loader(_):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return content
|
||||
|
||||
include = yaml_util.IncludeFile(parent, "child.yaml", None, counting_loader)
|
||||
first = include.load()
|
||||
second = include.load()
|
||||
|
||||
assert call_count == 1
|
||||
assert first is second
|
||||
|
||||
|
||||
def test_include_file_load_caches_none_result(tmp_path: Path) -> None:
|
||||
"""load() caches None content (empty YAML files) and does not re-invoke the loader."""
|
||||
parent = tmp_path / "main.yaml"
|
||||
call_count = 0
|
||||
|
||||
def counting_loader(_):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
||||
include = yaml_util.IncludeFile(parent, "empty.yaml", None, counting_loader)
|
||||
first = include.load()
|
||||
second = include.load()
|
||||
|
||||
assert call_count == 1
|
||||
assert first is None
|
||||
assert second is None
|
||||
|
||||
|
||||
def test_include_file_load_raises_on_unresolved_expressions(tmp_path: Path) -> None:
|
||||
"""load() raises if the filename contains unresolved substitutions or expressions."""
|
||||
parent = tmp_path / "main.yaml"
|
||||
include = yaml_util.IncludeFile(parent, "${undefined_var}.yaml", None, lambda _: {})
|
||||
with pytest.raises(cv.Invalid, match="unresolved"):
|
||||
include.load()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("filename", "expected"),
|
||||
[
|
||||
("device-${platform}.yaml", True),
|
||||
("$platform.yaml", True),
|
||||
("${a + b}.yaml", True), # Jinja expression
|
||||
("device.yaml", False),
|
||||
("path/to/device.yaml", False),
|
||||
("my$file.yaml", True), # $file is a valid substitution
|
||||
("price-100$.yaml", False), # $ at end, not followed by valid substitution
|
||||
],
|
||||
)
|
||||
def test_include_file_has_unresolved_expressions(
|
||||
tmp_path: Path, filename: str, expected: bool
|
||||
) -> None:
|
||||
"""has_unresolved_expressions() detects substitution patterns in the filename."""
|
||||
parent = tmp_path / "main.yaml"
|
||||
include = yaml_util.IncludeFile(parent, filename, None, lambda _: {})
|
||||
assert include.has_unresolved_expressions() == expected
|
||||
|
||||
|
||||
def test_include_in_list_context() -> None:
|
||||
"""!include of a file returning a list is handled correctly,
|
||||
including when that list itself contains a nested IncludeFile."""
|
||||
parent = Path("/fake/main.yaml")
|
||||
|
||||
# The nested IncludeFile resolves to a plain string value
|
||||
inner = yaml_util.IncludeFile(parent, "inner.yaml", None, lambda _: "gamma")
|
||||
|
||||
# The outer IncludeFile returns a list whose last element is itself an IncludeFile,
|
||||
# exercising the substitution pass's ability to recurse into loaded content.
|
||||
outer = yaml_util.IncludeFile(
|
||||
parent, "items.yaml", None, lambda _: ["alpha", "beta", inner]
|
||||
)
|
||||
|
||||
config = OrderedDict({"values": outer})
|
||||
config = substitutions.do_substitution_pass(config)
|
||||
|
||||
assert config["values"] == ["alpha", "beta", "gamma"]
|
||||
|
||||
|
||||
def test_include_plain_filename_loads_after_deferred_refactor() -> None:
|
||||
"""!include with a plain filename (no $ expressions) still loads correctly.
|
||||
|
||||
Regression guard: the deferred-loading refactor must not break the simple case.
|
||||
"""
|
||||
parent = Path("/fake/main.yaml")
|
||||
include = yaml_util.IncludeFile(
|
||||
parent, "child.yaml", None, lambda _: {"answer": 42}
|
||||
)
|
||||
|
||||
config = OrderedDict({"result": include})
|
||||
config = substitutions.do_substitution_pass(config)
|
||||
|
||||
assert config["result"]["answer"] == 42
|
||||
|
||||
|
||||
def test_yaml_merge_include_with_filename_substitution_raises() -> None:
|
||||
"""<<: !include ${expr} raises a clear error — substitutions in merge-key filenames
|
||||
are not yet supported, and the error message must say so."""
|
||||
yaml_text = "base:\n existing: value\n <<: !include ${filename}.yaml\n"
|
||||
with pytest.raises(EsphomeError, match="not supported yet"):
|
||||
yaml_util.parse_yaml(
|
||||
Path("/fake/main.yaml"), io.StringIO(yaml_text), lambda _: {}
|
||||
)
|
||||
|
||||
|
||||
def test_yaml_merge_list_include_with_filename_substitution_raises() -> None:
|
||||
"""Substitutions in include filenames within merge-key lists raise a clear error."""
|
||||
yaml_text = "base:\n existing: value\n <<:\n - !include ${filename}.yaml\n"
|
||||
with pytest.raises(EsphomeError, match="not supported yet"):
|
||||
yaml_util.parse_yaml(
|
||||
Path("/fake/main.yaml"), io.StringIO(yaml_text), lambda _: {}
|
||||
)
|
||||
|
||||
|
||||
def test_yaml_merge_chain_include_resolves() -> None:
|
||||
"""Chained includes in merge keys resolve through multiple IncludeFile layers."""
|
||||
parent = Path("/fake/main.yaml")
|
||||
|
||||
inner = yaml_util.IncludeFile(parent, "inner.yaml", None, lambda _: {"x": 1})
|
||||
outer = yaml_util.IncludeFile(parent, "outer.yaml", None, lambda _: inner)
|
||||
|
||||
yaml_text = "base:\n existing: value\n <<: !include outer.yaml\n"
|
||||
config = yaml_util.parse_yaml(parent, io.StringIO(yaml_text), lambda _: outer)
|
||||
config = substitutions.do_substitution_pass(config)
|
||||
|
||||
assert config["base"]["x"] == 1
|
||||
assert config["base"]["existing"] == "value"
|
||||
|
||||
|
||||
def test_yaml_merge_chain_include_depth_exceeded() -> None:
|
||||
"""Chain includes in merge keys exceeding depth limit raise a clear error."""
|
||||
parent = Path("/fake/main.yaml")
|
||||
|
||||
def self_referencing_loader(path: Path) -> yaml_util.IncludeFile:
|
||||
return yaml_util.IncludeFile(parent, path.name, None, self_referencing_loader)
|
||||
|
||||
yaml_text = "base:\n <<: !include loop.yaml\n"
|
||||
with pytest.raises(EsphomeError, match="Maximum include chain depth"):
|
||||
yaml_util.parse_yaml(parent, io.StringIO(yaml_text), self_referencing_loader)
|
||||
|
||||
Reference in New Issue
Block a user