[core] Fix WiFi connection in safe mode (#16269)

Co-authored-by: J. Nick Koston <nick@home-assistant.io>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Mat931
2026-05-06 14:56:33 +00:00
committed by GitHub
parent 6e1a59da3e
commit 90693fb39a
7 changed files with 299 additions and 82 deletions
+11 -16
View File
@@ -568,14 +568,9 @@ async def _add_controller_registry_define() -> None:
@coroutine_with_priority(CoroPriority.FINAL)
async def _add_looping_components() -> None:
# Emit a constexpr that computes the looping component count at C++ compile time
# and pre-init the FixedVector with the exact capacity. Uses std::is_same_v to
# detect loop() overrides. The constexpr goes in main.cpp's global section where
# all component types are in scope. calculate_looping_components_() then skips
# the counting pass and only does the two population passes.
# Emit ESPHOME_LOOPING_COMPONENT_COUNT. Sizing of looping_components_
# happens in core to_code() so it lands before safe_mode's early return.
entries = CORE.data.get("looping_component_entries", [])
if not entries:
return
# Build constexpr sum for the exact count, deduplicating by type
# Uses HasLoopOverride<T> which handles ambiguous &T::loop from multiple inheritance
@@ -583,7 +578,7 @@ async def _add_looping_components() -> None:
terms = [
f"({count} * HasLoopOverride<{cpp_type}>::value)"
for cpp_type, count in type_counts.items()
]
] or ["0"]
constexpr_expr = " + \\\n ".join(terms)
cg.add_global(
cg.RawStatement(
@@ -592,14 +587,6 @@ async def _add_looping_components() -> None:
)
)
# Pre-init FixedVector with exact capacity so calculate_looping_components_()
# can skip the counting pass
cg.add(
cg.RawExpression(
"App.looping_components_.init(ESPHOME_LOOPING_COMPONENT_COUNT)"
)
)
@coroutine_with_priority(CoroPriority.CORE)
async def to_code(config: ConfigType) -> None:
@@ -642,6 +629,14 @@ async def to_code(config: ConfigType) -> None:
# Define component count for static allocation
cg.add_define("ESPHOME_COMPONENT_COUNT", len(CORE.component_ids))
# Pre-init FixedVector with exact capacity so calculate_looping_components_()
# can skip the counting pass
cg.add(
cg.RawExpression(
"App.looping_components_.init(ESPHOME_LOOPING_COMPONENT_COUNT)"
)
)
CORE.add_job(_add_platform_defines)
CORE.add_job(_add_controller_registry_define)
CORE.add_job(_add_looping_components)
+54 -57
View File
@@ -501,14 +501,15 @@ async def _read_stream_lines(
@asynccontextmanager
async def run_binary_and_wait_for_port(
async def run_binary(
binary_path: Path,
host: str,
port: int,
timeout: float = PORT_WAIT_TIMEOUT,
line_callback: Callable[[str], None] | None = None,
) -> AsyncGenerator[None]:
"""Run a binary, wait for it to open a port, and clean up on exit."""
) -> AsyncGenerator[tuple[asyncio.subprocess.Process, list[str]]]:
"""Run a binary under a PTY, capture log output, and clean up on exit.
Yields the running ``Process`` and a live list of captured log lines.
No port wait -- callers that need that should use
``run_binary_and_wait_for_port``."""
# Create a pseudo-terminal to make the binary think it's running interactively
# This is needed because the ESPHome host logger checks isatty()
controller_fd, device_fd = pty.openpty()
@@ -535,7 +536,6 @@ async def run_binary_and_wait_for_port(
controller_transport, _ = await loop.connect_read_pipe(
lambda: controller_protocol, os.fdopen(controller_fd, "rb", 0)
)
output_reader = controller_reader
if process.returncode is not None:
raise RuntimeError(
@@ -543,27 +543,59 @@ async def run_binary_and_wait_for_port(
"Ensure the binary is valid and can run successfully."
)
# Wait for the API server to start listening
loop = asyncio.get_running_loop()
start_time = loop.time()
# Start collecting output
stdout_lines: list[str] = []
output_tasks: list[asyncio.Task] = []
output_task = asyncio.create_task(
_read_stream_lines(controller_reader, stdout_lines, sys.stdout, line_callback)
)
try:
# Read from output stream
output_tasks = [
asyncio.create_task(
_read_stream_lines(
output_reader, stdout_lines, sys.stdout, line_callback
)
)
]
# Small yield to ensure the process has a chance to start
await asyncio.sleep(0)
yield process, stdout_lines
finally:
output_task.cancel()
result = await asyncio.gather(output_task, return_exceptions=True)
if isinstance(result[0], Exception) and not isinstance(
result[0], asyncio.CancelledError
):
print(f"Error reading from PTY: {result[0]}", file=sys.stderr)
# Close the PTY transport (Unix only)
if controller_transport is not None:
controller_transport.close()
# Cleanup: terminate the process gracefully
if process.returncode is None:
# Send SIGINT (Ctrl+C) for graceful shutdown
process.send_signal(signal.SIGINT)
try:
await asyncio.wait_for(process.wait(), timeout=SIGINT_TIMEOUT)
except TimeoutError:
# If SIGINT didn't work, try SIGTERM
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=SIGTERM_TIMEOUT)
except TimeoutError:
# Last resort: SIGKILL
process.kill()
await process.wait()
@asynccontextmanager
async def run_binary_and_wait_for_port(
binary_path: Path,
host: str,
port: int,
timeout: float = PORT_WAIT_TIMEOUT,
line_callback: Callable[[str], None] | None = None,
) -> AsyncGenerator[None]:
"""Run a binary, wait for it to open a port, and clean up on exit."""
async with run_binary(binary_path, line_callback=line_callback) as (
process,
stdout_lines,
):
loop = asyncio.get_running_loop()
start_time = loop.time()
while loop.time() - start_time < timeout:
try:
# Try to connect to the port
@@ -593,41 +625,6 @@ async def run_binary_and_wait_for_port(
raise TimeoutError(error_msg)
finally:
# Cancel output collection tasks
for task in output_tasks:
task.cancel()
# Wait for tasks to complete and check for exceptions
results = await asyncio.gather(*output_tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception) and not isinstance(
result, asyncio.CancelledError
):
print(
f"Error reading from PTY: {result}",
file=sys.stderr,
)
# Close the PTY transport (Unix only)
if controller_transport is not None:
controller_transport.close()
# Cleanup: terminate the process gracefully
if process.returncode is None:
# Send SIGINT (Ctrl+C) for graceful shutdown
process.send_signal(signal.SIGINT)
try:
await asyncio.wait_for(process.wait(), timeout=SIGINT_TIMEOUT)
except TimeoutError:
# If SIGINT didn't work, try SIGTERM
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=SIGTERM_TIMEOUT)
except TimeoutError:
# Last resort: SIGKILL
process.kill()
await process.wait()
@asynccontextmanager
async def run_compiled_context(
@@ -0,0 +1,25 @@
esphome:
name: safe-mode-loop-runs
host:
logger:
safe_mode:
num_attempts: 10
on_safe_mode:
- lambda: |-
// Spawn a detached thread that logs a unique marker. The
// non-main-thread log goes through the task log buffer, which
// is only drained by Logger::loop(). If looping components
// weren't initialized (the bug fixed in #16269), the buffer is
// never read and the marker never reaches the console.
struct MarkerThread {
static void *thread_func(void *) {
ESP_LOGI("safe_mode_test", "looping component ran in safe mode");
return nullptr;
}
};
pthread_t t;
pthread_create(&t, nullptr, MarkerThread::thread_func, nullptr);
pthread_detach(t);
+39
View File
@@ -0,0 +1,39 @@
"""Helpers for manipulating the host platform's preferences file.
ESPHome's host platform stores preferences in
``~/.esphome/prefs/<app_name>.prefs`` using a simple binary layout that
mirrors ``HostPreferences::sync()``:
``[uint32_t key][uint8_t len][uint8_t data[len]]`` per entry.
Tests use these helpers to pre-populate state the binary will see at
boot (e.g. forcing safe mode) or to clear stale state between runs.
"""
from __future__ import annotations
from pathlib import Path
import struct
def host_prefs_path(device_name: str) -> Path:
"""Return the on-disk prefs file path for a host-platform device."""
return Path.home() / ".esphome" / "prefs" / f"{device_name}.prefs"
def clear_host_prefs(device_name: str) -> None:
"""Delete the prefs file for a host-platform device, if it exists."""
host_prefs_path(device_name).unlink(missing_ok=True)
def write_host_pref(device_name: str, key: int, data: bytes) -> Path:
"""Write a single preference entry, replacing the file's contents.
Returns the path that was written.
"""
if len(data) > 255:
raise ValueError(f"Preference data too long: {len(data)} bytes (max 255)")
path = host_prefs_path(device_name)
path.parent.mkdir(parents=True, exist_ok=True)
payload = struct.pack("<IB", key, len(data)) + data
path.write_bytes(payload)
return path
@@ -0,0 +1,94 @@
"""Regression test for safe_mode + looping_components init ordering.
Reproduces the bug fixed in https://github.com/esphome/esphome/pull/16269:
``App.looping_components_.init(...)`` was emitted at ``CoroPriority.FINAL``,
which placed it *after* the ``safe_mode`` early-return in ``setup_app()``.
When safe mode was entered, the ``FixedVector`` backing the looping-component
list was never sized, ``looping_components_active_end_`` stayed at 0, and
``loop()`` iterated zero components -- so any looping component above
``CoroPriority.APPLICATION`` (e.g. wifi, logger) never ran.
The test forces safe mode by writing ``ENTER_SAFE_MODE_MAGIC`` to the host
preferences file before booting, then asserts that ``Logger::loop()`` runs
by logging from a non-main thread. Non-main-thread logs are buffered in
``TaskLogBuffer`` and only emitted to the console when ``Logger::loop()``
drains the buffer. Without the fix, the marker stays in the buffer
forever; with the fix, it reaches the console.
The API server (``CoroPriority.WEB``, 40) is registered below safe_mode
(``CoroPriority.APPLICATION``, 50), so it's never set up when safe mode
is active and ``run_compiled`` would hang waiting for the API port.
This test uses ``run_binary`` directly to skip the port wait.
"""
from __future__ import annotations
import asyncio
import re
import struct
import pytest
from .conftest import run_binary
from .host_prefs import clear_host_prefs, write_host_pref
from .types import CompileFunction, ConfigWriter
# Must match esphome::safe_mode::RTC_KEY in safe_mode.h
SAFE_MODE_RTC_KEY = 233825507
# Must match esphome::safe_mode::SafeModeComponent::ENTER_SAFE_MODE_MAGIC
ENTER_SAFE_MODE_MAGIC = 0x5AFE5AFE
DEVICE_NAME = "safe-mode-loop-runs"
THREAD_LOG_MARKER = "looping component ran in safe mode"
@pytest.mark.asyncio
async def test_safe_mode_loop_runs(
yaml_config: str,
write_yaml_config: ConfigWriter,
compile_esphome: CompileFunction,
) -> None:
"""When safe mode is active, ``App.loop()`` must still iterate looping
components -- proven here by a thread-logged marker reaching the
console (which requires ``Logger::loop()`` to run)."""
config_path = await write_yaml_config(yaml_config)
binary_path = await compile_esphome(config_path)
# Compile finished successfully; pre-populate prefs so the *next* run
# enters safe mode immediately.
write_host_pref(
DEVICE_NAME, SAFE_MODE_RTC_KEY, struct.pack("<I", ENTER_SAFE_MODE_MAGIC)
)
try:
loop = asyncio.get_running_loop()
safe_mode_active = loop.create_future()
thread_log_seen = loop.create_future()
safe_mode_pattern = re.compile(r"SAFE MODE IS ACTIVE")
thread_log_pattern = re.compile(re.escape(THREAD_LOG_MARKER))
def on_log(line: str) -> None:
if not safe_mode_active.done() and safe_mode_pattern.search(line):
safe_mode_active.set_result(True)
if not thread_log_seen.done() and thread_log_pattern.search(line):
thread_log_seen.set_result(True)
async with run_binary(binary_path, line_callback=on_log):
try:
await asyncio.wait_for(safe_mode_active, timeout=15.0)
except TimeoutError:
pytest.fail(
"Did not observe 'SAFE MODE IS ACTIVE' -- safe mode "
"didn't trigger, so this test isn't exercising the bug."
)
try:
await asyncio.wait_for(thread_log_seen, timeout=10.0)
except TimeoutError:
pytest.fail(
f"Did not observe thread-logged marker {THREAD_LOG_MARKER!r} "
"within timeout. Logger::loop() never drained the task "
"log buffer, meaning App.looping_components_ was never "
"sized -- this is the regression #16269 fixed."
)
finally:
clear_host_prefs(DEVICE_NAME)
+6 -9
View File
@@ -9,7 +9,6 @@ Tests that:
from __future__ import annotations
import asyncio
from pathlib import Path
import socket
from typing import Any
@@ -17,9 +16,12 @@ from aioesphomeapi import TextInfo, TextState
import pytest
from .conftest import run_binary_and_wait_for_port, wait_and_connect_api_client
from .host_prefs import clear_host_prefs
from .state_utils import InitialStateHelper, require_entity
from .types import CompileFunction, ConfigWriter
DEVICE_NAME = "host-template-text-save-test"
@pytest.mark.asyncio
async def test_template_text_save(
@@ -32,11 +34,7 @@ async def test_template_text_save(
port, port_socket = reserved_tcp_port
# Clean up any stale preference file from previous runs
prefs_file = (
Path.home() / ".esphome" / "prefs" / "host-template-text-save-test.prefs"
)
if prefs_file.exists():
prefs_file.unlink()
clear_host_prefs(DEVICE_NAME)
# Write and compile once
config_path = await write_yaml_config(yaml_config)
@@ -59,7 +57,7 @@ async def test_template_text_save(
wait_and_connect_api_client(port=port) as client,
):
device_info = await client.device_info()
assert device_info.name == "host-template-text-save-test"
assert device_info.name == DEVICE_NAME
entities, _ = await client.list_entities_services()
text_entity = require_entity(
@@ -127,5 +125,4 @@ async def test_template_text_save(
)
# Clean up preference file
if prefs_file.exists():
prefs_file.unlink()
clear_host_prefs(DEVICE_NAME)
+70
View File
@@ -10,6 +10,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from esphome import config_validation as cv, core
from esphome.components.safe_mode import to_code as safe_mode_to_code
from esphome.const import (
CONF_AREA,
CONF_AREAS,
@@ -312,6 +313,75 @@ def test_add_platform_defines_priority() -> None:
)
def test_to_code_priority_above_safe_mode() -> None:
"""Test that core to_code emits the looping_components_ init before safe_mode.
Regression test for https://github.com/esphome/esphome/issues/16262.
safe_mode emits an `if (should_enter_safe_mode(...)) return;` line in main()
at APPLICATION priority. The `App.looping_components_.init(...)` call must be
emitted at a higher priority than APPLICATION so it lands in main() before
the early return; otherwise the FixedVector is never sized when safe mode is
active and loop() never runs (Wi-Fi never connects).
"""
assert config.to_code.priority > safe_mode_to_code.priority, (
f"core to_code priority ({config.to_code.priority}) must be greater than "
f"safe_mode to_code priority ({safe_mode_to_code.priority}) so that "
"App.looping_components_.init() is emitted before safe_mode's early return"
)
@pytest.mark.asyncio
async def test_add_looping_components_handles_empty_entries() -> None:
"""Test that _add_looping_components emits a valid constexpr when there are
no looping component entries.
With zero entries the generated constexpr must still be syntactically valid
C++ (`= 0;`), not an empty expression (`= ;`). This guards the empty-list
case that would otherwise produce uncompilable main.cpp output.
"""
CORE.data["looping_component_entries"] = []
await config._add_looping_components()
constexpr_lines = [
str(s)
for s in CORE.global_statements
if "ESPHOME_LOOPING_COMPONENT_COUNT" in str(s)
]
assert len(constexpr_lines) == 1
text = constexpr_lines[0]
assert "static constexpr size_t ESPHOME_LOOPING_COMPONENT_COUNT" in text
# The right-hand side must contain a literal `0`, not be empty.
rhs = text.split("=", 1)[1]
assert "0" in rhs
assert rhs.strip().rstrip(";").strip(), (
f"constexpr right-hand side must not be empty, got: {text!r}"
)
@pytest.mark.asyncio
async def test_add_looping_components_with_entries() -> None:
"""Test that _add_looping_components builds a HasLoopOverride sum from entries."""
CORE.data["looping_component_entries"] = [
"esphome::wifi::WiFiComponent",
"esphome::logger::Logger",
"esphome::wifi::WiFiComponent",
]
await config._add_looping_components()
constexpr_lines = [
str(s)
for s in CORE.global_statements
if "ESPHOME_LOOPING_COMPONENT_COUNT" in str(s)
]
assert len(constexpr_lines) == 1
text = constexpr_lines[0]
# Deduplicated by type, with per-type counts as multiplier.
assert "(2 * HasLoopOverride<esphome::wifi::WiFiComponent>::value)" in text
assert "(1 * HasLoopOverride<esphome::logger::Logger>::value)" in text
def test_valid_include_with_angle_brackets() -> None:
"""Test valid_include accepts angle bracket includes."""
assert valid_include("<ArduinoJson.h>") == "<ArduinoJson.h>"