mirror of
https://github.com/esphome/esphome.git
synced 2026-05-20 09:31:56 +08:00
[core] Allow finding all devices as target that match mac suffix (#13135)
This commit is contained in:
+108
-12
@@ -39,6 +39,7 @@ from esphome.const import (
|
||||
CONF_MDNS,
|
||||
CONF_MQTT,
|
||||
CONF_NAME,
|
||||
CONF_NAME_ADD_MAC_SUFFIX,
|
||||
CONF_OTA,
|
||||
CONF_PASSWORD,
|
||||
CONF_PLATFORM,
|
||||
@@ -71,6 +72,7 @@ from esphome.util import (
|
||||
run_external_process,
|
||||
safe_print,
|
||||
)
|
||||
from esphome.zeroconf import discover_mdns_devices
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -204,6 +206,64 @@ def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
|
||||
return [address]
|
||||
|
||||
|
||||
def _populate_mdns_cache(hosts_to_addresses: dict[str, list[str]]) -> None:
|
||||
"""Store discovered ``host -> [ips]`` entries in ``CORE.address_cache``.
|
||||
|
||||
Ensures ``CORE.address_cache`` exists, then records each mDNS hostname so
|
||||
the downstream resolution path (``resolve_ip_address``) can skip opening a
|
||||
second Zeroconf client.
|
||||
"""
|
||||
from esphome.address_cache import AddressCache
|
||||
|
||||
if CORE.address_cache is None:
|
||||
CORE.address_cache = AddressCache()
|
||||
for host, addresses in hosts_to_addresses.items():
|
||||
if addresses:
|
||||
_LOGGER.debug("Caching mDNS result %s -> %s", host, addresses)
|
||||
CORE.address_cache.add_mdns_addresses(host, addresses)
|
||||
|
||||
|
||||
def _discover_mac_suffix_devices() -> list[str] | None:
|
||||
"""Discover ``<name>-<mac>.local`` devices and cache their IPs.
|
||||
|
||||
Returns:
|
||||
- ``None`` when discovery isn't applicable (``name_add_mac_suffix`` off,
|
||||
mDNS disabled, or ``CORE.address`` is already an IP). Callers should
|
||||
then fall back to whatever default OTA address they normally use.
|
||||
- ``[]`` when discovery ran but found nothing. Callers should NOT fall
|
||||
back to the base name: with ``name_add_mac_suffix`` enabled, the base
|
||||
name by definition doesn't exist on the network.
|
||||
- A non-empty sorted list of ``.local`` hostnames on success.
|
||||
|
||||
Populates ``CORE.address_cache`` so downstream resolution (``espota2`` or
|
||||
``aioesphomeapi`` via :func:`_resolve_network_devices`) reuses the IPs we
|
||||
already have without opening a second Zeroconf client.
|
||||
"""
|
||||
if not (has_name_add_mac_suffix() and has_mdns() and has_non_ip_address()):
|
||||
return None
|
||||
_LOGGER.info("Discovering devices...")
|
||||
if not (discovered := discover_mdns_devices(CORE.name)):
|
||||
_LOGGER.warning(
|
||||
"No devices matching '%s-<mac>.local' were discovered.", CORE.name
|
||||
)
|
||||
return []
|
||||
_populate_mdns_cache(discovered)
|
||||
return list(discovered)
|
||||
|
||||
|
||||
def _ota_hostnames_for_default(purpose: Purpose) -> list[str]:
|
||||
"""Return OTA hostname(s) for the ``--device OTA`` / default-resolve path.
|
||||
|
||||
When ``name_add_mac_suffix`` is enabled, returns discovered
|
||||
``<name>-<mac>.local`` hostnames (possibly empty — in which case the
|
||||
caller should not fall back to the base name). Otherwise falls back to
|
||||
the cache-resolved ``CORE.address``.
|
||||
"""
|
||||
if (discovered := _discover_mac_suffix_devices()) is not None:
|
||||
return discovered
|
||||
return _resolve_with_cache(CORE.address, purpose)
|
||||
|
||||
|
||||
def choose_upload_log_host(
|
||||
default: list[str] | str | None,
|
||||
check_default: str | None,
|
||||
@@ -242,14 +302,14 @@ def choose_upload_log_host(
|
||||
resolved.append("MQTT")
|
||||
|
||||
if has_api() and has_non_ip_address() and has_resolvable_address():
|
||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
||||
resolved.extend(_ota_hostnames_for_default(purpose))
|
||||
|
||||
elif purpose == Purpose.UPLOADING:
|
||||
if has_ota() and has_mqtt_ip_lookup():
|
||||
resolved.append("MQTTIP")
|
||||
|
||||
if has_ota() and has_non_ip_address() and has_resolvable_address():
|
||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
||||
resolved.extend(_ota_hostnames_for_default(purpose))
|
||||
else:
|
||||
resolved.append(device)
|
||||
if not resolved:
|
||||
@@ -281,22 +341,29 @@ def choose_upload_log_host(
|
||||
elif bootsel.permission_error:
|
||||
bootsel_permission_error = True
|
||||
|
||||
def add_ota_options() -> None:
|
||||
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
|
||||
if (discovered := _discover_mac_suffix_devices()) is not None:
|
||||
# Discovery was applicable. Use whatever we found — on empty,
|
||||
# intentionally skip the base-name fallback since with
|
||||
# name_add_mac_suffix on, the base name doesn't exist on the net.
|
||||
for host in discovered:
|
||||
options.append((f"Over The Air ({host})", host))
|
||||
elif has_resolvable_address():
|
||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||
if has_mqtt_ip_lookup():
|
||||
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||
|
||||
if purpose == Purpose.LOGGING:
|
||||
if has_mqtt_logging():
|
||||
mqtt_config = CORE.config[CONF_MQTT]
|
||||
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
||||
|
||||
if has_api():
|
||||
if has_resolvable_address():
|
||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||
if has_mqtt_ip_lookup():
|
||||
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||
add_ota_options()
|
||||
|
||||
elif purpose == Purpose.UPLOADING and has_ota():
|
||||
if has_resolvable_address():
|
||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||
if has_mqtt_ip_lookup():
|
||||
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||
add_ota_options()
|
||||
|
||||
# Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found
|
||||
if (
|
||||
@@ -407,7 +474,17 @@ def has_resolvable_address() -> bool:
|
||||
return not CORE.address.endswith(".local")
|
||||
|
||||
|
||||
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str):
|
||||
def has_name_add_mac_suffix() -> bool:
|
||||
"""Check if name_add_mac_suffix is enabled in the config."""
|
||||
if CORE.config is None:
|
||||
return False
|
||||
esphome_config = CORE.config.get(CONF_ESPHOME, {})
|
||||
return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False)
|
||||
|
||||
|
||||
def mqtt_get_ip(
|
||||
config: ConfigType, username: str, password: str, client_id: str
|
||||
) -> list[str]:
|
||||
from esphome import mqtt
|
||||
|
||||
return mqtt.get_esphome_device_ip(config, username, password, client_id)
|
||||
@@ -420,6 +497,9 @@ def _resolve_network_devices(
|
||||
|
||||
This function filters the devices list to:
|
||||
- Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup
|
||||
- Expand hostnames that are already in ``CORE.address_cache`` to their
|
||||
cached IPs so downstream code (e.g. aioesphomeapi) doesn't open a second
|
||||
Zeroconf client to resolve them
|
||||
- Deduplicate addresses while preserving order
|
||||
- Only resolve MQTT once even if multiple MQTT strings are present
|
||||
- If MQTT resolution fails, log a warning and continue with other devices
|
||||
@@ -444,13 +524,29 @@ def _resolve_network_devices(
|
||||
mqtt_ips = mqtt_get_ip(
|
||||
config, args.username, args.password, args.client_id
|
||||
)
|
||||
network_devices.extend(mqtt_ips)
|
||||
# pylint can't infer mqtt_get_ip's return through its
|
||||
# lazy ``from esphome import mqtt`` import, so it flags
|
||||
# the genexpr below.
|
||||
network_devices.extend(
|
||||
addr
|
||||
for addr in mqtt_ips # pylint: disable=not-an-iterable
|
||||
if addr not in network_devices
|
||||
)
|
||||
except EsphomeError as err:
|
||||
_LOGGER.warning(
|
||||
"MQTT IP discovery failed (%s), will try other devices if available",
|
||||
err,
|
||||
)
|
||||
mqtt_resolved = True
|
||||
continue
|
||||
|
||||
# If the hostname is already in the address cache (e.g. populated by
|
||||
# mDNS discovery), substitute the cached IPs so aioesphomeapi doesn't
|
||||
# open its own Zeroconf to re-resolve it.
|
||||
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(device)):
|
||||
network_devices.extend(
|
||||
addr for addr in cached if addr not in network_devices
|
||||
)
|
||||
elif device not in network_devices:
|
||||
# Regular network address or IP - add if not already present
|
||||
network_devices.append(device)
|
||||
|
||||
@@ -101,6 +101,17 @@ class AddressCache:
|
||||
"""Check if any cache entries exist."""
|
||||
return bool(self.mdns_cache or self.dns_cache)
|
||||
|
||||
def add_mdns_addresses(self, hostname: str, addresses: list[str]) -> None:
|
||||
"""Store resolved mDNS addresses for ``hostname`` in the cache.
|
||||
|
||||
Callers that discover ``.local`` hosts (e.g. via mDNS browse) can use
|
||||
this to avoid a second resolution round-trip during the upload path.
|
||||
No-op when ``addresses`` is empty.
|
||||
"""
|
||||
if not addresses:
|
||||
return
|
||||
self.mdns_cache[normalize_hostname(hostname)] = addresses
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(
|
||||
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Helpers for running an async coroutine from sync code via a daemon thread.
|
||||
|
||||
``asyncio.run(coro())`` in the main thread blocks until the loop's cleanup
|
||||
cycle finishes, which can add hundreds of milliseconds before the caller
|
||||
receives the result. Running the loop in a daemon thread lets the caller
|
||||
observe the result as soon as the coroutine completes while cleanup finishes
|
||||
in the background.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
import threading
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class AsyncThreadRunner(threading.Thread, Generic[_T]):
|
||||
"""Run an async coroutine in a daemon thread and expose its result.
|
||||
|
||||
The runner catches all exceptions from the coroutine and stores them in
|
||||
``exception`` so ``event`` is always set — this prevents callers waiting
|
||||
on ``event`` from hanging forever when the coroutine crashes.
|
||||
|
||||
Typical usage::
|
||||
|
||||
runner = AsyncThreadRunner(lambda: my_coro(arg))
|
||||
runner.start()
|
||||
if not runner.event.wait(timeout=5.0):
|
||||
... # timed out
|
||||
if runner.exception is not None:
|
||||
raise runner.exception
|
||||
result = runner.result
|
||||
"""
|
||||
|
||||
def __init__(self, coro_factory: Callable[[], Awaitable[_T]]) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self._coro_factory = coro_factory
|
||||
self.result: _T | None = None
|
||||
self.exception: BaseException | None = None
|
||||
self.event = threading.Event()
|
||||
|
||||
async def _runner(self) -> None:
|
||||
try:
|
||||
self.result = await self._coro_factory()
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
# Capture all exceptions so ``event`` is always set — otherwise a
|
||||
# crash would hang the waiter forever.
|
||||
self.exception = exc
|
||||
finally:
|
||||
self.event.set()
|
||||
|
||||
def run(self) -> None:
|
||||
asyncio.run(self._runner())
|
||||
+17
-31
@@ -2,66 +2,52 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
||||
from esphome.async_thread import AsyncThreadRunner
|
||||
from esphome.core import EsphomeError
|
||||
|
||||
RESOLVE_TIMEOUT = 10.0 # seconds
|
||||
|
||||
|
||||
class AsyncResolver(threading.Thread):
|
||||
class AsyncResolver:
|
||||
"""Resolver using aioesphomeapi that runs in a thread for faster results.
|
||||
|
||||
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
|
||||
including proper .local domain fallback. Running in a thread allows us to get
|
||||
the result immediately without waiting for asyncio.run() to complete its
|
||||
cleanup cycle, which can take significant time.
|
||||
This resolver uses aioesphomeapi's async_resolve_host to handle DNS
|
||||
resolution, including proper .local domain fallback. Running in a thread
|
||||
(via :class:`AsyncThreadRunner`) allows us to get the result immediately
|
||||
without waiting for ``asyncio.run()`` to complete its cleanup cycle, which
|
||||
can take significant time.
|
||||
"""
|
||||
|
||||
def __init__(self, hosts: list[str], port: int) -> None:
|
||||
"""Initialize the resolver."""
|
||||
super().__init__(daemon=True)
|
||||
self.hosts = hosts
|
||||
self.port = port
|
||||
self.result: list[hr.AddrInfo] | None = None
|
||||
self.exception: Exception | None = None
|
||||
self.event = threading.Event()
|
||||
|
||||
async def _resolve(self) -> None:
|
||||
async def _resolve(self) -> list[hr.AddrInfo]:
|
||||
"""Resolve hostnames to IP addresses."""
|
||||
try:
|
||||
self.result = await hr.async_resolve_host(
|
||||
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
|
||||
)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# We need to catch all exceptions to ensure the event is set
|
||||
# Otherwise the thread could hang forever
|
||||
self.exception = e
|
||||
finally:
|
||||
self.event.set()
|
||||
|
||||
def run(self) -> None:
|
||||
"""Run the DNS resolution."""
|
||||
asyncio.run(self._resolve())
|
||||
return await hr.async_resolve_host(
|
||||
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
|
||||
)
|
||||
|
||||
def resolve(self) -> list[hr.AddrInfo]:
|
||||
"""Start the thread and wait for the result."""
|
||||
self.start()
|
||||
runner: AsyncThreadRunner[list[hr.AddrInfo]] = AsyncThreadRunner(self._resolve)
|
||||
runner.start()
|
||||
|
||||
if not self.event.wait(
|
||||
if not runner.event.wait(
|
||||
timeout=RESOLVE_TIMEOUT + 1.0
|
||||
): # Give it 1 second more than the resolver timeout
|
||||
raise EsphomeError("Timeout resolving IP address")
|
||||
|
||||
if exc := self.exception:
|
||||
if exc := runner.exception:
|
||||
if isinstance(exc, ResolveTimeoutAPIError):
|
||||
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
|
||||
if isinstance(exc, ResolveAPIError):
|
||||
raise EsphomeError(f"Error resolving IP address: {exc}") from exc
|
||||
raise exc
|
||||
|
||||
return self.result
|
||||
assert runner.result is not None # guaranteed when event set and no exception
|
||||
return runner.result
|
||||
|
||||
+174
-7
@@ -14,8 +14,13 @@ from zeroconf import (
|
||||
)
|
||||
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
from esphome.async_thread import AsyncThreadRunner
|
||||
from esphome.storage_json import StorageJSON, ext_storage_path
|
||||
|
||||
# Length of the MAC suffix appended when name_add_mac_suffix is enabled.
|
||||
MAC_SUFFIX_LEN = 6
|
||||
_HEX_CHARS = frozenset("0123456789abcdef")
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT = 10.0
|
||||
@@ -188,15 +193,177 @@ class EsphomeZeroconf(Zeroconf):
|
||||
return None
|
||||
|
||||
|
||||
async def async_resolve_hosts(
|
||||
zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT
|
||||
) -> dict[str, list[str]]:
|
||||
"""Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance.
|
||||
|
||||
Tries the cache synchronously first (so hosts already primed by a recent
|
||||
browse return immediately with no network round-trip), then issues
|
||||
``async_request`` for the remaining misses in parallel via
|
||||
``asyncio.gather``. Returns a dict mapping each host to its list of
|
||||
addresses (empty list when unresolved). Only ``<short>.local`` form is
|
||||
queried, matching the name scheme the resolvers below expect.
|
||||
"""
|
||||
resolvers: dict[str, AddressResolver] = {}
|
||||
pending: list[str] = []
|
||||
for host in hosts:
|
||||
resolver = AddressResolver(f"{host.partition('.')[0]}.local.")
|
||||
resolvers[host] = resolver
|
||||
if not resolver.load_from_cache(zeroconf):
|
||||
pending.append(host)
|
||||
|
||||
if pending and timeout:
|
||||
results = await asyncio.gather(
|
||||
*(
|
||||
resolvers[host].async_request(zeroconf, timeout * 1000)
|
||||
for host in pending
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for host, result in zip(pending, results):
|
||||
if isinstance(result, BaseException):
|
||||
_LOGGER.debug("Failed to resolve %s: %s", host, result)
|
||||
|
||||
return {
|
||||
host: resolver.parsed_scoped_addresses(IPVersion.All)
|
||||
for host, resolver in resolvers.items()
|
||||
}
|
||||
|
||||
|
||||
class AsyncEsphomeZeroconf(AsyncZeroconf):
|
||||
async def async_resolve_host(
|
||||
self, host: str, timeout: float = DEFAULT_TIMEOUT
|
||||
) -> list[str] | None:
|
||||
"""Resolve a host name to an IP address."""
|
||||
info = AddressResolver(f"{host.partition('.')[0]}.local.")
|
||||
if (
|
||||
info.load_from_cache(self.zeroconf)
|
||||
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
|
||||
) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
|
||||
return addresses
|
||||
return None
|
||||
addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host]
|
||||
return addresses or None
|
||||
|
||||
|
||||
def _is_mac_suffix_match(device_name: str, prefix: str) -> bool:
|
||||
"""Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC."""
|
||||
if not device_name.startswith(prefix):
|
||||
return False
|
||||
suffix = device_name[len(prefix) :]
|
||||
return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix)
|
||||
|
||||
|
||||
async def async_discover_mdns_devices(
|
||||
base_name: str, timeout: float = 5.0
|
||||
) -> dict[str, list[str]]:
|
||||
"""Discover ESPHome devices via mDNS that match the base name + MAC suffix.
|
||||
|
||||
When ``name_add_mac_suffix`` is enabled, devices advertise as
|
||||
``<base_name>-<6-hex-mac>.local``. This function uses a single
|
||||
``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and
|
||||
resolve their IP addresses, so callers get resolved addresses without
|
||||
opening a second Zeroconf client.
|
||||
|
||||
Args:
|
||||
base_name: The base device name (without MAC suffix).
|
||||
timeout: How long to wait for mDNS responses (default 5 seconds).
|
||||
|
||||
Returns:
|
||||
Mapping of ``<device>.local`` hostnames to their resolved IP addresses
|
||||
(may be empty for a device if resolution failed within the timeout).
|
||||
"""
|
||||
prefix = f"{base_name}-"
|
||||
# Preserves insertion order for stable output and deduplicates
|
||||
discovered: dict[str, list[str]] = {}
|
||||
|
||||
def on_service_state_change(
|
||||
zeroconf: Zeroconf,
|
||||
service_type: str,
|
||||
name: str,
|
||||
state_change: ServiceStateChange,
|
||||
) -> None:
|
||||
if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated):
|
||||
return
|
||||
device_name = name.partition(".")[0]
|
||||
if not _is_mac_suffix_match(device_name, prefix):
|
||||
_LOGGER.debug(
|
||||
"Ignoring %s (%s): does not match '%s<6-hex>'",
|
||||
device_name,
|
||||
state_change.name,
|
||||
prefix,
|
||||
)
|
||||
return
|
||||
host = f"{device_name}.local"
|
||||
if host in discovered:
|
||||
return
|
||||
discovered[host] = []
|
||||
_LOGGER.debug("Discovered %s (%s)", host, state_change.name)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Starting mDNS discovery for '%s<mac>.local' (timeout=%.1fs)",
|
||||
prefix,
|
||||
timeout,
|
||||
)
|
||||
try:
|
||||
aiozc = AsyncEsphomeZeroconf()
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
# Zeroconf init can raise OSError, NonUniqueNameException, etc.
|
||||
# Any failure here just means we can't discover — log and move on.
|
||||
_LOGGER.warning("mDNS discovery failed to initialize: %s", err)
|
||||
return {}
|
||||
|
||||
try:
|
||||
browser = AsyncServiceBrowser(
|
||||
aiozc.zeroconf,
|
||||
ESPHOME_SERVICE_TYPE,
|
||||
handlers=[on_service_state_change],
|
||||
)
|
||||
try:
|
||||
await asyncio.sleep(timeout)
|
||||
finally:
|
||||
await browser.async_cancel()
|
||||
_LOGGER.debug(
|
||||
"Browse finished: %d device(s) matched '%s<mac>'",
|
||||
len(discovered),
|
||||
prefix,
|
||||
)
|
||||
|
||||
# Resolve each discovered hostname on the SAME Zeroconf instance so
|
||||
# we don't spin up a second client. ``async_resolve_hosts`` tries the
|
||||
# cache synchronously (the browse usually primes it) before issuing
|
||||
# any ``async_request`` in parallel for misses.
|
||||
resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered))
|
||||
for host, addresses in resolved.items():
|
||||
if addresses:
|
||||
discovered[host] = addresses
|
||||
_LOGGER.debug("Resolved %s -> %s", host, addresses)
|
||||
else:
|
||||
_LOGGER.debug("No addresses returned for %s", host)
|
||||
finally:
|
||||
await aiozc.async_close()
|
||||
|
||||
return dict(sorted(discovered.items()))
|
||||
|
||||
|
||||
def _await_discovery(
|
||||
runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float
|
||||
) -> dict[str, list[str]]:
|
||||
"""Wait for ``runner`` to finish and return its discovery result.
|
||||
|
||||
Split out of :func:`discover_mdns_devices` so the timeout branch is
|
||||
testable without patching ``asyncio`` or ``threading`` internals — a test
|
||||
passes a stub whose ``event.wait`` returns ``False``.
|
||||
"""
|
||||
# Give the discovery an extra second over the browse timeout for the
|
||||
# resolution + cleanup pass.
|
||||
if not runner.event.wait(timeout=timeout + 2.0):
|
||||
_LOGGER.warning("mDNS discovery timed out after %.1fs", timeout)
|
||||
return {}
|
||||
if runner.exception is not None:
|
||||
_LOGGER.warning("mDNS discovery failed: %s", runner.exception)
|
||||
return {}
|
||||
return runner.result or {}
|
||||
|
||||
|
||||
def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]:
|
||||
"""Synchronous wrapper around :func:`async_discover_mdns_devices`."""
|
||||
runner = AsyncThreadRunner(
|
||||
lambda: async_discover_mdns_devices(base_name, timeout=timeout)
|
||||
)
|
||||
runner.start()
|
||||
return _await_discovery(runner, timeout)
|
||||
|
||||
@@ -121,6 +121,26 @@ def test_get_addresses_auto_detection() -> None:
|
||||
assert cache.get_addresses("unknown.com") is None
|
||||
|
||||
|
||||
def test_add_mdns_addresses_stores_and_normalizes() -> None:
|
||||
"""add_mdns_addresses inserts entries under the normalized hostname."""
|
||||
cache = AddressCache()
|
||||
cache.add_mdns_addresses("Device.Local.", ["192.168.1.10", "192.168.1.11"])
|
||||
|
||||
assert cache.mdns_cache == {
|
||||
normalize_hostname("Device.Local."): ["192.168.1.10", "192.168.1.11"]
|
||||
}
|
||||
# Overwrites on subsequent calls for the same host
|
||||
cache.add_mdns_addresses("device.local", ["10.0.0.1"])
|
||||
assert cache.mdns_cache[normalize_hostname("device.local")] == ["10.0.0.1"]
|
||||
|
||||
|
||||
def test_add_mdns_addresses_empty_is_noop() -> None:
|
||||
"""Passing an empty address list must not create an entry."""
|
||||
cache = AddressCache()
|
||||
cache.add_mdns_addresses("device.local", [])
|
||||
assert cache.mdns_cache == {}
|
||||
|
||||
|
||||
def test_has_cache() -> None:
|
||||
"""Test checking if cache has entries."""
|
||||
# Empty cache
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
|
||||
@@ -115,24 +115,21 @@ def test_async_resolver_generic_exception() -> None:
|
||||
|
||||
|
||||
def test_async_resolver_thread_timeout() -> None:
|
||||
"""Test timeout when thread doesn't complete in time."""
|
||||
# Mock the start method to prevent actual thread execution
|
||||
with (
|
||||
patch.object(AsyncResolver, "start"),
|
||||
patch("esphome.resolver.hr.async_resolve_host"),
|
||||
):
|
||||
resolver = AsyncResolver(["test.local"], 6053)
|
||||
# Override event.wait to simulate timeout (return False = timeout occurred)
|
||||
with (
|
||||
patch.object(resolver.event, "wait", return_value=False),
|
||||
pytest.raises(
|
||||
EsphomeError, match=re.escape("Timeout resolving IP address")
|
||||
),
|
||||
):
|
||||
resolver.resolve()
|
||||
"""Test timeout when the runner thread doesn't complete in time."""
|
||||
# Patch AsyncThreadRunner inside esphome.resolver so we never actually
|
||||
# start a thread and can control the wait return value directly.
|
||||
fake_runner = MagicMock()
|
||||
fake_runner.start = MagicMock()
|
||||
fake_runner.event.wait.return_value = False # simulate timeout
|
||||
|
||||
# Verify thread start was called
|
||||
resolver.start.assert_called_once()
|
||||
with (
|
||||
patch("esphome.resolver.AsyncThreadRunner", return_value=fake_runner),
|
||||
patch("esphome.resolver.hr.async_resolve_host"),
|
||||
pytest.raises(EsphomeError, match=re.escape("Timeout resolving IP address")),
|
||||
):
|
||||
AsyncResolver(["test.local"], 6053).resolve()
|
||||
|
||||
fake_runner.start.assert_called_once()
|
||||
|
||||
|
||||
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
|
||||
|
||||
Reference in New Issue
Block a user