diff --git a/esphome/__main__.py b/esphome/__main__.py index 7879cdad0ca..8c80dab90af 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -39,6 +39,7 @@ from esphome.const import ( CONF_MDNS, CONF_MQTT, CONF_NAME, + CONF_NAME_ADD_MAC_SUFFIX, CONF_OTA, CONF_PASSWORD, CONF_PLATFORM, @@ -71,6 +72,7 @@ from esphome.util import ( run_external_process, safe_print, ) +from esphome.zeroconf import discover_mdns_devices _LOGGER = logging.getLogger(__name__) @@ -204,6 +206,64 @@ def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]: return [address] +def _populate_mdns_cache(hosts_to_addresses: dict[str, list[str]]) -> None: + """Store discovered ``host -> [ips]`` entries in ``CORE.address_cache``. + + Ensures ``CORE.address_cache`` exists, then records each mDNS hostname so + the downstream resolution path (``resolve_ip_address``) can skip opening a + second Zeroconf client. + """ + from esphome.address_cache import AddressCache + + if CORE.address_cache is None: + CORE.address_cache = AddressCache() + for host, addresses in hosts_to_addresses.items(): + if addresses: + _LOGGER.debug("Caching mDNS result %s -> %s", host, addresses) + CORE.address_cache.add_mdns_addresses(host, addresses) + + +def _discover_mac_suffix_devices() -> list[str] | None: + """Discover ``-.local`` devices and cache their IPs. + + Returns: + - ``None`` when discovery isn't applicable (``name_add_mac_suffix`` off, + mDNS disabled, or ``CORE.address`` is already an IP). Callers should + then fall back to whatever default OTA address they normally use. + - ``[]`` when discovery ran but found nothing. Callers should NOT fall + back to the base name: with ``name_add_mac_suffix`` enabled, the base + name by definition doesn't exist on the network. + - A non-empty sorted list of ``.local`` hostnames on success. + + Populates ``CORE.address_cache`` so downstream resolution (``espota2`` or + ``aioesphomeapi`` via :func:`_resolve_network_devices`) reuses the IPs we + already have without opening a second Zeroconf client. + """ + if not (has_name_add_mac_suffix() and has_mdns() and has_non_ip_address()): + return None + _LOGGER.info("Discovering devices...") + if not (discovered := discover_mdns_devices(CORE.name)): + _LOGGER.warning( + "No devices matching '%s-.local' were discovered.", CORE.name + ) + return [] + _populate_mdns_cache(discovered) + return list(discovered) + + +def _ota_hostnames_for_default(purpose: Purpose) -> list[str]: + """Return OTA hostname(s) for the ``--device OTA`` / default-resolve path. + + When ``name_add_mac_suffix`` is enabled, returns discovered + ``-.local`` hostnames (possibly empty — in which case the + caller should not fall back to the base name). Otherwise falls back to + the cache-resolved ``CORE.address``. + """ + if (discovered := _discover_mac_suffix_devices()) is not None: + return discovered + return _resolve_with_cache(CORE.address, purpose) + + def choose_upload_log_host( default: list[str] | str | None, check_default: str | None, @@ -242,14 +302,14 @@ def choose_upload_log_host( resolved.append("MQTT") if has_api() and has_non_ip_address() and has_resolvable_address(): - resolved.extend(_resolve_with_cache(CORE.address, purpose)) + resolved.extend(_ota_hostnames_for_default(purpose)) elif purpose == Purpose.UPLOADING: if has_ota() and has_mqtt_ip_lookup(): resolved.append("MQTTIP") if has_ota() and has_non_ip_address() and has_resolvable_address(): - resolved.extend(_resolve_with_cache(CORE.address, purpose)) + resolved.extend(_ota_hostnames_for_default(purpose)) else: resolved.append(device) if not resolved: @@ -281,22 +341,29 @@ def choose_upload_log_host( elif bootsel.permission_error: bootsel_permission_error = True + def add_ota_options() -> None: + """Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled.""" + if (discovered := _discover_mac_suffix_devices()) is not None: + # Discovery was applicable. Use whatever we found — on empty, + # intentionally skip the base-name fallback since with + # name_add_mac_suffix on, the base name doesn't exist on the net. + for host in discovered: + options.append((f"Over The Air ({host})", host)) + elif has_resolvable_address(): + options.append((f"Over The Air ({CORE.address})", CORE.address)) + if has_mqtt_ip_lookup(): + options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + if purpose == Purpose.LOGGING: if has_mqtt_logging(): mqtt_config = CORE.config[CONF_MQTT] options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT")) if has_api(): - if has_resolvable_address(): - options.append((f"Over The Air ({CORE.address})", CORE.address)) - if has_mqtt_ip_lookup(): - options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + add_ota_options() elif purpose == Purpose.UPLOADING and has_ota(): - if has_resolvable_address(): - options.append((f"Over The Air ({CORE.address})", CORE.address)) - if has_mqtt_ip_lookup(): - options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + add_ota_options() # Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found if ( @@ -407,7 +474,17 @@ def has_resolvable_address() -> bool: return not CORE.address.endswith(".local") -def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str): +def has_name_add_mac_suffix() -> bool: + """Check if name_add_mac_suffix is enabled in the config.""" + if CORE.config is None: + return False + esphome_config = CORE.config.get(CONF_ESPHOME, {}) + return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False) + + +def mqtt_get_ip( + config: ConfigType, username: str, password: str, client_id: str +) -> list[str]: from esphome import mqtt return mqtt.get_esphome_device_ip(config, username, password, client_id) @@ -420,6 +497,9 @@ def _resolve_network_devices( This function filters the devices list to: - Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup + - Expand hostnames that are already in ``CORE.address_cache`` to their + cached IPs so downstream code (e.g. aioesphomeapi) doesn't open a second + Zeroconf client to resolve them - Deduplicate addresses while preserving order - Only resolve MQTT once even if multiple MQTT strings are present - If MQTT resolution fails, log a warning and continue with other devices @@ -444,13 +524,29 @@ def _resolve_network_devices( mqtt_ips = mqtt_get_ip( config, args.username, args.password, args.client_id ) - network_devices.extend(mqtt_ips) + # pylint can't infer mqtt_get_ip's return through its + # lazy ``from esphome import mqtt`` import, so it flags + # the genexpr below. + network_devices.extend( + addr + for addr in mqtt_ips # pylint: disable=not-an-iterable + if addr not in network_devices + ) except EsphomeError as err: _LOGGER.warning( "MQTT IP discovery failed (%s), will try other devices if available", err, ) mqtt_resolved = True + continue + + # If the hostname is already in the address cache (e.g. populated by + # mDNS discovery), substitute the cached IPs so aioesphomeapi doesn't + # open its own Zeroconf to re-resolve it. + if CORE.address_cache and (cached := CORE.address_cache.get_addresses(device)): + network_devices.extend( + addr for addr in cached if addr not in network_devices + ) elif device not in network_devices: # Regular network address or IP - add if not already present network_devices.append(device) diff --git a/esphome/address_cache.py b/esphome/address_cache.py index 7c20be90f00..4fb3689818b 100644 --- a/esphome/address_cache.py +++ b/esphome/address_cache.py @@ -101,6 +101,17 @@ class AddressCache: """Check if any cache entries exist.""" return bool(self.mdns_cache or self.dns_cache) + def add_mdns_addresses(self, hostname: str, addresses: list[str]) -> None: + """Store resolved mDNS addresses for ``hostname`` in the cache. + + Callers that discover ``.local`` hosts (e.g. via mDNS browse) can use + this to avoid a second resolution round-trip during the upload path. + No-op when ``addresses`` is empty. + """ + if not addresses: + return + self.mdns_cache[normalize_hostname(hostname)] = addresses + @classmethod def from_cli_args( cls, mdns_args: Iterable[str], dns_args: Iterable[str] diff --git a/esphome/async_thread.py b/esphome/async_thread.py new file mode 100644 index 00000000000..7be3c83a9a5 --- /dev/null +++ b/esphome/async_thread.py @@ -0,0 +1,56 @@ +"""Helpers for running an async coroutine from sync code via a daemon thread. + +``asyncio.run(coro())`` in the main thread blocks until the loop's cleanup +cycle finishes, which can add hundreds of milliseconds before the caller +receives the result. Running the loop in a daemon thread lets the caller +observe the result as soon as the coroutine completes while cleanup finishes +in the background. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +import threading +from typing import Generic, TypeVar + +_T = TypeVar("_T") + + +class AsyncThreadRunner(threading.Thread, Generic[_T]): + """Run an async coroutine in a daemon thread and expose its result. + + The runner catches all exceptions from the coroutine and stores them in + ``exception`` so ``event`` is always set — this prevents callers waiting + on ``event`` from hanging forever when the coroutine crashes. + + Typical usage:: + + runner = AsyncThreadRunner(lambda: my_coro(arg)) + runner.start() + if not runner.event.wait(timeout=5.0): + ... # timed out + if runner.exception is not None: + raise runner.exception + result = runner.result + """ + + def __init__(self, coro_factory: Callable[[], Awaitable[_T]]) -> None: + super().__init__(daemon=True) + self._coro_factory = coro_factory + self.result: _T | None = None + self.exception: BaseException | None = None + self.event = threading.Event() + + async def _runner(self) -> None: + try: + self.result = await self._coro_factory() + except Exception as exc: # pylint: disable=broad-except + # Capture all exceptions so ``event`` is always set — otherwise a + # crash would hang the waiter forever. + self.exception = exc + finally: + self.event.set() + + def run(self) -> None: + asyncio.run(self._runner()) diff --git a/esphome/resolver.py b/esphome/resolver.py index 99482aa20e9..9fb596ce7b5 100644 --- a/esphome/resolver.py +++ b/esphome/resolver.py @@ -2,66 +2,52 @@ from __future__ import annotations -import asyncio -import threading - from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError import aioesphomeapi.host_resolver as hr +from esphome.async_thread import AsyncThreadRunner from esphome.core import EsphomeError RESOLVE_TIMEOUT = 10.0 # seconds -class AsyncResolver(threading.Thread): +class AsyncResolver: """Resolver using aioesphomeapi that runs in a thread for faster results. - This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution, - including proper .local domain fallback. Running in a thread allows us to get - the result immediately without waiting for asyncio.run() to complete its - cleanup cycle, which can take significant time. + This resolver uses aioesphomeapi's async_resolve_host to handle DNS + resolution, including proper .local domain fallback. Running in a thread + (via :class:`AsyncThreadRunner`) allows us to get the result immediately + without waiting for ``asyncio.run()`` to complete its cleanup cycle, which + can take significant time. """ def __init__(self, hosts: list[str], port: int) -> None: """Initialize the resolver.""" - super().__init__(daemon=True) self.hosts = hosts self.port = port - self.result: list[hr.AddrInfo] | None = None - self.exception: Exception | None = None - self.event = threading.Event() - async def _resolve(self) -> None: + async def _resolve(self) -> list[hr.AddrInfo]: """Resolve hostnames to IP addresses.""" - try: - self.result = await hr.async_resolve_host( - self.hosts, self.port, timeout=RESOLVE_TIMEOUT - ) - except Exception as e: # pylint: disable=broad-except - # We need to catch all exceptions to ensure the event is set - # Otherwise the thread could hang forever - self.exception = e - finally: - self.event.set() - - def run(self) -> None: - """Run the DNS resolution.""" - asyncio.run(self._resolve()) + return await hr.async_resolve_host( + self.hosts, self.port, timeout=RESOLVE_TIMEOUT + ) def resolve(self) -> list[hr.AddrInfo]: """Start the thread and wait for the result.""" - self.start() + runner: AsyncThreadRunner[list[hr.AddrInfo]] = AsyncThreadRunner(self._resolve) + runner.start() - if not self.event.wait( + if not runner.event.wait( timeout=RESOLVE_TIMEOUT + 1.0 ): # Give it 1 second more than the resolver timeout raise EsphomeError("Timeout resolving IP address") - if exc := self.exception: + if exc := runner.exception: if isinstance(exc, ResolveTimeoutAPIError): raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc if isinstance(exc, ResolveAPIError): raise EsphomeError(f"Error resolving IP address: {exc}") from exc raise exc - return self.result + assert runner.result is not None # guaranteed when event set and no exception + return runner.result diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index dd45b58a6cb..6f5d33c8085 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -14,8 +14,13 @@ from zeroconf import ( ) from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf +from esphome.async_thread import AsyncThreadRunner from esphome.storage_json import StorageJSON, ext_storage_path +# Length of the MAC suffix appended when name_add_mac_suffix is enabled. +MAC_SUFFIX_LEN = 6 +_HEX_CHARS = frozenset("0123456789abcdef") + _LOGGER = logging.getLogger(__name__) DEFAULT_TIMEOUT = 10.0 @@ -188,15 +193,177 @@ class EsphomeZeroconf(Zeroconf): return None +async def async_resolve_hosts( + zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT +) -> dict[str, list[str]]: + """Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance. + + Tries the cache synchronously first (so hosts already primed by a recent + browse return immediately with no network round-trip), then issues + ``async_request`` for the remaining misses in parallel via + ``asyncio.gather``. Returns a dict mapping each host to its list of + addresses (empty list when unresolved). Only ``.local`` form is + queried, matching the name scheme the resolvers below expect. + """ + resolvers: dict[str, AddressResolver] = {} + pending: list[str] = [] + for host in hosts: + resolver = AddressResolver(f"{host.partition('.')[0]}.local.") + resolvers[host] = resolver + if not resolver.load_from_cache(zeroconf): + pending.append(host) + + if pending and timeout: + results = await asyncio.gather( + *( + resolvers[host].async_request(zeroconf, timeout * 1000) + for host in pending + ), + return_exceptions=True, + ) + for host, result in zip(pending, results): + if isinstance(result, BaseException): + _LOGGER.debug("Failed to resolve %s: %s", host, result) + + return { + host: resolver.parsed_scoped_addresses(IPVersion.All) + for host, resolver in resolvers.items() + } + + class AsyncEsphomeZeroconf(AsyncZeroconf): async def async_resolve_host( self, host: str, timeout: float = DEFAULT_TIMEOUT ) -> list[str] | None: """Resolve a host name to an IP address.""" - info = AddressResolver(f"{host.partition('.')[0]}.local.") - if ( - info.load_from_cache(self.zeroconf) - or (timeout and await info.async_request(self.zeroconf, timeout * 1000)) - ) and (addresses := info.parsed_scoped_addresses(IPVersion.All)): - return addresses - return None + addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host] + return addresses or None + + +def _is_mac_suffix_match(device_name: str, prefix: str) -> bool: + """Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC.""" + if not device_name.startswith(prefix): + return False + suffix = device_name[len(prefix) :] + return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix) + + +async def async_discover_mdns_devices( + base_name: str, timeout: float = 5.0 +) -> dict[str, list[str]]: + """Discover ESPHome devices via mDNS that match the base name + MAC suffix. + + When ``name_add_mac_suffix`` is enabled, devices advertise as + ``-<6-hex-mac>.local``. This function uses a single + ``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and + resolve their IP addresses, so callers get resolved addresses without + opening a second Zeroconf client. + + Args: + base_name: The base device name (without MAC suffix). + timeout: How long to wait for mDNS responses (default 5 seconds). + + Returns: + Mapping of ``.local`` hostnames to their resolved IP addresses + (may be empty for a device if resolution failed within the timeout). + """ + prefix = f"{base_name}-" + # Preserves insertion order for stable output and deduplicates + discovered: dict[str, list[str]] = {} + + def on_service_state_change( + zeroconf: Zeroconf, + service_type: str, + name: str, + state_change: ServiceStateChange, + ) -> None: + if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated): + return + device_name = name.partition(".")[0] + if not _is_mac_suffix_match(device_name, prefix): + _LOGGER.debug( + "Ignoring %s (%s): does not match '%s<6-hex>'", + device_name, + state_change.name, + prefix, + ) + return + host = f"{device_name}.local" + if host in discovered: + return + discovered[host] = [] + _LOGGER.debug("Discovered %s (%s)", host, state_change.name) + + _LOGGER.debug( + "Starting mDNS discovery for '%s.local' (timeout=%.1fs)", + prefix, + timeout, + ) + try: + aiozc = AsyncEsphomeZeroconf() + except Exception as err: # pylint: disable=broad-except + # Zeroconf init can raise OSError, NonUniqueNameException, etc. + # Any failure here just means we can't discover — log and move on. + _LOGGER.warning("mDNS discovery failed to initialize: %s", err) + return {} + + try: + browser = AsyncServiceBrowser( + aiozc.zeroconf, + ESPHOME_SERVICE_TYPE, + handlers=[on_service_state_change], + ) + try: + await asyncio.sleep(timeout) + finally: + await browser.async_cancel() + _LOGGER.debug( + "Browse finished: %d device(s) matched '%s'", + len(discovered), + prefix, + ) + + # Resolve each discovered hostname on the SAME Zeroconf instance so + # we don't spin up a second client. ``async_resolve_hosts`` tries the + # cache synchronously (the browse usually primes it) before issuing + # any ``async_request`` in parallel for misses. + resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered)) + for host, addresses in resolved.items(): + if addresses: + discovered[host] = addresses + _LOGGER.debug("Resolved %s -> %s", host, addresses) + else: + _LOGGER.debug("No addresses returned for %s", host) + finally: + await aiozc.async_close() + + return dict(sorted(discovered.items())) + + +def _await_discovery( + runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float +) -> dict[str, list[str]]: + """Wait for ``runner`` to finish and return its discovery result. + + Split out of :func:`discover_mdns_devices` so the timeout branch is + testable without patching ``asyncio`` or ``threading`` internals — a test + passes a stub whose ``event.wait`` returns ``False``. + """ + # Give the discovery an extra second over the browse timeout for the + # resolution + cleanup pass. + if not runner.event.wait(timeout=timeout + 2.0): + _LOGGER.warning("mDNS discovery timed out after %.1fs", timeout) + return {} + if runner.exception is not None: + _LOGGER.warning("mDNS discovery failed: %s", runner.exception) + return {} + return runner.result or {} + + +def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]: + """Synchronous wrapper around :func:`async_discover_mdns_devices`.""" + runner = AsyncThreadRunner( + lambda: async_discover_mdns_devices(base_name, timeout=timeout) + ) + runner.start() + return _await_discovery(runner, timeout) diff --git a/tests/unit_tests/test_address_cache.py b/tests/unit_tests/test_address_cache.py index de43830d532..1ca28c4f029 100644 --- a/tests/unit_tests/test_address_cache.py +++ b/tests/unit_tests/test_address_cache.py @@ -121,6 +121,26 @@ def test_get_addresses_auto_detection() -> None: assert cache.get_addresses("unknown.com") is None +def test_add_mdns_addresses_stores_and_normalizes() -> None: + """add_mdns_addresses inserts entries under the normalized hostname.""" + cache = AddressCache() + cache.add_mdns_addresses("Device.Local.", ["192.168.1.10", "192.168.1.11"]) + + assert cache.mdns_cache == { + normalize_hostname("Device.Local."): ["192.168.1.10", "192.168.1.11"] + } + # Overwrites on subsequent calls for the same host + cache.add_mdns_addresses("device.local", ["10.0.0.1"]) + assert cache.mdns_cache[normalize_hostname("device.local")] == ["10.0.0.1"] + + +def test_add_mdns_addresses_empty_is_noop() -> None: + """Passing an empty address list must not create an entry.""" + cache = AddressCache() + cache.add_mdns_addresses("device.local", []) + assert cache.mdns_cache == {} + + def test_has_cache() -> None: """Test checking if cache has entries.""" # Empty cache diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index e07b4accf23..8ec9e70cf85 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass import json import logging @@ -12,16 +12,18 @@ import re import sys import time from typing import Any -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from pytest import CaptureFixture +from zeroconf import ServiceStateChange from esphome import platformio_api from esphome.__main__ import ( Purpose, _get_configured_xtal_freq, _make_crystal_freq_callback, + _resolve_network_devices, choose_upload_log_host, command_analyze_memory, command_bundle, @@ -36,6 +38,7 @@ from esphome.__main__ import ( has_mqtt, has_mqtt_ip_lookup, has_mqtt_logging, + has_name_add_mac_suffix, has_non_ip_address, has_ota, has_resolvable_address, @@ -48,6 +51,7 @@ from esphome.__main__ import ( upload_using_picotool, upload_using_platformio, ) +from esphome.address_cache import AddressCache from esphome.bundle import BUNDLE_EXTENSION, BundleFile, BundleResult from esphome.components.esp32 import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32 from esphome.const import ( @@ -62,6 +66,7 @@ from esphome.const import ( CONF_MDNS, CONF_MQTT, CONF_NAME, + CONF_NAME_ADD_MAC_SUFFIX, CONF_OTA, CONF_PASSWORD, CONF_PLATFORM, @@ -79,6 +84,7 @@ from esphome.const import ( ) from esphome.core import CORE, EsphomeError from esphome.util import BootselResult +from esphome.zeroconf import _await_discovery, discover_mdns_devices def strip_ansi_codes(text: str) -> str: @@ -2218,6 +2224,509 @@ def test_has_resolvable_address() -> None: assert has_resolvable_address() is False +def test_has_name_add_mac_suffix() -> None: + """Test has_name_add_mac_suffix function.""" + + # Test with name_add_mac_suffix enabled + setup_core(config={CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True}}) + assert has_name_add_mac_suffix() is True + + # Test with name_add_mac_suffix disabled + setup_core(config={CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: False}}) + assert has_name_add_mac_suffix() is False + + # Test with name_add_mac_suffix not set (defaults to False) + setup_core(config={CONF_ESPHOME: {}}) + assert has_name_add_mac_suffix() is False + + # Test with no esphome config + setup_core(config={}) + assert has_name_add_mac_suffix() is False + + # Test with no config at all + CORE.config = None + assert has_name_add_mac_suffix() is False + + +@pytest.fixture +def mock_mdns_discovery() -> Generator[MagicMock]: + """Fixture to mock the async mDNS discovery infrastructure. + + Patches ``AsyncEsphomeZeroconf``, ``AsyncServiceBrowser`` and + ``AddressResolver`` in ``esphome.zeroconf`` and exposes hooks for tests to + stage browser events and control resolution results. The default + ``AddressResolver`` stub simulates a cache hit returning no addresses, so + matched hosts appear in the discovery output with empty address lists + unless the test overrides ``_resolver_setup``. + """ + with ( + patch("esphome.zeroconf.AsyncEsphomeZeroconf") as mock_aiozc_class, + patch("esphome.zeroconf.AsyncServiceBrowser") as mock_browser_class, + patch("esphome.zeroconf.AddressResolver") as mock_resolver_class, + ): + mock_aiozc = MagicMock() + mock_aiozc.zeroconf = MagicMock() + mock_aiozc.async_close = AsyncMock(return_value=None) + mock_aiozc_class.return_value = mock_aiozc + + mock_browser = MagicMock() + mock_browser.async_cancel = AsyncMock(return_value=None) + + # Default: each host gets a fresh resolver that hits the cache and + # returns no addresses. Tests can override via ``_resolver_setup``. + def default_resolver_factory(name: str) -> MagicMock: + resolver = MagicMock() + resolver._name = name + resolver.load_from_cache.return_value = True + resolver.async_request = AsyncMock(return_value=True) + resolver.parsed_scoped_addresses.return_value = [] + return resolver + + mock_resolver_class.side_effect = default_resolver_factory + + # Store references for test access + mock_aiozc._mock_browser_class = mock_browser_class + mock_aiozc._mock_browser = mock_browser + mock_aiozc._mock_class = mock_aiozc_class + mock_aiozc._mock_resolver_class = mock_resolver_class + yield mock_aiozc + + +@pytest.mark.parametrize( + ("discovered_services", "base_name", "expected_hosts"), + [ + # Matching devices; different-prefix device is filtered out + ( + [ + ("mydevice-abc123._esphomelib._tcp.local.", ServiceStateChange.Added), + ("mydevice-def456._esphomelib._tcp.local.", ServiceStateChange.Added), + ( + "otherdevice-abcdef._esphomelib._tcp.local.", + ServiceStateChange.Added, + ), + ], + "mydevice", + ["mydevice-abc123.local", "mydevice-def456.local"], + ), + # No matches at all + ( + [ + ( + "otherdevice-abcdef._esphomelib._tcp.local.", + ServiceStateChange.Added, + ), + ], + "mydevice", + [], + ), + # Deduplication (same device Added then Updated) + ( + [ + ("mydevice-abc123._esphomelib._tcp.local.", ServiceStateChange.Added), + ("mydevice-abc123._esphomelib._tcp.local.", ServiceStateChange.Updated), + ], + "mydevice", + ["mydevice-abc123.local"], + ), + # Suffix must be exactly 6 hex chars: wrong length and non-hex are rejected + ( + [ + # too short + ("mydevice-abcd._esphomelib._tcp.local.", ServiceStateChange.Added), + # too long + ( + "mydevice-abcdef1._esphomelib._tcp.local.", + ServiceStateChange.Added, + ), + # non-hex + ("mydevice-xyz123._esphomelib._tcp.local.", ServiceStateChange.Added), + # valid + ("mydevice-012345._esphomelib._tcp.local.", ServiceStateChange.Added), + ], + "mydevice", + ["mydevice-012345.local"], + ), + # Prefix-collision: base "foo" must not match "foo-bar-abc123" + ( + [ + ("foo-abcdef._esphomelib._tcp.local.", ServiceStateChange.Added), + ("foo-bar-abcdef._esphomelib._tcp.local.", ServiceStateChange.Added), + ], + "foo", + ["foo-abcdef.local"], + ), + ], + ids=[ + "matching_with_filter", + "no_matches", + "deduplication", + "hex_suffix_filter", + "prefix_collision", + ], +) +def test_discover_mdns_devices( + mock_mdns_discovery: MagicMock, + discovered_services: list[tuple[str, ServiceStateChange]], + base_name: str, + expected_hosts: list[str], +) -> None: + """Test discover_mdns_devices filtering and deduplication.""" + mock_browser = mock_mdns_discovery._mock_browser + + def capture_callback( + zc: MagicMock, + service_type: str, + handlers: list[Callable[..., None]], + ) -> MagicMock: + callback = handlers[0] + for service_name, state_change in discovered_services: + callback( + mock_mdns_discovery.zeroconf, service_type, service_name, state_change + ) + return mock_browser + + mock_mdns_discovery._mock_browser_class.side_effect = capture_callback + + # Each discovered host gets a resolver that returns a unique IP string + # derived from its server name so we can assert per-host. + def resolver_factory(name: str) -> MagicMock: + resolver = MagicMock() + resolver._name = name + resolver.load_from_cache.return_value = True + resolver.async_request = AsyncMock(return_value=True) + resolver.parsed_scoped_addresses.return_value = [f"10.0.0.1#{name}"] + return resolver + + mock_mdns_discovery._mock_resolver_class.side_effect = resolver_factory + + result = discover_mdns_devices(base_name, timeout=0) + + assert sorted(result) == expected_hosts + # Resolved addresses should be stored for matched hosts. AddressResolver + # receives the fully-qualified name (``.local.``). + for host in expected_hosts: + short = host.partition(".")[0] + assert result[host] == [f"10.0.0.1#{short}.local."] + mock_browser.async_cancel.assert_awaited_once() + mock_mdns_discovery.async_close.assert_awaited_once() + + +def test_discover_mdns_devices_init_failure(caplog: pytest.LogCaptureFixture) -> None: + """If AsyncEsphomeZeroconf fails to init, return empty dict and log warning.""" + with ( + patch( + "esphome.zeroconf.AsyncEsphomeZeroconf", + side_effect=OSError("no network"), + ), + caplog.at_level(logging.WARNING, logger="esphome.zeroconf"), + ): + result = discover_mdns_devices("mydevice", timeout=0) + + assert result == {} + assert "mDNS discovery failed to initialize" in caplog.text + + +def test_discover_mdns_devices_resolution_failure( + mock_mdns_discovery: MagicMock, +) -> None: + """If resolution raises, the host is still listed with an empty address list.""" + mock_browser = mock_mdns_discovery._mock_browser + + def capture_callback( + zc: MagicMock, + service_type: str, + handlers: list[Callable[..., None]], + ) -> MagicMock: + handlers[0]( + mock_mdns_discovery.zeroconf, + service_type, + "mydevice-abc123._esphomelib._tcp.local.", + ServiceStateChange.Added, + ) + return mock_browser + + mock_mdns_discovery._mock_browser_class.side_effect = capture_callback + + # Resolver misses the cache, then async_request raises. + def failing_resolver_factory(name: str) -> MagicMock: + resolver = MagicMock() + resolver.load_from_cache.return_value = False + resolver.async_request = AsyncMock(side_effect=OSError("boom")) + resolver.parsed_scoped_addresses.return_value = [] + return resolver + + mock_mdns_discovery._mock_resolver_class.side_effect = failing_resolver_factory + + result = discover_mdns_devices("mydevice", timeout=0) + + assert result == {"mydevice-abc123.local": []} + + +def test_discover_mdns_devices_ignores_removed_state( + mock_mdns_discovery: MagicMock, +) -> None: + """``Removed`` state changes are ignored and do not appear in the result.""" + mock_browser = mock_mdns_discovery._mock_browser + + def capture_callback( + zc: MagicMock, + service_type: str, + handlers: list[Callable[..., None]], + ) -> MagicMock: + handlers[0]( + mock_mdns_discovery.zeroconf, + service_type, + "mydevice-abc123._esphomelib._tcp.local.", + ServiceStateChange.Removed, + ) + return mock_browser + + mock_mdns_discovery._mock_browser_class.side_effect = capture_callback + + result = discover_mdns_devices("mydevice", timeout=0) + + assert result == {} + # No AddressResolver should have been constructed since no host matched. + mock_mdns_discovery._mock_resolver_class.assert_not_called() + + +def test_discover_mdns_devices_empty_resolution( + mock_mdns_discovery: MagicMock, +) -> None: + """Host is listed with empty addresses when resolver returns no addresses.""" + mock_browser = mock_mdns_discovery._mock_browser + + def capture_callback( + zc: MagicMock, + service_type: str, + handlers: list[Callable[..., None]], + ) -> MagicMock: + handlers[0]( + mock_mdns_discovery.zeroconf, + service_type, + "mydevice-abc123._esphomelib._tcp.local.", + ServiceStateChange.Added, + ) + return mock_browser + + mock_mdns_discovery._mock_browser_class.side_effect = capture_callback + # Default fixture resolver is a cache-hit with no addresses — simulates + # the "browse found it but no A/AAAA records are available" case. + + result = discover_mdns_devices("mydevice", timeout=0) + + assert result == {"mydevice-abc123.local": []} + + +def test_resolve_network_devices_expands_cached_mdns_hosts(tmp_path: Path) -> None: + """Hostnames in ``CORE.address_cache`` are expanded to their cached IPs.""" + setup_core(tmp_path=tmp_path) + CORE.address_cache = AddressCache( + mdns_cache={ + "device-abc123.local": ["10.0.0.1", "10.0.0.2"], + } + ) + + result = _resolve_network_devices( + ["device-abc123.local", "192.168.1.50", "device-abc123.local"], + CORE.config, + MockArgs(), + ) + + # Cached hostname is replaced with its IPs (deduplicated across repeats) + # and the literal IP is preserved after. + assert result == ["10.0.0.1", "10.0.0.2", "192.168.1.50"] + + +def test_resolve_network_devices_keeps_uncached_hosts(tmp_path: Path) -> None: + """Hostnames not in the cache pass through unchanged.""" + setup_core(tmp_path=tmp_path) + CORE.address_cache = AddressCache() + + result = _resolve_network_devices( + ["unknown.local", "192.168.1.50"], + CORE.config, + MockArgs(), + ) + + assert result == ["unknown.local", "192.168.1.50"] + + +def test_await_discovery_timeout_returns_empty( + caplog: pytest.LogCaptureFixture, +) -> None: + """If the discovery runner never sets its event, return {} and warn.""" + stub = MagicMock() + stub.event.wait.return_value = False + stub.exception = None + stub.result = {"should_not_be_read": ["1.2.3.4"]} + + with caplog.at_level(logging.WARNING, logger="esphome.zeroconf"): + result = _await_discovery(stub, timeout=0.01) + + assert result == {} + assert "mDNS discovery timed out after 0.0s" in caplog.text + stub.event.wait.assert_called_once_with(timeout=pytest.approx(2.01)) + + +def test_await_discovery_propagates_exception_as_empty( + caplog: pytest.LogCaptureFixture, +) -> None: + """If the coroutine raised, log and return {} rather than re-raise.""" + stub = MagicMock() + stub.event.wait.return_value = True + stub.exception = RuntimeError("boom") + stub.result = None + + with caplog.at_level(logging.WARNING, logger="esphome.zeroconf"): + result = _await_discovery(stub, timeout=5.0) + + assert result == {} + assert "mDNS discovery failed: boom" in caplog.text + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_discovers_mac_suffix_devices(tmp_path: Path) -> None: + """Interactive mode discovers MAC-suffixed devices and populates the cache.""" + setup_core( + config={ + CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True}, + CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}], + }, + address="mydevice.local", + tmp_path=tmp_path, + name="mydevice", + ) + CORE.address_cache = None + + discovered = { + "mydevice-abc123.local": ["10.0.0.1"], + "mydevice-def456.local": ["10.0.0.2"], + } + with ( + patch( + "esphome.__main__.discover_mdns_devices", return_value=discovered + ) as mock_discover, + patch( + "esphome.__main__.choose_prompt", return_value="mydevice-abc123.local" + ) as mock_prompt, + ): + result = choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.UPLOADING, + ) + + assert result == ["mydevice-abc123.local"] + mock_discover.assert_called_once_with("mydevice") + mock_prompt.assert_called_once_with( + [ + ("Over The Air (mydevice-abc123.local)", "mydevice-abc123.local"), + ("Over The Air (mydevice-def456.local)", "mydevice-def456.local"), + ], + purpose=Purpose.UPLOADING, + ) + # Resolved IPs should be cached so downstream resolution skips a second + # Zeroconf lookup. + assert CORE.address_cache is not None + assert CORE.address_cache.get_mdns_addresses("mydevice-abc123.local") == [ + "10.0.0.1" + ] + assert CORE.address_cache.get_mdns_addresses("mydevice-def456.local") == [ + "10.0.0.2" + ] + + +@pytest.mark.usefixtures("mock_no_serial_ports") +def test_choose_upload_log_host_mac_suffix_no_devices_found( + tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + """When discovery finds nothing, no OTA option is offered and a warning logs.""" + setup_core( + config={ + CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True}, + CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}], + }, + address="mydevice.local", + tmp_path=tmp_path, + name="mydevice", + ) + + with ( + patch("esphome.__main__.discover_mdns_devices", return_value={}), + caplog.at_level(logging.WARNING, logger="esphome.__main__"), + pytest.raises(EsphomeError), + ): + choose_upload_log_host( + default=None, + check_default=None, + purpose=Purpose.UPLOADING, + ) + + assert "No devices matching 'mydevice-.local'" in caplog.text + + +def test_choose_upload_log_host_default_ota_discovers_mac_suffix( + tmp_path: Path, +) -> None: + """``--device OTA`` also runs mDNS discovery when name_add_mac_suffix is on.""" + setup_core( + config={ + CONF_ESPHOME: {CONF_NAME_ADD_MAC_SUFFIX: True}, + CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}], + }, + address="mydevice.local", + tmp_path=tmp_path, + name="mydevice", + ) + CORE.address_cache = None + + discovered = { + "mydevice-abc123.local": ["10.0.0.1"], + "mydevice-def456.local": ["10.0.0.2"], + } + with patch( + "esphome.__main__.discover_mdns_devices", return_value=discovered + ) as mock_discover: + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + + # Both discovered hostnames are returned so aioesphomeapi / espota2 can + # try each in turn with the cached IPs. + assert result == ["mydevice-abc123.local", "mydevice-def456.local"] + mock_discover.assert_called_once_with("mydevice") + assert CORE.address_cache is not None + assert CORE.address_cache.get_mdns_addresses("mydevice-abc123.local") == [ + "10.0.0.1" + ] + + +def test_choose_upload_log_host_default_ota_no_suffix_discovery( + tmp_path: Path, +) -> None: + """``--device OTA`` without name_add_mac_suffix uses CORE.address as-is.""" + setup_core( + config={CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}]}, + address="192.168.1.100", + tmp_path=tmp_path, + name="mydevice", + ) + + with patch("esphome.__main__.discover_mdns_devices") as mock_discover: + result = choose_upload_log_host( + default="OTA", + check_default=None, + purpose=Purpose.UPLOADING, + ) + + assert result == ["192.168.1.100"] + # Discovery must NOT run when name_add_mac_suffix is disabled. + mock_discover.assert_not_called() + + def test_command_wizard(tmp_path: Path) -> None: """Test command_wizard function.""" config_file = tmp_path / "test.yaml" diff --git a/tests/unit_tests/test_resolver.py b/tests/unit_tests/test_resolver.py index b4cca05d9fd..7862c268ca1 100644 --- a/tests/unit_tests/test_resolver.py +++ b/tests/unit_tests/test_resolver.py @@ -4,7 +4,7 @@ from __future__ import annotations import re import socket -from unittest.mock import patch +from unittest.mock import MagicMock, patch from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr @@ -115,24 +115,21 @@ def test_async_resolver_generic_exception() -> None: def test_async_resolver_thread_timeout() -> None: - """Test timeout when thread doesn't complete in time.""" - # Mock the start method to prevent actual thread execution - with ( - patch.object(AsyncResolver, "start"), - patch("esphome.resolver.hr.async_resolve_host"), - ): - resolver = AsyncResolver(["test.local"], 6053) - # Override event.wait to simulate timeout (return False = timeout occurred) - with ( - patch.object(resolver.event, "wait", return_value=False), - pytest.raises( - EsphomeError, match=re.escape("Timeout resolving IP address") - ), - ): - resolver.resolve() + """Test timeout when the runner thread doesn't complete in time.""" + # Patch AsyncThreadRunner inside esphome.resolver so we never actually + # start a thread and can control the wait return value directly. + fake_runner = MagicMock() + fake_runner.start = MagicMock() + fake_runner.event.wait.return_value = False # simulate timeout - # Verify thread start was called - resolver.start.assert_called_once() + with ( + patch("esphome.resolver.AsyncThreadRunner", return_value=fake_runner), + patch("esphome.resolver.hr.async_resolve_host"), + pytest.raises(EsphomeError, match=re.escape("Timeout resolving IP address")), + ): + AsyncResolver(["test.local"], 6053).resolve() + + fake_runner.start.assert_called_once() def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: