mirror of
https://github.com/esphome/esphome.git
synced 2026-05-10 05:37:55 +08:00
[core] Fix WiFi connection in safe mode (#16269)
Co-authored-by: J. Nick Koston <nick@home-assistant.io> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
+11
-16
@@ -568,14 +568,9 @@ async def _add_controller_registry_define() -> None:
|
||||
|
||||
@coroutine_with_priority(CoroPriority.FINAL)
|
||||
async def _add_looping_components() -> None:
|
||||
# Emit a constexpr that computes the looping component count at C++ compile time
|
||||
# and pre-init the FixedVector with the exact capacity. Uses std::is_same_v to
|
||||
# detect loop() overrides. The constexpr goes in main.cpp's global section where
|
||||
# all component types are in scope. calculate_looping_components_() then skips
|
||||
# the counting pass and only does the two population passes.
|
||||
# Emit ESPHOME_LOOPING_COMPONENT_COUNT. Sizing of looping_components_
|
||||
# happens in core to_code() so it lands before safe_mode's early return.
|
||||
entries = CORE.data.get("looping_component_entries", [])
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# Build constexpr sum for the exact count, deduplicating by type
|
||||
# Uses HasLoopOverride<T> which handles ambiguous &T::loop from multiple inheritance
|
||||
@@ -583,7 +578,7 @@ async def _add_looping_components() -> None:
|
||||
terms = [
|
||||
f"({count} * HasLoopOverride<{cpp_type}>::value)"
|
||||
for cpp_type, count in type_counts.items()
|
||||
]
|
||||
] or ["0"]
|
||||
constexpr_expr = " + \\\n ".join(terms)
|
||||
cg.add_global(
|
||||
cg.RawStatement(
|
||||
@@ -592,14 +587,6 @@ async def _add_looping_components() -> None:
|
||||
)
|
||||
)
|
||||
|
||||
# Pre-init FixedVector with exact capacity so calculate_looping_components_()
|
||||
# can skip the counting pass
|
||||
cg.add(
|
||||
cg.RawExpression(
|
||||
"App.looping_components_.init(ESPHOME_LOOPING_COMPONENT_COUNT)"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@coroutine_with_priority(CoroPriority.CORE)
|
||||
async def to_code(config: ConfigType) -> None:
|
||||
@@ -642,6 +629,14 @@ async def to_code(config: ConfigType) -> None:
|
||||
# Define component count for static allocation
|
||||
cg.add_define("ESPHOME_COMPONENT_COUNT", len(CORE.component_ids))
|
||||
|
||||
# Pre-init FixedVector with exact capacity so calculate_looping_components_()
|
||||
# can skip the counting pass
|
||||
cg.add(
|
||||
cg.RawExpression(
|
||||
"App.looping_components_.init(ESPHOME_LOOPING_COMPONENT_COUNT)"
|
||||
)
|
||||
)
|
||||
|
||||
CORE.add_job(_add_platform_defines)
|
||||
CORE.add_job(_add_controller_registry_define)
|
||||
CORE.add_job(_add_looping_components)
|
||||
|
||||
@@ -501,14 +501,15 @@ async def _read_stream_lines(
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_binary_and_wait_for_port(
|
||||
async def run_binary(
|
||||
binary_path: Path,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: float = PORT_WAIT_TIMEOUT,
|
||||
line_callback: Callable[[str], None] | None = None,
|
||||
) -> AsyncGenerator[None]:
|
||||
"""Run a binary, wait for it to open a port, and clean up on exit."""
|
||||
) -> AsyncGenerator[tuple[asyncio.subprocess.Process, list[str]]]:
|
||||
"""Run a binary under a PTY, capture log output, and clean up on exit.
|
||||
|
||||
Yields the running ``Process`` and a live list of captured log lines.
|
||||
No port wait -- callers that need that should use
|
||||
``run_binary_and_wait_for_port``."""
|
||||
# Create a pseudo-terminal to make the binary think it's running interactively
|
||||
# This is needed because the ESPHome host logger checks isatty()
|
||||
controller_fd, device_fd = pty.openpty()
|
||||
@@ -535,7 +536,6 @@ async def run_binary_and_wait_for_port(
|
||||
controller_transport, _ = await loop.connect_read_pipe(
|
||||
lambda: controller_protocol, os.fdopen(controller_fd, "rb", 0)
|
||||
)
|
||||
output_reader = controller_reader
|
||||
|
||||
if process.returncode is not None:
|
||||
raise RuntimeError(
|
||||
@@ -543,27 +543,59 @@ async def run_binary_and_wait_for_port(
|
||||
"Ensure the binary is valid and can run successfully."
|
||||
)
|
||||
|
||||
# Wait for the API server to start listening
|
||||
loop = asyncio.get_running_loop()
|
||||
start_time = loop.time()
|
||||
|
||||
# Start collecting output
|
||||
stdout_lines: list[str] = []
|
||||
output_tasks: list[asyncio.Task] = []
|
||||
output_task = asyncio.create_task(
|
||||
_read_stream_lines(controller_reader, stdout_lines, sys.stdout, line_callback)
|
||||
)
|
||||
|
||||
try:
|
||||
# Read from output stream
|
||||
output_tasks = [
|
||||
asyncio.create_task(
|
||||
_read_stream_lines(
|
||||
output_reader, stdout_lines, sys.stdout, line_callback
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Small yield to ensure the process has a chance to start
|
||||
await asyncio.sleep(0)
|
||||
yield process, stdout_lines
|
||||
finally:
|
||||
output_task.cancel()
|
||||
result = await asyncio.gather(output_task, return_exceptions=True)
|
||||
if isinstance(result[0], Exception) and not isinstance(
|
||||
result[0], asyncio.CancelledError
|
||||
):
|
||||
print(f"Error reading from PTY: {result[0]}", file=sys.stderr)
|
||||
|
||||
# Close the PTY transport (Unix only)
|
||||
if controller_transport is not None:
|
||||
controller_transport.close()
|
||||
|
||||
# Cleanup: terminate the process gracefully
|
||||
if process.returncode is None:
|
||||
# Send SIGINT (Ctrl+C) for graceful shutdown
|
||||
process.send_signal(signal.SIGINT)
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=SIGINT_TIMEOUT)
|
||||
except TimeoutError:
|
||||
# If SIGINT didn't work, try SIGTERM
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=SIGTERM_TIMEOUT)
|
||||
except TimeoutError:
|
||||
# Last resort: SIGKILL
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_binary_and_wait_for_port(
|
||||
binary_path: Path,
|
||||
host: str,
|
||||
port: int,
|
||||
timeout: float = PORT_WAIT_TIMEOUT,
|
||||
line_callback: Callable[[str], None] | None = None,
|
||||
) -> AsyncGenerator[None]:
|
||||
"""Run a binary, wait for it to open a port, and clean up on exit."""
|
||||
async with run_binary(binary_path, line_callback=line_callback) as (
|
||||
process,
|
||||
stdout_lines,
|
||||
):
|
||||
loop = asyncio.get_running_loop()
|
||||
start_time = loop.time()
|
||||
while loop.time() - start_time < timeout:
|
||||
try:
|
||||
# Try to connect to the port
|
||||
@@ -593,41 +625,6 @@ async def run_binary_and_wait_for_port(
|
||||
|
||||
raise TimeoutError(error_msg)
|
||||
|
||||
finally:
|
||||
# Cancel output collection tasks
|
||||
for task in output_tasks:
|
||||
task.cancel()
|
||||
# Wait for tasks to complete and check for exceptions
|
||||
results = await asyncio.gather(*output_tasks, return_exceptions=True)
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception) and not isinstance(
|
||||
result, asyncio.CancelledError
|
||||
):
|
||||
print(
|
||||
f"Error reading from PTY: {result}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Close the PTY transport (Unix only)
|
||||
if controller_transport is not None:
|
||||
controller_transport.close()
|
||||
|
||||
# Cleanup: terminate the process gracefully
|
||||
if process.returncode is None:
|
||||
# Send SIGINT (Ctrl+C) for graceful shutdown
|
||||
process.send_signal(signal.SIGINT)
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=SIGINT_TIMEOUT)
|
||||
except TimeoutError:
|
||||
# If SIGINT didn't work, try SIGTERM
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=SIGTERM_TIMEOUT)
|
||||
except TimeoutError:
|
||||
# Last resort: SIGKILL
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_compiled_context(
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
esphome:
|
||||
name: safe-mode-loop-runs
|
||||
|
||||
host:
|
||||
|
||||
logger:
|
||||
|
||||
safe_mode:
|
||||
num_attempts: 10
|
||||
on_safe_mode:
|
||||
- lambda: |-
|
||||
// Spawn a detached thread that logs a unique marker. The
|
||||
// non-main-thread log goes through the task log buffer, which
|
||||
// is only drained by Logger::loop(). If looping components
|
||||
// weren't initialized (the bug fixed in #16269), the buffer is
|
||||
// never read and the marker never reaches the console.
|
||||
struct MarkerThread {
|
||||
static void *thread_func(void *) {
|
||||
ESP_LOGI("safe_mode_test", "looping component ran in safe mode");
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
pthread_t t;
|
||||
pthread_create(&t, nullptr, MarkerThread::thread_func, nullptr);
|
||||
pthread_detach(t);
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Helpers for manipulating the host platform's preferences file.
|
||||
|
||||
ESPHome's host platform stores preferences in
|
||||
``~/.esphome/prefs/<app_name>.prefs`` using a simple binary layout that
|
||||
mirrors ``HostPreferences::sync()``:
|
||||
``[uint32_t key][uint8_t len][uint8_t data[len]]`` per entry.
|
||||
|
||||
Tests use these helpers to pre-populate state the binary will see at
|
||||
boot (e.g. forcing safe mode) or to clear stale state between runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import struct
|
||||
|
||||
|
||||
def host_prefs_path(device_name: str) -> Path:
|
||||
"""Return the on-disk prefs file path for a host-platform device."""
|
||||
return Path.home() / ".esphome" / "prefs" / f"{device_name}.prefs"
|
||||
|
||||
|
||||
def clear_host_prefs(device_name: str) -> None:
|
||||
"""Delete the prefs file for a host-platform device, if it exists."""
|
||||
host_prefs_path(device_name).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def write_host_pref(device_name: str, key: int, data: bytes) -> Path:
|
||||
"""Write a single preference entry, replacing the file's contents.
|
||||
|
||||
Returns the path that was written.
|
||||
"""
|
||||
if len(data) > 255:
|
||||
raise ValueError(f"Preference data too long: {len(data)} bytes (max 255)")
|
||||
path = host_prefs_path(device_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = struct.pack("<IB", key, len(data)) + data
|
||||
path.write_bytes(payload)
|
||||
return path
|
||||
@@ -0,0 +1,94 @@
|
||||
"""Regression test for safe_mode + looping_components init ordering.
|
||||
|
||||
Reproduces the bug fixed in https://github.com/esphome/esphome/pull/16269:
|
||||
``App.looping_components_.init(...)`` was emitted at ``CoroPriority.FINAL``,
|
||||
which placed it *after* the ``safe_mode`` early-return in ``setup_app()``.
|
||||
When safe mode was entered, the ``FixedVector`` backing the looping-component
|
||||
list was never sized, ``looping_components_active_end_`` stayed at 0, and
|
||||
``loop()`` iterated zero components -- so any looping component above
|
||||
``CoroPriority.APPLICATION`` (e.g. wifi, logger) never ran.
|
||||
|
||||
The test forces safe mode by writing ``ENTER_SAFE_MODE_MAGIC`` to the host
|
||||
preferences file before booting, then asserts that ``Logger::loop()`` runs
|
||||
by logging from a non-main thread. Non-main-thread logs are buffered in
|
||||
``TaskLogBuffer`` and only emitted to the console when ``Logger::loop()``
|
||||
drains the buffer. Without the fix, the marker stays in the buffer
|
||||
forever; with the fix, it reaches the console.
|
||||
|
||||
The API server (``CoroPriority.WEB``, 40) is registered below safe_mode
|
||||
(``CoroPriority.APPLICATION``, 50), so it's never set up when safe mode
|
||||
is active and ``run_compiled`` would hang waiting for the API port.
|
||||
This test uses ``run_binary`` directly to skip the port wait.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import struct
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_binary
|
||||
from .host_prefs import clear_host_prefs, write_host_pref
|
||||
from .types import CompileFunction, ConfigWriter
|
||||
|
||||
# Must match esphome::safe_mode::RTC_KEY in safe_mode.h
|
||||
SAFE_MODE_RTC_KEY = 233825507
|
||||
# Must match esphome::safe_mode::SafeModeComponent::ENTER_SAFE_MODE_MAGIC
|
||||
ENTER_SAFE_MODE_MAGIC = 0x5AFE5AFE
|
||||
|
||||
DEVICE_NAME = "safe-mode-loop-runs"
|
||||
THREAD_LOG_MARKER = "looping component ran in safe mode"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_mode_loop_runs(
|
||||
yaml_config: str,
|
||||
write_yaml_config: ConfigWriter,
|
||||
compile_esphome: CompileFunction,
|
||||
) -> None:
|
||||
"""When safe mode is active, ``App.loop()`` must still iterate looping
|
||||
components -- proven here by a thread-logged marker reaching the
|
||||
console (which requires ``Logger::loop()`` to run)."""
|
||||
config_path = await write_yaml_config(yaml_config)
|
||||
binary_path = await compile_esphome(config_path)
|
||||
|
||||
# Compile finished successfully; pre-populate prefs so the *next* run
|
||||
# enters safe mode immediately.
|
||||
write_host_pref(
|
||||
DEVICE_NAME, SAFE_MODE_RTC_KEY, struct.pack("<I", ENTER_SAFE_MODE_MAGIC)
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
safe_mode_active = loop.create_future()
|
||||
thread_log_seen = loop.create_future()
|
||||
safe_mode_pattern = re.compile(r"SAFE MODE IS ACTIVE")
|
||||
thread_log_pattern = re.compile(re.escape(THREAD_LOG_MARKER))
|
||||
|
||||
def on_log(line: str) -> None:
|
||||
if not safe_mode_active.done() and safe_mode_pattern.search(line):
|
||||
safe_mode_active.set_result(True)
|
||||
if not thread_log_seen.done() and thread_log_pattern.search(line):
|
||||
thread_log_seen.set_result(True)
|
||||
|
||||
async with run_binary(binary_path, line_callback=on_log):
|
||||
try:
|
||||
await asyncio.wait_for(safe_mode_active, timeout=15.0)
|
||||
except TimeoutError:
|
||||
pytest.fail(
|
||||
"Did not observe 'SAFE MODE IS ACTIVE' -- safe mode "
|
||||
"didn't trigger, so this test isn't exercising the bug."
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(thread_log_seen, timeout=10.0)
|
||||
except TimeoutError:
|
||||
pytest.fail(
|
||||
f"Did not observe thread-logged marker {THREAD_LOG_MARKER!r} "
|
||||
"within timeout. Logger::loop() never drained the task "
|
||||
"log buffer, meaning App.looping_components_ was never "
|
||||
"sized -- this is the regression #16269 fixed."
|
||||
)
|
||||
finally:
|
||||
clear_host_prefs(DEVICE_NAME)
|
||||
@@ -9,7 +9,6 @@ Tests that:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import socket
|
||||
from typing import Any
|
||||
|
||||
@@ -17,9 +16,12 @@ from aioesphomeapi import TextInfo, TextState
|
||||
import pytest
|
||||
|
||||
from .conftest import run_binary_and_wait_for_port, wait_and_connect_api_client
|
||||
from .host_prefs import clear_host_prefs
|
||||
from .state_utils import InitialStateHelper, require_entity
|
||||
from .types import CompileFunction, ConfigWriter
|
||||
|
||||
DEVICE_NAME = "host-template-text-save-test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_text_save(
|
||||
@@ -32,11 +34,7 @@ async def test_template_text_save(
|
||||
port, port_socket = reserved_tcp_port
|
||||
|
||||
# Clean up any stale preference file from previous runs
|
||||
prefs_file = (
|
||||
Path.home() / ".esphome" / "prefs" / "host-template-text-save-test.prefs"
|
||||
)
|
||||
if prefs_file.exists():
|
||||
prefs_file.unlink()
|
||||
clear_host_prefs(DEVICE_NAME)
|
||||
|
||||
# Write and compile once
|
||||
config_path = await write_yaml_config(yaml_config)
|
||||
@@ -59,7 +57,7 @@ async def test_template_text_save(
|
||||
wait_and_connect_api_client(port=port) as client,
|
||||
):
|
||||
device_info = await client.device_info()
|
||||
assert device_info.name == "host-template-text-save-test"
|
||||
assert device_info.name == DEVICE_NAME
|
||||
|
||||
entities, _ = await client.list_entities_services()
|
||||
text_entity = require_entity(
|
||||
@@ -127,5 +125,4 @@ async def test_template_text_save(
|
||||
)
|
||||
|
||||
# Clean up preference file
|
||||
if prefs_file.exists():
|
||||
prefs_file.unlink()
|
||||
clear_host_prefs(DEVICE_NAME)
|
||||
|
||||
@@ -10,6 +10,7 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
import pytest
|
||||
|
||||
from esphome import config_validation as cv, core
|
||||
from esphome.components.safe_mode import to_code as safe_mode_to_code
|
||||
from esphome.const import (
|
||||
CONF_AREA,
|
||||
CONF_AREAS,
|
||||
@@ -312,6 +313,75 @@ def test_add_platform_defines_priority() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_to_code_priority_above_safe_mode() -> None:
|
||||
"""Test that core to_code emits the looping_components_ init before safe_mode.
|
||||
|
||||
Regression test for https://github.com/esphome/esphome/issues/16262.
|
||||
safe_mode emits an `if (should_enter_safe_mode(...)) return;` line in main()
|
||||
at APPLICATION priority. The `App.looping_components_.init(...)` call must be
|
||||
emitted at a higher priority than APPLICATION so it lands in main() before
|
||||
the early return; otherwise the FixedVector is never sized when safe mode is
|
||||
active and loop() never runs (Wi-Fi never connects).
|
||||
"""
|
||||
assert config.to_code.priority > safe_mode_to_code.priority, (
|
||||
f"core to_code priority ({config.to_code.priority}) must be greater than "
|
||||
f"safe_mode to_code priority ({safe_mode_to_code.priority}) so that "
|
||||
"App.looping_components_.init() is emitted before safe_mode's early return"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_looping_components_handles_empty_entries() -> None:
|
||||
"""Test that _add_looping_components emits a valid constexpr when there are
|
||||
no looping component entries.
|
||||
|
||||
With zero entries the generated constexpr must still be syntactically valid
|
||||
C++ (`= 0;`), not an empty expression (`= ;`). This guards the empty-list
|
||||
case that would otherwise produce uncompilable main.cpp output.
|
||||
"""
|
||||
CORE.data["looping_component_entries"] = []
|
||||
|
||||
await config._add_looping_components()
|
||||
|
||||
constexpr_lines = [
|
||||
str(s)
|
||||
for s in CORE.global_statements
|
||||
if "ESPHOME_LOOPING_COMPONENT_COUNT" in str(s)
|
||||
]
|
||||
assert len(constexpr_lines) == 1
|
||||
text = constexpr_lines[0]
|
||||
assert "static constexpr size_t ESPHOME_LOOPING_COMPONENT_COUNT" in text
|
||||
# The right-hand side must contain a literal `0`, not be empty.
|
||||
rhs = text.split("=", 1)[1]
|
||||
assert "0" in rhs
|
||||
assert rhs.strip().rstrip(";").strip(), (
|
||||
f"constexpr right-hand side must not be empty, got: {text!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_looping_components_with_entries() -> None:
|
||||
"""Test that _add_looping_components builds a HasLoopOverride sum from entries."""
|
||||
CORE.data["looping_component_entries"] = [
|
||||
"esphome::wifi::WiFiComponent",
|
||||
"esphome::logger::Logger",
|
||||
"esphome::wifi::WiFiComponent",
|
||||
]
|
||||
|
||||
await config._add_looping_components()
|
||||
|
||||
constexpr_lines = [
|
||||
str(s)
|
||||
for s in CORE.global_statements
|
||||
if "ESPHOME_LOOPING_COMPONENT_COUNT" in str(s)
|
||||
]
|
||||
assert len(constexpr_lines) == 1
|
||||
text = constexpr_lines[0]
|
||||
# Deduplicated by type, with per-type counts as multiplier.
|
||||
assert "(2 * HasLoopOverride<esphome::wifi::WiFiComponent>::value)" in text
|
||||
assert "(1 * HasLoopOverride<esphome::logger::Logger>::value)" in text
|
||||
|
||||
|
||||
def test_valid_include_with_angle_brackets() -> None:
|
||||
"""Test valid_include accepts angle bracket includes."""
|
||||
assert valid_include("<ArduinoJson.h>") == "<ArduinoJson.h>"
|
||||
|
||||
Reference in New Issue
Block a user