[core] Allow finding all devices as target that match mac suffix (#13135)

This commit is contained in:
Paulus Schoutsen
2026-04-23 09:43:32 -04:00
committed by GitHub
parent 70ae614abd
commit 9b45b046a8
8 changed files with 912 additions and 70 deletions
+108 -12
View File
@@ -39,6 +39,7 @@ from esphome.const import (
CONF_MDNS, CONF_MDNS,
CONF_MQTT, CONF_MQTT,
CONF_NAME, CONF_NAME,
CONF_NAME_ADD_MAC_SUFFIX,
CONF_OTA, CONF_OTA,
CONF_PASSWORD, CONF_PASSWORD,
CONF_PLATFORM, CONF_PLATFORM,
@@ -71,6 +72,7 @@ from esphome.util import (
run_external_process, run_external_process,
safe_print, safe_print,
) )
from esphome.zeroconf import discover_mdns_devices
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -204,6 +206,64 @@ def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
return [address] return [address]
def _populate_mdns_cache(hosts_to_addresses: dict[str, list[str]]) -> None:
"""Store discovered ``host -> [ips]`` entries in ``CORE.address_cache``.
Ensures ``CORE.address_cache`` exists, then records each mDNS hostname so
the downstream resolution path (``resolve_ip_address``) can skip opening a
second Zeroconf client.
"""
from esphome.address_cache import AddressCache
if CORE.address_cache is None:
CORE.address_cache = AddressCache()
for host, addresses in hosts_to_addresses.items():
if addresses:
_LOGGER.debug("Caching mDNS result %s -> %s", host, addresses)
CORE.address_cache.add_mdns_addresses(host, addresses)
def _discover_mac_suffix_devices() -> list[str] | None:
"""Discover ``<name>-<mac>.local`` devices and cache their IPs.
Returns:
- ``None`` when discovery isn't applicable (``name_add_mac_suffix`` off,
mDNS disabled, or ``CORE.address`` is already an IP). Callers should
then fall back to whatever default OTA address they normally use.
- ``[]`` when discovery ran but found nothing. Callers should NOT fall
back to the base name: with ``name_add_mac_suffix`` enabled, the base
name by definition doesn't exist on the network.
- A non-empty sorted list of ``.local`` hostnames on success.
Populates ``CORE.address_cache`` so downstream resolution (``espota2`` or
``aioesphomeapi`` via :func:`_resolve_network_devices`) reuses the IPs we
already have without opening a second Zeroconf client.
"""
if not (has_name_add_mac_suffix() and has_mdns() and has_non_ip_address()):
return None
_LOGGER.info("Discovering devices...")
if not (discovered := discover_mdns_devices(CORE.name)):
_LOGGER.warning(
"No devices matching '%s-<mac>.local' were discovered.", CORE.name
)
return []
_populate_mdns_cache(discovered)
return list(discovered)
def _ota_hostnames_for_default(purpose: Purpose) -> list[str]:
"""Return OTA hostname(s) for the ``--device OTA`` / default-resolve path.
When ``name_add_mac_suffix`` is enabled, returns discovered
``<name>-<mac>.local`` hostnames (possibly empty — in which case the
caller should not fall back to the base name). Otherwise falls back to
the cache-resolved ``CORE.address``.
"""
if (discovered := _discover_mac_suffix_devices()) is not None:
return discovered
return _resolve_with_cache(CORE.address, purpose)
def choose_upload_log_host( def choose_upload_log_host(
default: list[str] | str | None, default: list[str] | str | None,
check_default: str | None, check_default: str | None,
@@ -242,14 +302,14 @@ def choose_upload_log_host(
resolved.append("MQTT") resolved.append("MQTT")
if has_api() and has_non_ip_address() and has_resolvable_address(): if has_api() and has_non_ip_address() and has_resolvable_address():
resolved.extend(_resolve_with_cache(CORE.address, purpose)) resolved.extend(_ota_hostnames_for_default(purpose))
elif purpose == Purpose.UPLOADING: elif purpose == Purpose.UPLOADING:
if has_ota() and has_mqtt_ip_lookup(): if has_ota() and has_mqtt_ip_lookup():
resolved.append("MQTTIP") resolved.append("MQTTIP")
if has_ota() and has_non_ip_address() and has_resolvable_address(): if has_ota() and has_non_ip_address() and has_resolvable_address():
resolved.extend(_resolve_with_cache(CORE.address, purpose)) resolved.extend(_ota_hostnames_for_default(purpose))
else: else:
resolved.append(device) resolved.append(device)
if not resolved: if not resolved:
@@ -281,22 +341,29 @@ def choose_upload_log_host(
elif bootsel.permission_error: elif bootsel.permission_error:
bootsel_permission_error = True bootsel_permission_error = True
def add_ota_options() -> None:
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
if (discovered := _discover_mac_suffix_devices()) is not None:
# Discovery was applicable. Use whatever we found — on empty,
# intentionally skip the base-name fallback since with
# name_add_mac_suffix on, the base name doesn't exist on the net.
for host in discovered:
options.append((f"Over The Air ({host})", host))
elif has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
if purpose == Purpose.LOGGING: if purpose == Purpose.LOGGING:
if has_mqtt_logging(): if has_mqtt_logging():
mqtt_config = CORE.config[CONF_MQTT] mqtt_config = CORE.config[CONF_MQTT]
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
if has_api(): if has_api():
if has_resolvable_address(): add_ota_options()
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
elif purpose == Purpose.UPLOADING and has_ota(): elif purpose == Purpose.UPLOADING and has_ota():
if has_resolvable_address(): add_ota_options()
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
# Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found # Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found
if ( if (
@@ -407,7 +474,17 @@ def has_resolvable_address() -> bool:
return not CORE.address.endswith(".local") return not CORE.address.endswith(".local")
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str): def has_name_add_mac_suffix() -> bool:
"""Check if name_add_mac_suffix is enabled in the config."""
if CORE.config is None:
return False
esphome_config = CORE.config.get(CONF_ESPHOME, {})
return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False)
def mqtt_get_ip(
config: ConfigType, username: str, password: str, client_id: str
) -> list[str]:
from esphome import mqtt from esphome import mqtt
return mqtt.get_esphome_device_ip(config, username, password, client_id) return mqtt.get_esphome_device_ip(config, username, password, client_id)
@@ -420,6 +497,9 @@ def _resolve_network_devices(
This function filters the devices list to: This function filters the devices list to:
- Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup - Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup
- Expand hostnames that are already in ``CORE.address_cache`` to their
cached IPs so downstream code (e.g. aioesphomeapi) doesn't open a second
Zeroconf client to resolve them
- Deduplicate addresses while preserving order - Deduplicate addresses while preserving order
- Only resolve MQTT once even if multiple MQTT strings are present - Only resolve MQTT once even if multiple MQTT strings are present
- If MQTT resolution fails, log a warning and continue with other devices - If MQTT resolution fails, log a warning and continue with other devices
@@ -444,13 +524,29 @@ def _resolve_network_devices(
mqtt_ips = mqtt_get_ip( mqtt_ips = mqtt_get_ip(
config, args.username, args.password, args.client_id config, args.username, args.password, args.client_id
) )
network_devices.extend(mqtt_ips) # pylint can't infer mqtt_get_ip's return through its
# lazy ``from esphome import mqtt`` import, so it flags
# the genexpr below.
network_devices.extend(
addr
for addr in mqtt_ips # pylint: disable=not-an-iterable
if addr not in network_devices
)
except EsphomeError as err: except EsphomeError as err:
_LOGGER.warning( _LOGGER.warning(
"MQTT IP discovery failed (%s), will try other devices if available", "MQTT IP discovery failed (%s), will try other devices if available",
err, err,
) )
mqtt_resolved = True mqtt_resolved = True
continue
# If the hostname is already in the address cache (e.g. populated by
# mDNS discovery), substitute the cached IPs so aioesphomeapi doesn't
# open its own Zeroconf to re-resolve it.
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(device)):
network_devices.extend(
addr for addr in cached if addr not in network_devices
)
elif device not in network_devices: elif device not in network_devices:
# Regular network address or IP - add if not already present # Regular network address or IP - add if not already present
network_devices.append(device) network_devices.append(device)
+11
View File
@@ -101,6 +101,17 @@ class AddressCache:
"""Check if any cache entries exist.""" """Check if any cache entries exist."""
return bool(self.mdns_cache or self.dns_cache) return bool(self.mdns_cache or self.dns_cache)
def add_mdns_addresses(self, hostname: str, addresses: list[str]) -> None:
"""Store resolved mDNS addresses for ``hostname`` in the cache.
Callers that discover ``.local`` hosts (e.g. via mDNS browse) can use
this to avoid a second resolution round-trip during the upload path.
No-op when ``addresses`` is empty.
"""
if not addresses:
return
self.mdns_cache[normalize_hostname(hostname)] = addresses
@classmethod @classmethod
def from_cli_args( def from_cli_args(
cls, mdns_args: Iterable[str], dns_args: Iterable[str] cls, mdns_args: Iterable[str], dns_args: Iterable[str]
+56
View File
@@ -0,0 +1,56 @@
"""Helpers for running an async coroutine from sync code via a daemon thread.
``asyncio.run(coro())`` in the main thread blocks until the loop's cleanup
cycle finishes, which can add hundreds of milliseconds before the caller
receives the result. Running the loop in a daemon thread lets the caller
observe the result as soon as the coroutine completes while cleanup finishes
in the background.
"""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
import threading
from typing import Generic, TypeVar
_T = TypeVar("_T")
class AsyncThreadRunner(threading.Thread, Generic[_T]):
"""Run an async coroutine in a daemon thread and expose its result.
The runner catches all exceptions from the coroutine and stores them in
``exception`` so ``event`` is always set — this prevents callers waiting
on ``event`` from hanging forever when the coroutine crashes.
Typical usage::
runner = AsyncThreadRunner(lambda: my_coro(arg))
runner.start()
if not runner.event.wait(timeout=5.0):
... # timed out
if runner.exception is not None:
raise runner.exception
result = runner.result
"""
def __init__(self, coro_factory: Callable[[], Awaitable[_T]]) -> None:
super().__init__(daemon=True)
self._coro_factory = coro_factory
self.result: _T | None = None
self.exception: BaseException | None = None
self.event = threading.Event()
async def _runner(self) -> None:
try:
self.result = await self._coro_factory()
except Exception as exc: # pylint: disable=broad-except
# Capture all exceptions so ``event`` is always set — otherwise a
# crash would hang the waiter forever.
self.exception = exc
finally:
self.event.set()
def run(self) -> None:
asyncio.run(self._runner())
+17 -31
View File
@@ -2,66 +2,52 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import threading
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
import aioesphomeapi.host_resolver as hr import aioesphomeapi.host_resolver as hr
from esphome.async_thread import AsyncThreadRunner
from esphome.core import EsphomeError from esphome.core import EsphomeError
RESOLVE_TIMEOUT = 10.0 # seconds RESOLVE_TIMEOUT = 10.0 # seconds
class AsyncResolver(threading.Thread): class AsyncResolver:
"""Resolver using aioesphomeapi that runs in a thread for faster results. """Resolver using aioesphomeapi that runs in a thread for faster results.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution, This resolver uses aioesphomeapi's async_resolve_host to handle DNS
including proper .local domain fallback. Running in a thread allows us to get resolution, including proper .local domain fallback. Running in a thread
the result immediately without waiting for asyncio.run() to complete its (via :class:`AsyncThreadRunner`) allows us to get the result immediately
cleanup cycle, which can take significant time. without waiting for ``asyncio.run()`` to complete its cleanup cycle, which
can take significant time.
""" """
def __init__(self, hosts: list[str], port: int) -> None: def __init__(self, hosts: list[str], port: int) -> None:
"""Initialize the resolver.""" """Initialize the resolver."""
super().__init__(daemon=True)
self.hosts = hosts self.hosts = hosts
self.port = port self.port = port
self.result: list[hr.AddrInfo] | None = None
self.exception: Exception | None = None
self.event = threading.Event()
async def _resolve(self) -> None: async def _resolve(self) -> list[hr.AddrInfo]:
"""Resolve hostnames to IP addresses.""" """Resolve hostnames to IP addresses."""
try: return await hr.async_resolve_host(
self.result = await hr.async_resolve_host( self.hosts, self.port, timeout=RESOLVE_TIMEOUT
self.hosts, self.port, timeout=RESOLVE_TIMEOUT )
)
except Exception as e: # pylint: disable=broad-except
# We need to catch all exceptions to ensure the event is set
# Otherwise the thread could hang forever
self.exception = e
finally:
self.event.set()
def run(self) -> None:
"""Run the DNS resolution."""
asyncio.run(self._resolve())
def resolve(self) -> list[hr.AddrInfo]: def resolve(self) -> list[hr.AddrInfo]:
"""Start the thread and wait for the result.""" """Start the thread and wait for the result."""
self.start() runner: AsyncThreadRunner[list[hr.AddrInfo]] = AsyncThreadRunner(self._resolve)
runner.start()
if not self.event.wait( if not runner.event.wait(
timeout=RESOLVE_TIMEOUT + 1.0 timeout=RESOLVE_TIMEOUT + 1.0
): # Give it 1 second more than the resolver timeout ): # Give it 1 second more than the resolver timeout
raise EsphomeError("Timeout resolving IP address") raise EsphomeError("Timeout resolving IP address")
if exc := self.exception: if exc := runner.exception:
if isinstance(exc, ResolveTimeoutAPIError): if isinstance(exc, ResolveTimeoutAPIError):
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
if isinstance(exc, ResolveAPIError): if isinstance(exc, ResolveAPIError):
raise EsphomeError(f"Error resolving IP address: {exc}") from exc raise EsphomeError(f"Error resolving IP address: {exc}") from exc
raise exc raise exc
return self.result assert runner.result is not None # guaranteed when event set and no exception
return runner.result
+174 -7
View File
@@ -14,8 +14,13 @@ from zeroconf import (
) )
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
from esphome.async_thread import AsyncThreadRunner
from esphome.storage_json import StorageJSON, ext_storage_path from esphome.storage_json import StorageJSON, ext_storage_path
# Length of the MAC suffix appended when name_add_mac_suffix is enabled.
MAC_SUFFIX_LEN = 6
_HEX_CHARS = frozenset("0123456789abcdef")
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 10.0 DEFAULT_TIMEOUT = 10.0
@@ -188,15 +193,177 @@ class EsphomeZeroconf(Zeroconf):
return None return None
async def async_resolve_hosts(
zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT
) -> dict[str, list[str]]:
"""Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance.
Tries the cache synchronously first (so hosts already primed by a recent
browse return immediately with no network round-trip), then issues
``async_request`` for the remaining misses in parallel via
``asyncio.gather``. Returns a dict mapping each host to its list of
addresses (empty list when unresolved). Only ``<short>.local`` form is
queried, matching the name scheme the resolvers below expect.
"""
resolvers: dict[str, AddressResolver] = {}
pending: list[str] = []
for host in hosts:
resolver = AddressResolver(f"{host.partition('.')[0]}.local.")
resolvers[host] = resolver
if not resolver.load_from_cache(zeroconf):
pending.append(host)
if pending and timeout:
results = await asyncio.gather(
*(
resolvers[host].async_request(zeroconf, timeout * 1000)
for host in pending
),
return_exceptions=True,
)
for host, result in zip(pending, results):
if isinstance(result, BaseException):
_LOGGER.debug("Failed to resolve %s: %s", host, result)
return {
host: resolver.parsed_scoped_addresses(IPVersion.All)
for host, resolver in resolvers.items()
}
class AsyncEsphomeZeroconf(AsyncZeroconf): class AsyncEsphomeZeroconf(AsyncZeroconf):
async def async_resolve_host( async def async_resolve_host(
self, host: str, timeout: float = DEFAULT_TIMEOUT self, host: str, timeout: float = DEFAULT_TIMEOUT
) -> list[str] | None: ) -> list[str] | None:
"""Resolve a host name to an IP address.""" """Resolve a host name to an IP address."""
info = AddressResolver(f"{host.partition('.')[0]}.local.") addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host]
if ( return addresses or None
info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
) and (addresses := info.parsed_scoped_addresses(IPVersion.All)): def _is_mac_suffix_match(device_name: str, prefix: str) -> bool:
return addresses """Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC."""
return None if not device_name.startswith(prefix):
return False
suffix = device_name[len(prefix) :]
return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix)
async def async_discover_mdns_devices(
base_name: str, timeout: float = 5.0
) -> dict[str, list[str]]:
"""Discover ESPHome devices via mDNS that match the base name + MAC suffix.
When ``name_add_mac_suffix`` is enabled, devices advertise as
``<base_name>-<6-hex-mac>.local``. This function uses a single
``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and
resolve their IP addresses, so callers get resolved addresses without
opening a second Zeroconf client.
Args:
base_name: The base device name (without MAC suffix).
timeout: How long to wait for mDNS responses (default 5 seconds).
Returns:
Mapping of ``<device>.local`` hostnames to their resolved IP addresses
(may be empty for a device if resolution failed within the timeout).
"""
prefix = f"{base_name}-"
# Preserves insertion order for stable output and deduplicates
discovered: dict[str, list[str]] = {}
def on_service_state_change(
zeroconf: Zeroconf,
service_type: str,
name: str,
state_change: ServiceStateChange,
) -> None:
if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated):
return
device_name = name.partition(".")[0]
if not _is_mac_suffix_match(device_name, prefix):
_LOGGER.debug(
"Ignoring %s (%s): does not match '%s<6-hex>'",
device_name,
state_change.name,
prefix,
)
return
host = f"{device_name}.local"
if host in discovered:
return
discovered[host] = []
_LOGGER.debug("Discovered %s (%s)", host, state_change.name)
_LOGGER.debug(
"Starting mDNS discovery for '%s<mac>.local' (timeout=%.1fs)",
prefix,
timeout,
)
try:
aiozc = AsyncEsphomeZeroconf()
except Exception as err: # pylint: disable=broad-except
# Zeroconf init can raise OSError, NonUniqueNameException, etc.
# Any failure here just means we can't discover — log and move on.
_LOGGER.warning("mDNS discovery failed to initialize: %s", err)
return {}
try:
browser = AsyncServiceBrowser(
aiozc.zeroconf,
ESPHOME_SERVICE_TYPE,
handlers=[on_service_state_change],
)
try:
await asyncio.sleep(timeout)
finally:
await browser.async_cancel()
_LOGGER.debug(
"Browse finished: %d device(s) matched '%s<mac>'",
len(discovered),
prefix,
)
# Resolve each discovered hostname on the SAME Zeroconf instance so
# we don't spin up a second client. ``async_resolve_hosts`` tries the
# cache synchronously (the browse usually primes it) before issuing
# any ``async_request`` in parallel for misses.
resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered))
for host, addresses in resolved.items():
if addresses:
discovered[host] = addresses
_LOGGER.debug("Resolved %s -> %s", host, addresses)
else:
_LOGGER.debug("No addresses returned for %s", host)
finally:
await aiozc.async_close()
return dict(sorted(discovered.items()))
def _await_discovery(
runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float
) -> dict[str, list[str]]:
"""Wait for ``runner`` to finish and return its discovery result.
Split out of :func:`discover_mdns_devices` so the timeout branch is
testable without patching ``asyncio`` or ``threading`` internals — a test
passes a stub whose ``event.wait`` returns ``False``.
"""
# Give the discovery an extra second over the browse timeout for the
# resolution + cleanup pass.
if not runner.event.wait(timeout=timeout + 2.0):
_LOGGER.warning("mDNS discovery timed out after %.1fs", timeout)
return {}
if runner.exception is not None:
_LOGGER.warning("mDNS discovery failed: %s", runner.exception)
return {}
return runner.result or {}
def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]:
"""Synchronous wrapper around :func:`async_discover_mdns_devices`."""
runner = AsyncThreadRunner(
lambda: async_discover_mdns_devices(base_name, timeout=timeout)
)
runner.start()
return _await_discovery(runner, timeout)
+20
View File
@@ -121,6 +121,26 @@ def test_get_addresses_auto_detection() -> None:
assert cache.get_addresses("unknown.com") is None assert cache.get_addresses("unknown.com") is None
def test_add_mdns_addresses_stores_and_normalizes() -> None:
"""add_mdns_addresses inserts entries under the normalized hostname."""
cache = AddressCache()
cache.add_mdns_addresses("Device.Local.", ["192.168.1.10", "192.168.1.11"])
assert cache.mdns_cache == {
normalize_hostname("Device.Local."): ["192.168.1.10", "192.168.1.11"]
}
# Overwrites on subsequent calls for the same host
cache.add_mdns_addresses("device.local", ["10.0.0.1"])
assert cache.mdns_cache[normalize_hostname("device.local")] == ["10.0.0.1"]
def test_add_mdns_addresses_empty_is_noop() -> None:
"""Passing an empty address list must not create an entry."""
cache = AddressCache()
cache.add_mdns_addresses("device.local", [])
assert cache.mdns_cache == {}
def test_has_cache() -> None: def test_has_cache() -> None:
"""Test checking if cache has entries.""" """Test checking if cache has entries."""
# Empty cache # Empty cache
File diff suppressed because it is too large Load Diff
+15 -18
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import re import re
import socket import socket
from unittest.mock import patch from unittest.mock import MagicMock, patch
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
@@ -115,24 +115,21 @@ def test_async_resolver_generic_exception() -> None:
def test_async_resolver_thread_timeout() -> None: def test_async_resolver_thread_timeout() -> None:
"""Test timeout when thread doesn't complete in time.""" """Test timeout when the runner thread doesn't complete in time."""
# Mock the start method to prevent actual thread execution # Patch AsyncThreadRunner inside esphome.resolver so we never actually
with ( # start a thread and can control the wait return value directly.
patch.object(AsyncResolver, "start"), fake_runner = MagicMock()
patch("esphome.resolver.hr.async_resolve_host"), fake_runner.start = MagicMock()
): fake_runner.event.wait.return_value = False # simulate timeout
resolver = AsyncResolver(["test.local"], 6053)
# Override event.wait to simulate timeout (return False = timeout occurred)
with (
patch.object(resolver.event, "wait", return_value=False),
pytest.raises(
EsphomeError, match=re.escape("Timeout resolving IP address")
),
):
resolver.resolve()
# Verify thread start was called with (
resolver.start.assert_called_once() patch("esphome.resolver.AsyncThreadRunner", return_value=fake_runner),
patch("esphome.resolver.hr.async_resolve_host"),
pytest.raises(EsphomeError, match=re.escape("Timeout resolving IP address")),
):
AsyncResolver(["test.local"], 6053).resolve()
fake_runner.start.assert_called_once()
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: