diff --git a/esphome/components/packages/__init__.py b/esphome/components/packages/__init__.py index 6d353ccf11..793cb946dd 100644 --- a/esphome/components/packages/__init__.py +++ b/esphome/components/packages/__init__.py @@ -226,7 +226,7 @@ def _process_remote_package(config: dict, skip_update: bool = False) -> dict: raise cv.Invalid( f"Current ESPHome Version is too old to use this package: {ESPHOME_VERSION} < {min_version}" ) - new_yaml = yaml_util.substitute_vars(new_yaml, vars) + new_yaml = yaml_util.add_context(new_yaml, vars or None) packages[f"{filename}{idx}"] = new_yaml except EsphomeError as e: raise cv.Invalid( @@ -296,6 +296,13 @@ def do_packages_pass(config: dict, skip_update: bool = False) -> dict: def process_package_callback(package_config: dict) -> dict: """This will be called for each package found in the config.""" + if isinstance(package_config, yaml_util.ConfigContext): + context_vars = package_config.vars + if CONF_PACKAGES in package_config or CONF_URL in package_config: + # Remote package definition: eagerly resolve before PACKAGE_SCHEMA validation. + from esphome.components.substitutions import substitute_context_vars + + substitute_context_vars(package_config, context_vars) package_config = PACKAGE_SCHEMA(package_config) if isinstance(package_config, str): return package_config # Jinja string, skip processing diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index 7e15f714f7..ecee816ce9 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -1,31 +1,50 @@ +from collections import ChainMap import logging -from re import Match from typing import Any from esphome import core from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered import esphome.config_validation as cv from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS -from esphome.yaml_util import ESPHomeDataBase, ESPLiteralValue, make_data_base +from esphome.types import ConfigType +from esphome.util import OrderedDict +from esphome.yaml_util import ( + ConfigContext, + ESPHomeDataBase, + ESPLiteralValue, + make_data_base, +) -from .jinja import Jinja, JinjaError, JinjaStr, has_jinja +from .jinja import Jinja, JinjaError, Missing, Resolver, UndefinedError, has_jinja CODEOWNERS = ["@esphome/core"] _LOGGER = logging.getLogger(__name__) +ContextVars = ChainMap[str, Any] +SubstitutionPath = list[int | str] +ErrList = list[tuple[UndefinedError, SubstitutionPath, Any]] +# Module-level instance is safe: context_vars is passed per-call, and context_trace +# is stack-saved/restored within expand(). Not thread-safe — only use from one thread. +jinja = Jinja() -def validate_substitution_key(value): + +def validate_substitution_key(value: Any) -> str: + """Validate and normalize a substitution key, stripping a leading ``$`` if present.""" value = cv.string(value) if not value: raise cv.Invalid("Substitution key must not be empty") if value[0] == "$": value = value[1:] + if not value: + raise cv.Invalid("Substitution key must not be empty") if value[0].isdigit(): raise cv.Invalid("First character in substitutions cannot be a digit.") for char in value: if char not in VALID_SUBSTITUTIONS_CHARACTERS: raise cv.Invalid( - f"Substitution must only consist of upper/lowercase characters, the underscore and numbers. The character '{char}' cannot be used" + f"Substitution must only consist of upper/lowercase characters," + f" the underscore and numbers." + f" The character '{char}' cannot be used" ) return value @@ -37,8 +56,8 @@ CONFIG_SCHEMA = cv.Schema( ) -async def to_code(config): - pass +async def to_code(config: ConfigType) -> None: + """No runtime code generation needed — substitutions are resolved at config time.""" def _restore_data_base(value: Any, orig_value: ESPHomeDataBase) -> ESPHomeDataBase: @@ -62,91 +81,122 @@ def _restore_data_base(value: Any, orig_value: ESPHomeDataBase) -> ESPHomeDataBa return value -def _expand_jinja( - value: str | JinjaStr, - orig_value: str | JinjaStr, - path, - jinja: Jinja, - ignore_missing: bool, -) -> Any: - if has_jinja(value): - # If the original value passed in to this function is a JinjaStr, it means it contains an unresolved - # Jinja expression from a previous pass. - if isinstance(orig_value, JinjaStr): - # Rebuild the JinjaStr in case it was lost while replacing substitutions. - value = JinjaStr(value, orig_value.upvalues) - try: - # Invoke the jinja engine to evaluate the expression. - value, err = jinja.expand(value) - if err is not None and not ignore_missing and "password" not in path: - _LOGGER.warning( - "Found '%s' (see %s) which looks like an expression," - " but could not resolve all the variables: %s", - value, - "->".join(str(x) for x in path), - err.message, - ) - except JinjaError as err: - raise cv.Invalid( - f"{err.error_name()} Error evaluating jinja expression '{value}': {str(err.parent())}." - f"\nEvaluation stack: (most recent evaluation last)\n{err.stack_trace_str()}" - f"\nRelevant context:\n{err.context_trace_str()}" - f"\nSee {'->'.join(str(x) for x in path)}", - path, - ) - # If the original, unexpanded string, contained document metadata (ESPHomeDatabase), - # assign this same document metadata to the resulting value. - if isinstance(orig_value, ESPHomeDataBase): - value = _restore_data_base(value, orig_value) +def _try_substitute(value: Any, context: ContextVars) -> Any: + """Substitute variables in value, returning the result or the original if unchanged.""" + result = _substitute_item(value, [], context, strict_undefined=True) + return result if result is not None else value - return value + +def _resolve_var(name: str, context_vars: ContextVars) -> Any: + """Look up a substitution variable, falling back to the resolver callback.""" + sub = context_vars.get(name, Missing) + if sub is Missing: + resolver = context_vars.get(Resolver) + if resolver: + sub = resolver(name) + return sub + + +def _handle_undefined( + err: UndefinedError, + path: SubstitutionPath, + value: Any, + strict_undefined: bool, + errors: ErrList | None, +) -> None: + """Handle an undefined variable. + + In strict mode, raises immediately. Otherwise, appends to the errors + list for deferred warning at the end of the substitution pass. + """ + if strict_undefined: + raise err + if errors is not None: + errors.append((err, path, value)) def _expand_substitutions( - substitutions: dict, value: str, path, jinja: Jinja, ignore_missing: bool + value: str, + path: SubstitutionPath, + context_vars: ContextVars, + strict_undefined: bool, + errors: ErrList | None, ) -> Any: + """Expand ``$var``, ``${var}``, and Jinja expressions in a string. + + Works in two phases: + + 1. **Simple substitution** — scan for ``$name`` / ``${name}`` tokens + and replace them with the value from *context_vars*. If the token + spans the entire string, return the raw value (preserving type). + 2. **Jinja evaluation** — if the result still contains Jinja syntax + (e.g. ``${a * b}``), render it through the Jinja engine with the + full *context_vars* as template variables. + + Returns the expanded value (may be a non-string type) or the + original *value* unchanged if there is nothing to substitute. + """ if "$" not in value: return value orig_value = value - i = 0 - while True: - m: Match[str] = cv.VARIABLE_PROG.search(value, i) - if not m: - # No more variable substitutions found. See if the remainder looks like a jinja template - value = _expand_jinja(value, orig_value, path, jinja, ignore_missing) - break - - i, j = m.span(0) + # Phase 1: Replace $var and ${var} references + search_pos = 0 + while (m := cv.VARIABLE_PROG.search(value, search_pos)) is not None: + match_start, match_end = m.span(0) name: str = m.group(1) if name.startswith("{") and name.endswith("}"): name = name[1:-1] - if name not in substitutions: - if not ignore_missing and "password" not in path: - _LOGGER.warning( - "Found '%s' (see %s) which looks like a substitution, but '%s' was " - "not declared", - orig_value, - "->".join(str(x) for x in path), - name, - ) - i = j + sub = _resolve_var(name, context_vars) + if sub is Missing: + _handle_undefined( + err=UndefinedError(f"'{name}' is undefined"), + path=path, + value=value, + strict_undefined=strict_undefined, + errors=errors, + ) + search_pos = match_end continue - sub: Any = substitutions[name] - - if i == 0 and j == len(value): - # The variable spans the whole expression, e.g., "${varName}". Return its resolved value directly - # to conserve its type. + if match_start == 0 and match_end == len(value): + # The variable spans the whole expression, e.g., "${varName}". + # Return its resolved value directly to conserve its type. value = sub break - tail = value[j:] - value = value[:i] + str(sub) - i = len(value) + tail = value[match_end:] + value = value[:match_start] + str(sub) + search_pos = len(value) value += tail + # Phase 2: Evaluate any remaining jinja expressions (e.g., "${a * b}") + if isinstance(value, str) and has_jinja(value): + try: + value = jinja.expand(value, context_vars) + except UndefinedError as err: + _handle_undefined( + err=err, + path=path, + value=value, + strict_undefined=strict_undefined, + errors=errors, + ) + except JinjaError as err: + raise cv.Invalid( + f"{err.error_name()} Error evaluating jinja expression" + f" '{value}': {str(err.parent())}." + f"\nEvaluation stack: (most recent evaluation last)" + f"\n{err.stack_trace_str()}" + f"\nRelevant context:\n{err.context_trace_str()}" + f"\nSee {'->'.join(str(x) for x in path)}", + path, + ) + else: + if isinstance(orig_value, ESPHomeDataBase): + value = _restore_data_base(value, orig_value) + # orig_value can also already be a lambda with esp_range info, and only # a plain string is sent in orig_value if isinstance(orig_value, ESPHomeDataBase): @@ -157,83 +207,221 @@ def _expand_substitutions( return value +def _push_context( + local_vars: dict[str, Any], + parent_context: ContextVars, + errors: ErrList | None = None, +) -> tuple[ContextVars, dict[str, Any]]: + """Resolve local_vars and layer them on top of parent_context. + + Returns ``(child_context, resolved_vars)`` where *child_context* is a + new :class:`ChainMap` whose front map is *resolved_vars* (an + :class:`OrderedDict` of successfully-resolved variables). + + Variables may reference each other (e.g. ``b: ${a + 1}``). + Dependencies are resolved recursively via a *resolver* callback + that Jinja invokes on cache-miss. If vars are already in + dependency order, the loop iterates exactly once per variable. + + The ChainMap stack used during resolution is:: + + resolver_context → resolved_vars → parent maps … + ↑ ↑ + holds Resolver filled as vars + callback are resolved + """ + # Vars still waiting to be resolved — popped one-by-one by resolve(). + unresolved_vars = local_vars.copy() + # Accumulates resolved values in dependency order; becomes the front + # map of the returned child context so later lookups find them first. + resolved_vars = OrderedDict() + # The context callees will search: resolved_vars (initially empty) + # shadowing whatever the parent already provides. + context_vars = parent_context.new_child(resolved_vars) + + # Vars that failed resolution (missing or circular references). + # Maps name → (original_value, cause_error) for deferred warnings. + unresolvables: dict[str, tuple[Any, UndefinedError]] = {} + + # One extra child layer so the Resolver callback lives in its own + # map and doesn't pollute resolved_vars. + resolver_context = context_vars.new_child() + + def resolve(key: str) -> Any: + """Resolve a variable, recursively resolving any dependencies it references.""" + value = unresolved_vars.pop(key, Missing) + if value is Missing: + return Missing + try: + value = _try_substitute(value, resolver_context) + except UndefinedError as err: + unresolvables[key] = (value, err) + return Missing + resolved_vars[key] = value + return value + + # Set up the resolver for use during substitution + resolver_context[Resolver] = resolve + + # Resolve all variables, recursively resolving dependencies as needed. + # Each call to resolve() resolves that variable and any variables it depends on. + while unresolved_vars: + resolve(next(iter(unresolved_vars))) + + for name, (value, cause) in unresolvables.items(): + resolved_vars[name] = value + if errors is not None: + _handle_undefined( + err=UndefinedError( + f"Could not resolve substitution variable '{name}': {cause}" + ), + path=["substitutions", name], + value=value, + strict_undefined=False, + errors=errors, + ) + + return context_vars, resolved_vars + + +def push_context( + config_node: Any, + parent_context: ContextVars, + errors: ErrList | None = None, +) -> ContextVars: + """Returns the context vars this config node must be evaluated with.""" + if isinstance(config_node, ConfigContext): + return _push_context(config_node.vars, parent_context, errors)[0] + + # This node does not define any vars itself, so just return parent context + return parent_context + + def _substitute_item( - substitutions: dict, item: Any, - path: list[int | str], - jinja: Jinja, - ignore_missing: bool, + path: SubstitutionPath, + parent_context: ContextVars, + strict_undefined: bool, + errors: ErrList | None = None, ) -> Any | None: - if isinstance(item, ESPLiteralValue): - return None # do not substitute inside literal blocks - if isinstance(item, list): - for i, it in enumerate(item): - sub = _substitute_item(substitutions, it, path + [i], jinja, ignore_missing) - if sub is not None: - item[i] = sub - elif isinstance(item, dict): - replace_keys = [] - for k, v in item.items(): - if path or k != CONF_SUBSTITUTIONS: - sub = _substitute_item( - substitutions, k, path + [k], jinja, ignore_missing - ) + """Recursively substitute variables in a config item. + + Walks dicts, lists, strings, Lambdas, Extend, and Remove nodes, + replacing variable references with values from context_vars. + Mutates containers in-place; returns a replacement value for + strings/scalars, or None if the item was unchanged. + """ + + def _walk(item: Any, path: SubstitutionPath, parent_ctx: ContextVars) -> Any | None: + if isinstance(item, ESPLiteralValue): + return None # do not substitute inside literal blocks + + ctx = push_context(item, parent_ctx, errors) + + if isinstance(item, list): + for idx, it in enumerate(item): + sub = _walk(it, path + [idx], ctx) if sub is not None: - replace_keys.append((k, sub)) - sub = _substitute_item(substitutions, v, path + [k], jinja, ignore_missing) - if sub is not None: - item[k] = sub - for old, new in replace_keys: - if str(new) == str(old): - item[new] = item[old] - else: - item[new] = merge_config(item.get(old), item.get(new)) - del item[old] - elif isinstance(item, str): - sub = _expand_substitutions(substitutions, item, path, jinja, ignore_missing) - if isinstance(sub, JinjaStr) or sub != item: - return sub - elif isinstance(item, (core.Lambda, Extend, Remove)): - sub = _expand_substitutions( - substitutions, item.value, path, jinja, ignore_missing + item[idx] = sub + elif isinstance(item, dict): + replace_keys: list[tuple[str, Any]] = [] + for k, v in item.items(): + if path or k != CONF_SUBSTITUTIONS: + sub = _walk(k, path + [k], ctx) + if sub is not None: + replace_keys.append((k, sub)) + sub = _walk(v, path + [k], ctx) + if sub is not None: + item[k] = sub + for old, new in replace_keys: + if str(new) == str(old): + item[new] = item[old] + else: + item[new] = merge_config(item.get(new), item.get(old)) + del item[old] + elif isinstance(item, str): + sub = _expand_substitutions(item, path, ctx, strict_undefined, errors) + if not isinstance(sub, str) or sub != item: + return sub + elif isinstance(item, (core.Lambda, Extend, Remove)) and item.value: + sub = _expand_substitutions(item.value, path, ctx, strict_undefined, errors) + if sub != item.value: + item.value = sub + return None + + return _walk(item, path, parent_context) + + +def substitute_context_vars(node: Any, context_vars: dict[str, Any]) -> None: + """Eagerly substitute context vars into a config node in-place. + + Undefined variables are silently ignored — this is used before + the main substitution pass when not all variables are visible yet. + """ + _substitute_item(node, [], ContextVars(context_vars), strict_undefined=False) + + +def _warn_unresolved_variables(errors: ErrList) -> None: + """Log warnings for unresolved substitution variables, skipping password fields.""" + for err, path, expression in errors: + if "password" in path: + continue + location: str = "->".join(str(x) for x in path) + if isinstance(expression, ESPHomeDataBase) and expression.esp_range is not None: + location += f" in {str(expression.esp_range.start_mark)}" + + _LOGGER.warning( + "The string '%s' looks like an expression," + " but could not resolve all the variables: %s (see %s)", + expression, + err.message, + location, ) - if sub != item: - item.value = sub - return None def do_substitution_pass( - config: dict, command_line_substitutions: dict, ignore_missing: bool = False -) -> None: - if CONF_SUBSTITUTIONS not in config and not command_line_substitutions: - return + config: OrderedDict, command_line_substitutions: dict[str, Any] | None = None +) -> OrderedDict: + """Run the substitution pass over the entire config. - # Merge substitutions in config, overriding with substitutions coming from command line: + Extracts the ``substitutions:`` block, merges in any command-line + overrides, resolves inter-variable dependencies, then walks the + config tree replacing all ``$var`` / ``${expr}`` references. + Returns the (mutated) config dict with resolved substitutions + restored at the front. + """ + # Extract substitutions from config, overriding with substitutions coming from command line: # Use merge_dicts_ordered to preserve OrderedDict type for move_to_end() - substitutions = merge_dicts_ordered( - config.get(CONF_SUBSTITUTIONS, {}), command_line_substitutions or {} - ) - with cv.prepend_path("substitutions"): + substitutions = config.pop(CONF_SUBSTITUTIONS, {}) + with cv.prepend_path(CONF_SUBSTITUTIONS): if not isinstance(substitutions, dict): raise cv.Invalid( f"Substitutions must be a key to value mapping, got {type(substitutions)}" ) + substitutions = merge_dicts_ordered( + substitutions, command_line_substitutions or {} + ) - replace_keys = [] - for key, value in substitutions.items(): + replace_keys: list[tuple[str, str]] = [] + for key in substitutions: with cv.prepend_path(key): sub = validate_substitution_key(key) if sub != key: replace_keys.append((key, sub)) - substitutions[key] = value for old, new in replace_keys: substitutions[new] = substitutions[old] del substitutions[old] - config[CONF_SUBSTITUTIONS] = substitutions - # Move substitutions to the first place to replace substitutions in them correctly - config.move_to_end(CONF_SUBSTITUTIONS, False) + errors: ErrList = [] # Collect undefined errors during substitution + parent_context, substitutions = _push_context(substitutions, ContextVars(), errors) - # Create a Jinja environment that will consider substitutions in scope: - jinja = Jinja(substitutions) - _substitute_item(substitutions, config, [], jinja, ignore_missing) + _substitute_item(config, [], parent_context, False, errors) + + if errors: + _warn_unresolved_variables(errors) + + # Restore substitutions to front of dict for readability + if substitutions: + config[CONF_SUBSTITUTIONS] = substitutions + config.move_to_end(CONF_SUBSTITUTIONS, last=False) + return config diff --git a/esphome/components/substitutions/jinja.py b/esphome/components/substitutions/jinja.py index fb9f843da2..37e9fa4d2d 100644 --- a/esphome/components/substitutions/jinja.py +++ b/esphome/components/substitutions/jinja.py @@ -1,7 +1,6 @@ from ast import literal_eval -from collections.abc import Iterator +from collections.abc import Iterator, Mapping from itertools import chain, islice -import logging import math import re from types import GeneratorType @@ -9,16 +8,17 @@ from typing import Any import jinja2 as jinja from jinja2.nativetypes import NativeCodeGenerator, NativeTemplate - -from esphome.yaml_util import ESPLiteralValue +from jinja2.runtime import missing as Missing TemplateError = jinja.TemplateError TemplateSyntaxError = jinja.TemplateSyntaxError TemplateRuntimeError = jinja.TemplateRuntimeError UndefinedError = jinja.UndefinedError Undefined = jinja.Undefined +# Sentinel key for resolver callback in ContextVars. +# Dots are invalid in substitution names so this can never collide with user keys. +Resolver = ".resolver" -_LOGGER = logging.getLogger(__name__) DETECT_JINJA = r"(\$\{)" detect_jinja_re = re.compile( @@ -52,33 +52,6 @@ SAFE_GLOBALS = { } -class JinjaStr(str): - """ - Wraps a string containing an unresolved Jinja expression, - storing the variables visible to it when it failed to resolve. - For example, an expression inside a package, `${ A * B }` may fail - to resolve at package parsing time if `A` is a local package var - but `B` is a substitution defined in the root yaml. - Therefore, we store the value of `A` as an upvalue bound - to the original string so we may be able to resolve `${ A * B }` - later in the main substitutions pass. - """ - - Undefined = object() - - def __new__(cls, value: str, upvalues=None): - if isinstance(value, JinjaStr): - base = str(value) - merged = {**value.upvalues, **(upvalues or {})} - else: - base = value - merged = dict(upvalues or {}) - obj = super().__new__(cls, base) - obj.upvalues = merged - obj.result = JinjaStr.Undefined - return obj - - class JinjaError(Exception): def __init__(self, context_trace: dict, expr: str): self.context_trace = context_trace @@ -106,9 +79,13 @@ class JinjaError(Exception): class TrackerContext(jinja.runtime.Context): def resolve_or_missing(self, key): val = super().resolve_or_missing(key) - if isinstance(val, JinjaStr): - self.environment.context_trace[key] = val - val, _ = self.environment.expand(val) + if val is Missing: + # Variable not in the template context — check if a resolver callback + # was registered (by _push_context) to lazily resolve dependencies + # between substitution variables in the same block. + resolver = super().resolve_or_missing(Resolver) + if resolver is not Missing: + val = resolver(key) self.environment.context_trace[key] = val return val @@ -160,15 +137,13 @@ def _concat_nodes_override(values: Iterator[Any]) -> Any: class Jinja(jinja.Environment): - """ - Wraps a Jinja environment - """ + """Jinja environment configured for ESPHome substitution expressions.""" # jinja environment customization overrides code_generator_class = NativeCodeGenerator concat = staticmethod(_concat_nodes_override) - def __init__(self, context_vars: dict): + def __init__(self) -> None: super().__init__( trim_blocks=True, lstrip_blocks=True, @@ -183,49 +158,25 @@ class Jinja(jinja.Environment): self.context_class = TrackerContext self.add_extension("jinja2.ext.do") self.context_trace = {} - self.context_vars = {**context_vars} - for k, v in self.context_vars.items(): - if isinstance(v, ESPLiteralValue): - continue - if isinstance(v, str) and not isinstance(v, JinjaStr) and has_jinja(v): - self.context_vars[k] = JinjaStr(v, self.context_vars) - self.globals = { - **self.globals, - **self.context_vars, - **SAFE_GLOBALS, - } + self.globals = {**self.globals, **SAFE_GLOBALS} - def expand(self, content_str: str | JinjaStr) -> Any: + def expand(self, content_str: str, context_vars: Mapping[str, Any]) -> Any: """ Renders a string that may contain Jinja expressions or statements Returns the resulting value if all variables and expressions could be resolved. - Otherwise, it returns a tagged (JinjaStr) string that captures variables - in scope (upvalues), like a closure for later evaluation. """ result = None - override_vars = {} - if isinstance(content_str, JinjaStr): - if content_str.result is not JinjaStr.Undefined: - return content_str.result, None - # If `value` is already a JinjaStr, it means we are trying to evaluate it again - # in a parent pass. - # Hopefully, all required variables are visible now. - override_vars = content_str.upvalues old_trace = self.context_trace self.context_trace = {} try: template = self.from_string(content_str) - result = template.render(override_vars) + result = template.render(context_vars) if isinstance(result, Undefined): - print("" + result) # force a UndefinedError exception - except (TemplateSyntaxError, UndefinedError) as err: - # `content_str` contains a Jinja expression that refers to a variable that is undefined - # in this scope. Perhaps it refers to a root substitution that is not visible yet. - # Therefore, return `content_str` as a JinjaStr, which contains the variables - # that are actually visible to it at this point to postpone evaluation. - return JinjaStr(content_str, {**self.context_vars, **override_vars}), err + str(result) # force a UndefinedError exception + except UndefinedError as err: + raise err except JinjaError as err: err.context_trace = {**self.context_trace, **err.context_trace} err.eval_stack.append(content_str) @@ -242,10 +193,7 @@ class Jinja(jinja.Environment): finally: self.context_trace = old_trace - if isinstance(content_str, JinjaStr): - content_str.result = result - - return result, None + return result class JinjaTemplate(NativeTemplate): diff --git a/esphome/config.py b/esphome/config.py index 6f6ad4886b..b80aaf3700 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -12,7 +12,8 @@ from typing import Any import voluptuous as vol from esphome import core, loader, pins, yaml_util -from esphome.config_helpers import Extend, Remove, merge_config, merge_dicts_ordered +from esphome.components.substitutions import do_substitution_pass +from esphome.config_helpers import Extend, Remove, merge_config import esphome.config_validation as cv from esphome.const import ( CONF_ESPHOME, @@ -974,7 +975,7 @@ class PinUseValidationCheck(ConfigValidationStep): def validate_config( config: dict[str, Any], - command_line_substitutions: dict[str, Any], + command_line_substitutions: dict[str, Any] | None, skip_external_update: bool = False, ) -> Config: result = Config() @@ -994,21 +995,15 @@ def validate_config( result.add_error(err) return result - CORE.raw_config = config - # 1. Load substitutions if CONF_SUBSTITUTIONS in config or command_line_substitutions: - from esphome.components import substitutions - - result[CONF_SUBSTITUTIONS] = merge_dicts_ordered( - config.get(CONF_SUBSTITUTIONS) or {}, command_line_substitutions - ) result.add_output_path([CONF_SUBSTITUTIONS], CONF_SUBSTITUTIONS) - try: - substitutions.do_substitution_pass(config, command_line_substitutions) - except vol.Invalid as err: - result.add_error(err) - return result + try: + config = do_substitution_pass(config, command_line_substitutions) + except vol.Invalid as err: + CORE.raw_config = config + result.add_error(err) + return result # 1.1. Merge packages if CONF_PACKAGES in config: @@ -1016,6 +1011,9 @@ def validate_config( config = merge_packages(config) + # Remove substitutions from config during validation to prevent + # re-substitution. Re-added to result at the end of this function. + substitutions = config.pop(CONF_SUBSTITUTIONS, None) CORE.raw_config = config # 1.2. Resolve !extend and !remove and check for REPLACEME @@ -1089,6 +1087,10 @@ def validate_config( result.run_validation_steps() + if substitutions is not None: + result[CONF_SUBSTITUTIONS] = substitutions + result.move_to_end(CONF_SUBSTITUTIONS, last=False) + return result diff --git a/esphome/yaml_util.py b/esphome/yaml_util.py index d0eab4e44e..e001316a22 100644 --- a/esphome/yaml_util.py +++ b/esphome/yaml_util.py @@ -325,9 +325,7 @@ class ESPHomeLoaderMixin: return val @_add_data_ref - def construct_include( - self, node: yaml.Node - ) -> dict[str, Any] | OrderedDict[str, Any]: + def construct_include(self, node: yaml.Node) -> Any: from esphome.const import CONF_VARS def extract_file_vars(node): @@ -344,9 +342,7 @@ class ESPHomeLoaderMixin: file, vars = node.value, None result = self.yaml_loader(self._rel_path(file)) - if not vars: - vars = {} - return substitute_vars(result, vars) + return add_context(result, vars) @_add_data_ref def construct_include_dir_list(self, node: yaml.Node) -> list[dict[str, Any]]: @@ -495,39 +491,6 @@ def parse_yaml( ) -def substitute_vars(config, vars): - from esphome.components import substitutions - from esphome.const import CONF_SUBSTITUTIONS - - org_subs = None - result = config - if not isinstance(config, dict): - # when the included yaml contains a list or a scalar - # wrap it into an OrderedDict because do_substitution_pass expects it - result = OrderedDict([("yaml", config)]) - elif CONF_SUBSTITUTIONS in result: - org_subs = result.pop(CONF_SUBSTITUTIONS) - - defaults = {} - if CONF_DEFAULTS in result: - defaults = result.pop(CONF_DEFAULTS) - - result[CONF_SUBSTITUTIONS] = vars - for k, v in defaults.items(): - if k not in result[CONF_SUBSTITUTIONS]: - result[CONF_SUBSTITUTIONS][k] = v - - # Ignore missing vars that refer to the top level substitutions - substitutions.do_substitution_pass(result, None, ignore_missing=True) - result.pop(CONF_SUBSTITUTIONS) - - if not isinstance(config, dict): - result = result["yaml"] # unwrap the result - elif org_subs: - result[CONF_SUBSTITUTIONS] = org_subs - return result - - def _load_yaml_internal_with_type( loader_type: type[ESPHomeLoader] | type[ESPHomePurePythonLoader], fname: Path, diff --git a/tests/component_tests/packages/test_packages.py b/tests/component_tests/packages/test_packages.py index 22fb2c4e32..60dc0dccda 100644 --- a/tests/component_tests/packages/test_packages.py +++ b/tests/component_tests/packages/test_packages.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from esphome.components.packages import CONFIG_SCHEMA, do_packages_pass, merge_packages +from esphome.components.substitutions import do_substitution_pass import esphome.config as config_module from esphome.config import resolve_extend_remove from esphome.config_helpers import Extend, Remove @@ -71,6 +72,7 @@ def fixture_basic_esphome(): def packages_pass(config): """Wrapper around packages_pass that also resolves Extend and Remove.""" config = do_packages_pass(config) + config = do_substitution_pass(config) config = merge_packages(config) resolve_extend_remove(config) return config diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml index 9ed9b99c49..87f0e3fa21 100644 --- a/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.approved.yaml @@ -38,3 +38,20 @@ test_list: - '{ 79, 82 }' - a: 15 should be 15, overridden from command line b: 20 should stay as 20, not overridden + - aa: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + bb: + - 7 + - 8 + - 9 + - aa: + x: 1 + y: 3 + z: 4 + bb: + w: 5 diff --git a/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml index 64701c03dd..d70372f280 100644 --- a/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml +++ b/tests/unit_tests/fixtures/substitutions/00-simple_var.input.yaml @@ -44,3 +44,13 @@ test_list: - '{ ${position.x}, ${position.y} }' - a: ${a} should be 15, overridden from command line b: ${b} should stay as 20, not overridden + + # Test merging lists when substituted keys resolve to an existing key + - ${ "aa" }: [1, 2, 3] + ${ "a" + "a" }: [4, 5, 6] + ${ "bb" }: [7, 8, 9] + + # Test merging dicts when substituted keys resolve to an existing key + - ${ "aa" }: {"x": 1, "y": 2} + ${ "a" + "a" }: {"y": 3, "z": 4} + ${ "bb" }: {"w": 5} diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml index 1a51fc44cf..b8c76fbf52 100644 --- a/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.approved.yaml @@ -9,6 +9,11 @@ substitutions: numberOne: 1 var1: 79 double_width: 14 + double_height: 16 + y: ${x} + x: ${y} + b: 79 + c: 80 test_list: - The area is 56 - 56 @@ -27,3 +32,4 @@ test_list: - chr(97) = a - len([1,2,3]) = 3 - width = 7, double_width = 14 + - a = ${a} diff --git a/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml index 4612f581b5..9593867f49 100644 --- a/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml +++ b/tests/unit_tests/fixtures/substitutions/02-expressions.input.yaml @@ -1,4 +1,7 @@ substitutions: + y: ${x} # Circular reference, expect to pass unresolved. + x: ${y} # Circular reference, expect to pass unresolved. + double_height: ${height * 2} width: 7 height: 8 enabled: true @@ -9,6 +12,8 @@ substitutions: numberOne: 1 var1: 79 double_width: ${width * 2} + c: ${b+1} + b: ${undefined_variable | default(79) } test_list: - "The area is ${width * height}" @@ -25,3 +30,4 @@ test_list: - chr(97) = ${ chr(97) } - len([1,2,3]) = ${ len([1,2,3]) } - width = ${width}, double_width = ${double_width} + - a = ${a} diff --git a/tests/unit_tests/fixtures/substitutions/07-package_merging.approved.yaml b/tests/unit_tests/fixtures/substitutions/07-package_merging.approved.yaml new file mode 100644 index 0000000000..867889b7bc --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/07-package_merging.approved.yaml @@ -0,0 +1,46 @@ +fancy_component: &id001 + - id: component9 + value: 9 +some_component: + - id: component1 + value: 1 + - id: component2 + value: 2 + - id: component3 + value: 3 + - id: component4 + value: 4 + - id: component5 + value: 79 + power: 200 + - id: component6 + value: 6 + - id: component7 + value: 7 +switch: &id002 + - platform: gpio + id: switch1 + pin: 12 + - platform: gpio + id: switch2 + pin: 13 +display: + - platform: ili9xxx + dimensions: + width: 100 + height: 480 +substitutions: + extended_component: component5 + package_options: + alternative_package: + alternative_component: + - id: component8 + value: 8 + fancy_package: + substitutions: + fancy_subst: 42 + fancy_component: *id001 + pin: 12 + some_switches: *id002 + package_selection: fancy_package + fancy_subst: 42 diff --git a/tests/unit_tests/fixtures/substitutions/07-package_merging.input.yaml b/tests/unit_tests/fixtures/substitutions/07-package_merging.input.yaml new file mode 100644 index 0000000000..cc7b841aba --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/07-package_merging.input.yaml @@ -0,0 +1,63 @@ +substitutions: + package_options: + alternative_package: + alternative_component: + - id: component8 + value: 8 + fancy_package: + substitutions: + fancy_subst: 42 + fancy_component: + - id: component9 + value: 9 + + pin: 12 + some_switches: + - platform: gpio + id: switch1 + pin: ${pin} + - platform: gpio + id: switch2 + pin: ${pin+1} + + package_selection: fancy_package + +packages: + - ${ package_options[package_selection] } + - some_component: + - id: component1 + value: 1 + - some_component: + - id: component2 + value: 2 + - switch: ${ some_switches } + - packages: + package_with_defaults: !include + file: display.yaml + vars: + native_width: 100 + high_dpi: false + my_package: + packages: + - packages: + special_package: + substitutions: + extended_component: component5 + some_component: + - id: component3 + value: 3 + some_component: + - id: component4 + value: 4 + - id: !extend ${ extended_component } + power: 200 + value: 79 + some_component: + - id: component5 + value: 5 + +some_component: + - id: component6 + value: 6 + - id: component7 + value: 7 diff --git a/tests/unit_tests/fixtures/substitutions/09-include_vars_without_substs.approved.yaml b/tests/unit_tests/fixtures/substitutions/09-include_vars_without_substs.approved.yaml new file mode 100644 index 0000000000..4abaf4471d --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/09-include_vars_without_substs.approved.yaml @@ -0,0 +1,5 @@ +values: + - var1: $var1 + - a: 10 + - b: B-default + - c: The value of C is 79 diff --git a/tests/unit_tests/fixtures/substitutions/09-include_vars_without_substs.input.yaml b/tests/unit_tests/fixtures/substitutions/09-include_vars_without_substs.input.yaml new file mode 100644 index 0000000000..91eb0e9a3f --- /dev/null +++ b/tests/unit_tests/fixtures/substitutions/09-include_vars_without_substs.input.yaml @@ -0,0 +1,7 @@ +# Test that include_vars with vars works even when there are no substitutions key defined. +packages: + - !include + file: inc1.yaml + vars: + a: 10 + c: 79 diff --git a/tests/unit_tests/test_substitutions.py b/tests/unit_tests/test_substitutions.py index 1d8cb7631d..db46a27dfb 100644 --- a/tests/unit_tests/test_substitutions.py +++ b/tests/unit_tests/test_substitutions.py @@ -10,9 +10,10 @@ from esphome import config as config_module, yaml_util from esphome.components import substitutions from esphome.components.packages import do_packages_pass, merge_packages from esphome.config import resolve_extend_remove -from esphome.config_helpers import merge_config +from esphome.config_helpers import Extend, merge_config +import esphome.config_validation as cv from esphome.const import CONF_SUBSTITUTIONS -from esphome.core import CORE +from esphome.core import CORE, Lambda from esphome.util import OrderedDict _LOGGER = logging.getLogger(__name__) @@ -144,7 +145,7 @@ def test_substitutions_fixtures( config = do_packages_pass(config) - substitutions.do_substitution_pass(config, command_line_substitutions) + config = substitutions.do_substitution_pass(config, command_line_substitutions) config = merge_packages(config) @@ -206,7 +207,7 @@ def test_substitutions_with_command_line_maintains_ordered_dict() -> None: command_line_subs = {"var2": "override", "var3": "new_value"} # Call do_substitution_pass with command line substitutions - substitutions.do_substitution_pass(config, command_line_subs) + config = substitutions.do_substitution_pass(config, command_line_subs) # Verify that config is still an OrderedDict assert isinstance(config, OrderedDict), "Config should remain an OrderedDict" @@ -234,7 +235,7 @@ def test_substitutions_without_command_line_maintains_ordered_dict() -> None: config["other_key"] = "other_value" # Call without command line substitutions - substitutions.do_substitution_pass(config, None) + config = substitutions.do_substitution_pass(config, None) # Verify that config is still an OrderedDict assert isinstance(config, OrderedDict), "Config should remain an OrderedDict" @@ -268,7 +269,7 @@ def test_substitutions_after_merge_config_maintains_ordered_dict() -> None: ) # Now try to run substitution pass on the merged config - substitutions.do_substitution_pass(merged_config, None) + merged_config = substitutions.do_substitution_pass(merged_config, None) # Should not raise AttributeError assert isinstance(merged_config, OrderedDict), ( @@ -279,7 +280,7 @@ def test_substitutions_after_merge_config_maintains_ordered_dict() -> None: def test_validate_config_with_command_line_substitutions_maintains_ordered_dict( - tmp_path, + tmp_path: Path, ) -> None: """Test that validate_config preserves OrderedDict when merging command-line substitutions. @@ -288,7 +289,7 @@ def test_validate_config_with_command_line_substitutions_maintains_ordered_dict( """ # Create a minimal valid config test_config = OrderedDict() - test_config["esphome"] = {"name": "test_device", "platform": "ESP32"} + test_config["esphome"] = {"name": "test_device"} test_config[CONF_SUBSTITUTIONS] = OrderedDict({"var1": "value1", "var2": "value2"}) test_config["esp32"] = {"board": "esp32dev"} @@ -314,17 +315,11 @@ def test_validate_config_with_command_line_substitutions_maintains_ordered_dict( assert result[CONF_SUBSTITUTIONS]["var3"] == "new_value" -def test_validate_config_without_command_line_substitutions_maintains_ordered_dict( - tmp_path, -) -> None: - """Test that validate_config preserves OrderedDict without command-line substitutions. - - This tests the code path in config.py where result[CONF_SUBSTITUTIONS] is set - using merge_dicts_ordered() when command_line_substitutions is None. - """ +def _get_test_minimal_valid_config(tmp_path: Path) -> OrderedDict: + """Helper to create a minimal valid config for testing.""" # Create a minimal valid config test_config = OrderedDict() - test_config["esphome"] = {"name": "test_device", "platform": "ESP32"} + test_config["esphome"] = {"name": "test_device"} test_config[CONF_SUBSTITUTIONS] = OrderedDict({"var1": "value1", "var2": "value2"}) test_config["esp32"] = {"board": "esp32dev"} @@ -332,6 +327,19 @@ def test_validate_config_without_command_line_substitutions_maintains_ordered_di test_yaml = tmp_path / "test.yaml" test_yaml.write_text("# test config") CORE.config_path = test_yaml + return test_config + + +def test_validate_config_without_command_line_substitutions_maintains_ordered_dict( + tmp_path: Path, +) -> None: + """Test that validate_config preserves OrderedDict without command-line substitutions. + + This tests the code path in config.py where result[CONF_SUBSTITUTIONS] is set + using merge_dicts_ordered() when command_line_substitutions is None. + """ + + test_config = _get_test_minimal_valid_config(tmp_path) # Call validate_config without command line substitutions result = config_module.validate_config(test_config, None) @@ -384,3 +392,205 @@ def test_merge_config_preserves_ordered_dict() -> None: assert not isinstance(result, OrderedDict), ( "dict + dict should not return OrderedDict" ) + + +def test_substitution_pass_error_gets_captured( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """vol.Invalid from do_substitution_pass is captured by validate_config.""" + + # Patch the target: in config_module.do_substitution_pass (NOT where it's defined) + def fake_do_substitution_pass(*args, **kwargs): + raise cv.Invalid("Error in do_substitutions_pass!!") + + monkeypatch.setattr( + config_module, "do_substitution_pass", fake_do_substitution_pass + ) + + # Prepare minimal config + no CLI substitutions + config = _get_test_minimal_valid_config(tmp_path) + + # Call the function under test + result = config_module.validate_config(config, None) + + # Now assert that add_error was called with the vol.Invalid + + assert "Error in do_substitutions_pass!!" in str(result.get_error_for_path([])) + + +@pytest.mark.parametrize( + "value", ["", " ", "1foo", "9VAR", "0abc", "$1foo", "$9VAR", "$0abc"] +) +def test_validate_substitution_key_empty_raises(value: str) -> None: + """Empty (or all-whitespace) substitution keys are rejected.""" + with pytest.raises(cv.Invalid): + substitutions.validate_substitution_key(value) + + +@pytest.mark.parametrize( + "input_value, expected_output", + [ + ("$FOO_bar9", "FOO_bar9"), # Valid key with leading '$' + ("Foo_bar9", "Foo_bar9"), # Normal valid key + ], +) +def test_validate_substitution_key_valid( + input_value: str, expected_output: str +) -> None: + """Valid substitution keys are accepted with optional leading '$'.""" + result = substitutions.validate_substitution_key(input_value) + assert result == expected_output + + +def test_circular_dependency_warnings( + caplog: pytest.LogCaptureFixture, +) -> None: + """Circular substitution references produce warnings naming the cause.""" + config = OrderedDict( + { + CONF_SUBSTITUTIONS: OrderedDict({"x": "${y}", "y": "${x}"}), + "key": "value", + } + ) + with caplog.at_level(logging.WARNING): + substitutions.do_substitution_pass(config) + + assert "Could not resolve substitution variable 'x'" in caplog.text + assert "'y' is undefined" in caplog.text + assert "Could not resolve substitution variable 'y'" in caplog.text + assert "'x' is undefined" in caplog.text + # Verify path includes location + assert "substitutions->x" in caplog.text + assert "substitutions->y" in caplog.text + + +def test_missing_dependency_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + """A substitution referencing an undefined variable warns with the cause.""" + config = OrderedDict( + { + CONF_SUBSTITUTIONS: OrderedDict({"a": "${missing}"}), + "key": "value", + } + ) + with caplog.at_level(logging.WARNING): + substitutions.do_substitution_pass(config) + + assert "Could not resolve substitution variable 'a'" in caplog.text + assert "'missing' is undefined" in caplog.text + assert "substitutions->a" in caplog.text + + +def test_undefined_variable_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + """A reference to an undefined variable in config values produces a warning.""" + config = OrderedDict( + { + "key": "${undefined_var}", + } + ) + with caplog.at_level(logging.WARNING): + substitutions.do_substitution_pass(config) + + assert "'undefined_var' is undefined" in caplog.text + + +def test_password_field_warnings_suppressed( + caplog: pytest.LogCaptureFixture, +) -> None: + """Undefined variables in password fields should not produce warnings.""" + config = OrderedDict( + { + "password": "${undefined_var}", + } + ) + with caplog.at_level(logging.WARNING): + substitutions.do_substitution_pass(config) + + assert caplog.text == "" + + +def test_config_context_unresolvable_warns( + caplog: pytest.LogCaptureFixture, +) -> None: + """Unresolvable vars in a ConfigContext produce warnings via push_context.""" + inner = OrderedDict({"key": "${a}"}) + yaml_util.add_context(inner, {"a": "${undefined}"}) + config = OrderedDict({"items": [inner]}) + with caplog.at_level(logging.WARNING): + substitutions.do_substitution_pass(config) + + assert "Could not resolve substitution variable 'a'" in caplog.text + assert "'undefined' is undefined" in caplog.text + + +def test_non_string_substitution_value_warning( + caplog: pytest.LogCaptureFixture, +) -> None: + """Undefined vars in non-string contexts (e.g. dict keys) produce warnings.""" + config = OrderedDict( + { + "items": {"${undefined_key}": "value"}, + } + ) + with caplog.at_level(logging.WARNING): + substitutions.do_substitution_pass(config) + + assert "'undefined_key' is undefined" in caplog.text + + +def test_lambda_substitution() -> None: + """Substitution inside a Lambda value should be expanded.""" + lam = Lambda("return ${var};") + config = OrderedDict( + { + CONF_SUBSTITUTIONS: OrderedDict({"var": "42"}), + "lambda": lam, + } + ) + substitutions.do_substitution_pass(config) + assert lam.value == "return 42;" + + +def test_lambda_no_substitution_unchanged() -> None: + """A Lambda with no variable references should not be mutated.""" + lam = Lambda("return 1;") + original_value = lam.value + config = OrderedDict( + { + CONF_SUBSTITUTIONS: OrderedDict({"var": "42"}), + "lambda": lam, + } + ) + substitutions.do_substitution_pass(config) + assert lam.value is original_value + + +def test_extend_substitution() -> None: + """Substitution inside an Extend value should be expanded.""" + ext = Extend("${component_id}") + config = OrderedDict( + { + CONF_SUBSTITUTIONS: OrderedDict({"component_id": "my_sensor"}), + "sensor": ext, + } + ) + substitutions.do_substitution_pass(config) + assert ext.value == "my_sensor" + + +def test_do_substitution_pass_substitutions_must_be_mapping_from_config() -> None: + """Non-mapping substitutions raises cv.Invalid.""" + config = OrderedDict( + { + CONF_SUBSTITUTIONS: ["not", "a", "mapping"], + "other": "value", + } + ) + + with pytest.raises( + cv.Invalid, match="Substitutions must be a key to value mapping" + ): + substitutions.do_substitution_pass(config) diff --git a/tests/unit_tests/test_yaml_util.py b/tests/unit_tests/test_yaml_util.py index adb7658bfd..35a4bc3707 100644 --- a/tests/unit_tests/test_yaml_util.py +++ b/tests/unit_tests/test_yaml_util.py @@ -25,7 +25,7 @@ def test_include_with_vars(fixture_path: Path) -> None: yaml_file = fixture_path / "yaml_util" / "includetest.yaml" actual = yaml_util.load_yaml(yaml_file) - substitutions.do_substitution_pass(actual, None) + actual = substitutions.do_substitution_pass(actual, None) assert actual["esphome"]["name"] == "original" assert actual["esphome"]["libraries"][0] == "Wire" assert actual["esp8266"]["board"] == "nodemcu"