mirror of
https://github.com/esphome/esphome.git
synced 2026-05-28 13:37:24 +08:00
[core] Allow finding all devices as target that match mac suffix (#13135)
This commit is contained in:
+108
-12
@@ -39,6 +39,7 @@ from esphome.const import (
|
|||||||
CONF_MDNS,
|
CONF_MDNS,
|
||||||
CONF_MQTT,
|
CONF_MQTT,
|
||||||
CONF_NAME,
|
CONF_NAME,
|
||||||
|
CONF_NAME_ADD_MAC_SUFFIX,
|
||||||
CONF_OTA,
|
CONF_OTA,
|
||||||
CONF_PASSWORD,
|
CONF_PASSWORD,
|
||||||
CONF_PLATFORM,
|
CONF_PLATFORM,
|
||||||
@@ -71,6 +72,7 @@ from esphome.util import (
|
|||||||
run_external_process,
|
run_external_process,
|
||||||
safe_print,
|
safe_print,
|
||||||
)
|
)
|
||||||
|
from esphome.zeroconf import discover_mdns_devices
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -204,6 +206,64 @@ def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
|
|||||||
return [address]
|
return [address]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_mdns_cache(hosts_to_addresses: dict[str, list[str]]) -> None:
|
||||||
|
"""Store discovered ``host -> [ips]`` entries in ``CORE.address_cache``.
|
||||||
|
|
||||||
|
Ensures ``CORE.address_cache`` exists, then records each mDNS hostname so
|
||||||
|
the downstream resolution path (``resolve_ip_address``) can skip opening a
|
||||||
|
second Zeroconf client.
|
||||||
|
"""
|
||||||
|
from esphome.address_cache import AddressCache
|
||||||
|
|
||||||
|
if CORE.address_cache is None:
|
||||||
|
CORE.address_cache = AddressCache()
|
||||||
|
for host, addresses in hosts_to_addresses.items():
|
||||||
|
if addresses:
|
||||||
|
_LOGGER.debug("Caching mDNS result %s -> %s", host, addresses)
|
||||||
|
CORE.address_cache.add_mdns_addresses(host, addresses)
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_mac_suffix_devices() -> list[str] | None:
|
||||||
|
"""Discover ``<name>-<mac>.local`` devices and cache their IPs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- ``None`` when discovery isn't applicable (``name_add_mac_suffix`` off,
|
||||||
|
mDNS disabled, or ``CORE.address`` is already an IP). Callers should
|
||||||
|
then fall back to whatever default OTA address they normally use.
|
||||||
|
- ``[]`` when discovery ran but found nothing. Callers should NOT fall
|
||||||
|
back to the base name: with ``name_add_mac_suffix`` enabled, the base
|
||||||
|
name by definition doesn't exist on the network.
|
||||||
|
- A non-empty sorted list of ``.local`` hostnames on success.
|
||||||
|
|
||||||
|
Populates ``CORE.address_cache`` so downstream resolution (``espota2`` or
|
||||||
|
``aioesphomeapi`` via :func:`_resolve_network_devices`) reuses the IPs we
|
||||||
|
already have without opening a second Zeroconf client.
|
||||||
|
"""
|
||||||
|
if not (has_name_add_mac_suffix() and has_mdns() and has_non_ip_address()):
|
||||||
|
return None
|
||||||
|
_LOGGER.info("Discovering devices...")
|
||||||
|
if not (discovered := discover_mdns_devices(CORE.name)):
|
||||||
|
_LOGGER.warning(
|
||||||
|
"No devices matching '%s-<mac>.local' were discovered.", CORE.name
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
_populate_mdns_cache(discovered)
|
||||||
|
return list(discovered)
|
||||||
|
|
||||||
|
|
||||||
|
def _ota_hostnames_for_default(purpose: Purpose) -> list[str]:
|
||||||
|
"""Return OTA hostname(s) for the ``--device OTA`` / default-resolve path.
|
||||||
|
|
||||||
|
When ``name_add_mac_suffix`` is enabled, returns discovered
|
||||||
|
``<name>-<mac>.local`` hostnames (possibly empty — in which case the
|
||||||
|
caller should not fall back to the base name). Otherwise falls back to
|
||||||
|
the cache-resolved ``CORE.address``.
|
||||||
|
"""
|
||||||
|
if (discovered := _discover_mac_suffix_devices()) is not None:
|
||||||
|
return discovered
|
||||||
|
return _resolve_with_cache(CORE.address, purpose)
|
||||||
|
|
||||||
|
|
||||||
def choose_upload_log_host(
|
def choose_upload_log_host(
|
||||||
default: list[str] | str | None,
|
default: list[str] | str | None,
|
||||||
check_default: str | None,
|
check_default: str | None,
|
||||||
@@ -242,14 +302,14 @@ def choose_upload_log_host(
|
|||||||
resolved.append("MQTT")
|
resolved.append("MQTT")
|
||||||
|
|
||||||
if has_api() and has_non_ip_address() and has_resolvable_address():
|
if has_api() and has_non_ip_address() and has_resolvable_address():
|
||||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
resolved.extend(_ota_hostnames_for_default(purpose))
|
||||||
|
|
||||||
elif purpose == Purpose.UPLOADING:
|
elif purpose == Purpose.UPLOADING:
|
||||||
if has_ota() and has_mqtt_ip_lookup():
|
if has_ota() and has_mqtt_ip_lookup():
|
||||||
resolved.append("MQTTIP")
|
resolved.append("MQTTIP")
|
||||||
|
|
||||||
if has_ota() and has_non_ip_address() and has_resolvable_address():
|
if has_ota() and has_non_ip_address() and has_resolvable_address():
|
||||||
resolved.extend(_resolve_with_cache(CORE.address, purpose))
|
resolved.extend(_ota_hostnames_for_default(purpose))
|
||||||
else:
|
else:
|
||||||
resolved.append(device)
|
resolved.append(device)
|
||||||
if not resolved:
|
if not resolved:
|
||||||
@@ -281,22 +341,29 @@ def choose_upload_log_host(
|
|||||||
elif bootsel.permission_error:
|
elif bootsel.permission_error:
|
||||||
bootsel_permission_error = True
|
bootsel_permission_error = True
|
||||||
|
|
||||||
|
def add_ota_options() -> None:
|
||||||
|
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
|
||||||
|
if (discovered := _discover_mac_suffix_devices()) is not None:
|
||||||
|
# Discovery was applicable. Use whatever we found — on empty,
|
||||||
|
# intentionally skip the base-name fallback since with
|
||||||
|
# name_add_mac_suffix on, the base name doesn't exist on the net.
|
||||||
|
for host in discovered:
|
||||||
|
options.append((f"Over The Air ({host})", host))
|
||||||
|
elif has_resolvable_address():
|
||||||
|
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
||||||
|
if has_mqtt_ip_lookup():
|
||||||
|
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
||||||
|
|
||||||
if purpose == Purpose.LOGGING:
|
if purpose == Purpose.LOGGING:
|
||||||
if has_mqtt_logging():
|
if has_mqtt_logging():
|
||||||
mqtt_config = CORE.config[CONF_MQTT]
|
mqtt_config = CORE.config[CONF_MQTT]
|
||||||
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
|
||||||
|
|
||||||
if has_api():
|
if has_api():
|
||||||
if has_resolvable_address():
|
add_ota_options()
|
||||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
|
||||||
if has_mqtt_ip_lookup():
|
|
||||||
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
|
||||||
|
|
||||||
elif purpose == Purpose.UPLOADING and has_ota():
|
elif purpose == Purpose.UPLOADING and has_ota():
|
||||||
if has_resolvable_address():
|
add_ota_options()
|
||||||
options.append((f"Over The Air ({CORE.address})", CORE.address))
|
|
||||||
if has_mqtt_ip_lookup():
|
|
||||||
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
|
|
||||||
|
|
||||||
# Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found
|
# Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found
|
||||||
if (
|
if (
|
||||||
@@ -407,7 +474,17 @@ def has_resolvable_address() -> bool:
|
|||||||
return not CORE.address.endswith(".local")
|
return not CORE.address.endswith(".local")
|
||||||
|
|
||||||
|
|
||||||
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str):
|
def has_name_add_mac_suffix() -> bool:
|
||||||
|
"""Check if name_add_mac_suffix is enabled in the config."""
|
||||||
|
if CORE.config is None:
|
||||||
|
return False
|
||||||
|
esphome_config = CORE.config.get(CONF_ESPHOME, {})
|
||||||
|
return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False)
|
||||||
|
|
||||||
|
|
||||||
|
def mqtt_get_ip(
|
||||||
|
config: ConfigType, username: str, password: str, client_id: str
|
||||||
|
) -> list[str]:
|
||||||
from esphome import mqtt
|
from esphome import mqtt
|
||||||
|
|
||||||
return mqtt.get_esphome_device_ip(config, username, password, client_id)
|
return mqtt.get_esphome_device_ip(config, username, password, client_id)
|
||||||
@@ -420,6 +497,9 @@ def _resolve_network_devices(
|
|||||||
|
|
||||||
This function filters the devices list to:
|
This function filters the devices list to:
|
||||||
- Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup
|
- Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup
|
||||||
|
- Expand hostnames that are already in ``CORE.address_cache`` to their
|
||||||
|
cached IPs so downstream code (e.g. aioesphomeapi) doesn't open a second
|
||||||
|
Zeroconf client to resolve them
|
||||||
- Deduplicate addresses while preserving order
|
- Deduplicate addresses while preserving order
|
||||||
- Only resolve MQTT once even if multiple MQTT strings are present
|
- Only resolve MQTT once even if multiple MQTT strings are present
|
||||||
- If MQTT resolution fails, log a warning and continue with other devices
|
- If MQTT resolution fails, log a warning and continue with other devices
|
||||||
@@ -444,13 +524,29 @@ def _resolve_network_devices(
|
|||||||
mqtt_ips = mqtt_get_ip(
|
mqtt_ips = mqtt_get_ip(
|
||||||
config, args.username, args.password, args.client_id
|
config, args.username, args.password, args.client_id
|
||||||
)
|
)
|
||||||
network_devices.extend(mqtt_ips)
|
# pylint can't infer mqtt_get_ip's return through its
|
||||||
|
# lazy ``from esphome import mqtt`` import, so it flags
|
||||||
|
# the genexpr below.
|
||||||
|
network_devices.extend(
|
||||||
|
addr
|
||||||
|
for addr in mqtt_ips # pylint: disable=not-an-iterable
|
||||||
|
if addr not in network_devices
|
||||||
|
)
|
||||||
except EsphomeError as err:
|
except EsphomeError as err:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"MQTT IP discovery failed (%s), will try other devices if available",
|
"MQTT IP discovery failed (%s), will try other devices if available",
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
mqtt_resolved = True
|
mqtt_resolved = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the hostname is already in the address cache (e.g. populated by
|
||||||
|
# mDNS discovery), substitute the cached IPs so aioesphomeapi doesn't
|
||||||
|
# open its own Zeroconf to re-resolve it.
|
||||||
|
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(device)):
|
||||||
|
network_devices.extend(
|
||||||
|
addr for addr in cached if addr not in network_devices
|
||||||
|
)
|
||||||
elif device not in network_devices:
|
elif device not in network_devices:
|
||||||
# Regular network address or IP - add if not already present
|
# Regular network address or IP - add if not already present
|
||||||
network_devices.append(device)
|
network_devices.append(device)
|
||||||
|
|||||||
@@ -101,6 +101,17 @@ class AddressCache:
|
|||||||
"""Check if any cache entries exist."""
|
"""Check if any cache entries exist."""
|
||||||
return bool(self.mdns_cache or self.dns_cache)
|
return bool(self.mdns_cache or self.dns_cache)
|
||||||
|
|
||||||
|
def add_mdns_addresses(self, hostname: str, addresses: list[str]) -> None:
|
||||||
|
"""Store resolved mDNS addresses for ``hostname`` in the cache.
|
||||||
|
|
||||||
|
Callers that discover ``.local`` hosts (e.g. via mDNS browse) can use
|
||||||
|
this to avoid a second resolution round-trip during the upload path.
|
||||||
|
No-op when ``addresses`` is empty.
|
||||||
|
"""
|
||||||
|
if not addresses:
|
||||||
|
return
|
||||||
|
self.mdns_cache[normalize_hostname(hostname)] = addresses
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(
|
def from_cli_args(
|
||||||
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
|
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Helpers for running an async coroutine from sync code via a daemon thread.
|
||||||
|
|
||||||
|
``asyncio.run(coro())`` in the main thread blocks until the loop's cleanup
|
||||||
|
cycle finishes, which can add hundreds of milliseconds before the caller
|
||||||
|
receives the result. Running the loop in a daemon thread lets the caller
|
||||||
|
observe the result as soon as the coroutine completes while cleanup finishes
|
||||||
|
in the background.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
import threading
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncThreadRunner(threading.Thread, Generic[_T]):
|
||||||
|
"""Run an async coroutine in a daemon thread and expose its result.
|
||||||
|
|
||||||
|
The runner catches all exceptions from the coroutine and stores them in
|
||||||
|
``exception`` so ``event`` is always set — this prevents callers waiting
|
||||||
|
on ``event`` from hanging forever when the coroutine crashes.
|
||||||
|
|
||||||
|
Typical usage::
|
||||||
|
|
||||||
|
runner = AsyncThreadRunner(lambda: my_coro(arg))
|
||||||
|
runner.start()
|
||||||
|
if not runner.event.wait(timeout=5.0):
|
||||||
|
... # timed out
|
||||||
|
if runner.exception is not None:
|
||||||
|
raise runner.exception
|
||||||
|
result = runner.result
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, coro_factory: Callable[[], Awaitable[_T]]) -> None:
|
||||||
|
super().__init__(daemon=True)
|
||||||
|
self._coro_factory = coro_factory
|
||||||
|
self.result: _T | None = None
|
||||||
|
self.exception: BaseException | None = None
|
||||||
|
self.event = threading.Event()
|
||||||
|
|
||||||
|
async def _runner(self) -> None:
|
||||||
|
try:
|
||||||
|
self.result = await self._coro_factory()
|
||||||
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
|
# Capture all exceptions so ``event`` is always set — otherwise a
|
||||||
|
# crash would hang the waiter forever.
|
||||||
|
self.exception = exc
|
||||||
|
finally:
|
||||||
|
self.event.set()
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
asyncio.run(self._runner())
|
||||||
+17
-31
@@ -2,66 +2,52 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
||||||
import aioesphomeapi.host_resolver as hr
|
import aioesphomeapi.host_resolver as hr
|
||||||
|
|
||||||
|
from esphome.async_thread import AsyncThreadRunner
|
||||||
from esphome.core import EsphomeError
|
from esphome.core import EsphomeError
|
||||||
|
|
||||||
RESOLVE_TIMEOUT = 10.0 # seconds
|
RESOLVE_TIMEOUT = 10.0 # seconds
|
||||||
|
|
||||||
|
|
||||||
class AsyncResolver(threading.Thread):
|
class AsyncResolver:
|
||||||
"""Resolver using aioesphomeapi that runs in a thread for faster results.
|
"""Resolver using aioesphomeapi that runs in a thread for faster results.
|
||||||
|
|
||||||
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
|
This resolver uses aioesphomeapi's async_resolve_host to handle DNS
|
||||||
including proper .local domain fallback. Running in a thread allows us to get
|
resolution, including proper .local domain fallback. Running in a thread
|
||||||
the result immediately without waiting for asyncio.run() to complete its
|
(via :class:`AsyncThreadRunner`) allows us to get the result immediately
|
||||||
cleanup cycle, which can take significant time.
|
without waiting for ``asyncio.run()`` to complete its cleanup cycle, which
|
||||||
|
can take significant time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hosts: list[str], port: int) -> None:
|
def __init__(self, hosts: list[str], port: int) -> None:
|
||||||
"""Initialize the resolver."""
|
"""Initialize the resolver."""
|
||||||
super().__init__(daemon=True)
|
|
||||||
self.hosts = hosts
|
self.hosts = hosts
|
||||||
self.port = port
|
self.port = port
|
||||||
self.result: list[hr.AddrInfo] | None = None
|
|
||||||
self.exception: Exception | None = None
|
|
||||||
self.event = threading.Event()
|
|
||||||
|
|
||||||
async def _resolve(self) -> None:
|
async def _resolve(self) -> list[hr.AddrInfo]:
|
||||||
"""Resolve hostnames to IP addresses."""
|
"""Resolve hostnames to IP addresses."""
|
||||||
try:
|
return await hr.async_resolve_host(
|
||||||
self.result = await hr.async_resolve_host(
|
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
|
||||||
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
|
)
|
||||||
)
|
|
||||||
except Exception as e: # pylint: disable=broad-except
|
|
||||||
# We need to catch all exceptions to ensure the event is set
|
|
||||||
# Otherwise the thread could hang forever
|
|
||||||
self.exception = e
|
|
||||||
finally:
|
|
||||||
self.event.set()
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
"""Run the DNS resolution."""
|
|
||||||
asyncio.run(self._resolve())
|
|
||||||
|
|
||||||
def resolve(self) -> list[hr.AddrInfo]:
|
def resolve(self) -> list[hr.AddrInfo]:
|
||||||
"""Start the thread and wait for the result."""
|
"""Start the thread and wait for the result."""
|
||||||
self.start()
|
runner: AsyncThreadRunner[list[hr.AddrInfo]] = AsyncThreadRunner(self._resolve)
|
||||||
|
runner.start()
|
||||||
|
|
||||||
if not self.event.wait(
|
if not runner.event.wait(
|
||||||
timeout=RESOLVE_TIMEOUT + 1.0
|
timeout=RESOLVE_TIMEOUT + 1.0
|
||||||
): # Give it 1 second more than the resolver timeout
|
): # Give it 1 second more than the resolver timeout
|
||||||
raise EsphomeError("Timeout resolving IP address")
|
raise EsphomeError("Timeout resolving IP address")
|
||||||
|
|
||||||
if exc := self.exception:
|
if exc := runner.exception:
|
||||||
if isinstance(exc, ResolveTimeoutAPIError):
|
if isinstance(exc, ResolveTimeoutAPIError):
|
||||||
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
|
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
|
||||||
if isinstance(exc, ResolveAPIError):
|
if isinstance(exc, ResolveAPIError):
|
||||||
raise EsphomeError(f"Error resolving IP address: {exc}") from exc
|
raise EsphomeError(f"Error resolving IP address: {exc}") from exc
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
return self.result
|
assert runner.result is not None # guaranteed when event set and no exception
|
||||||
|
return runner.result
|
||||||
|
|||||||
+174
-7
@@ -14,8 +14,13 @@ from zeroconf import (
|
|||||||
)
|
)
|
||||||
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
|
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
|
||||||
|
|
||||||
|
from esphome.async_thread import AsyncThreadRunner
|
||||||
from esphome.storage_json import StorageJSON, ext_storage_path
|
from esphome.storage_json import StorageJSON, ext_storage_path
|
||||||
|
|
||||||
|
# Length of the MAC suffix appended when name_add_mac_suffix is enabled.
|
||||||
|
MAC_SUFFIX_LEN = 6
|
||||||
|
_HEX_CHARS = frozenset("0123456789abcdef")
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 10.0
|
DEFAULT_TIMEOUT = 10.0
|
||||||
@@ -188,15 +193,177 @@ class EsphomeZeroconf(Zeroconf):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def async_resolve_hosts(
|
||||||
|
zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance.
|
||||||
|
|
||||||
|
Tries the cache synchronously first (so hosts already primed by a recent
|
||||||
|
browse return immediately with no network round-trip), then issues
|
||||||
|
``async_request`` for the remaining misses in parallel via
|
||||||
|
``asyncio.gather``. Returns a dict mapping each host to its list of
|
||||||
|
addresses (empty list when unresolved). Only ``<short>.local`` form is
|
||||||
|
queried, matching the name scheme the resolvers below expect.
|
||||||
|
"""
|
||||||
|
resolvers: dict[str, AddressResolver] = {}
|
||||||
|
pending: list[str] = []
|
||||||
|
for host in hosts:
|
||||||
|
resolver = AddressResolver(f"{host.partition('.')[0]}.local.")
|
||||||
|
resolvers[host] = resolver
|
||||||
|
if not resolver.load_from_cache(zeroconf):
|
||||||
|
pending.append(host)
|
||||||
|
|
||||||
|
if pending and timeout:
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
resolvers[host].async_request(zeroconf, timeout * 1000)
|
||||||
|
for host in pending
|
||||||
|
),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
for host, result in zip(pending, results):
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
_LOGGER.debug("Failed to resolve %s: %s", host, result)
|
||||||
|
|
||||||
|
return {
|
||||||
|
host: resolver.parsed_scoped_addresses(IPVersion.All)
|
||||||
|
for host, resolver in resolvers.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AsyncEsphomeZeroconf(AsyncZeroconf):
|
class AsyncEsphomeZeroconf(AsyncZeroconf):
|
||||||
async def async_resolve_host(
|
async def async_resolve_host(
|
||||||
self, host: str, timeout: float = DEFAULT_TIMEOUT
|
self, host: str, timeout: float = DEFAULT_TIMEOUT
|
||||||
) -> list[str] | None:
|
) -> list[str] | None:
|
||||||
"""Resolve a host name to an IP address."""
|
"""Resolve a host name to an IP address."""
|
||||||
info = AddressResolver(f"{host.partition('.')[0]}.local.")
|
addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host]
|
||||||
if (
|
return addresses or None
|
||||||
info.load_from_cache(self.zeroconf)
|
|
||||||
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
|
|
||||||
) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
|
def _is_mac_suffix_match(device_name: str, prefix: str) -> bool:
|
||||||
return addresses
|
"""Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC."""
|
||||||
return None
|
if not device_name.startswith(prefix):
|
||||||
|
return False
|
||||||
|
suffix = device_name[len(prefix) :]
|
||||||
|
return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_discover_mdns_devices(
|
||||||
|
base_name: str, timeout: float = 5.0
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Discover ESPHome devices via mDNS that match the base name + MAC suffix.
|
||||||
|
|
||||||
|
When ``name_add_mac_suffix`` is enabled, devices advertise as
|
||||||
|
``<base_name>-<6-hex-mac>.local``. This function uses a single
|
||||||
|
``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and
|
||||||
|
resolve their IP addresses, so callers get resolved addresses without
|
||||||
|
opening a second Zeroconf client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_name: The base device name (without MAC suffix).
|
||||||
|
timeout: How long to wait for mDNS responses (default 5 seconds).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping of ``<device>.local`` hostnames to their resolved IP addresses
|
||||||
|
(may be empty for a device if resolution failed within the timeout).
|
||||||
|
"""
|
||||||
|
prefix = f"{base_name}-"
|
||||||
|
# Preserves insertion order for stable output and deduplicates
|
||||||
|
discovered: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
def on_service_state_change(
|
||||||
|
zeroconf: Zeroconf,
|
||||||
|
service_type: str,
|
||||||
|
name: str,
|
||||||
|
state_change: ServiceStateChange,
|
||||||
|
) -> None:
|
||||||
|
if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated):
|
||||||
|
return
|
||||||
|
device_name = name.partition(".")[0]
|
||||||
|
if not _is_mac_suffix_match(device_name, prefix):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Ignoring %s (%s): does not match '%s<6-hex>'",
|
||||||
|
device_name,
|
||||||
|
state_change.name,
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
host = f"{device_name}.local"
|
||||||
|
if host in discovered:
|
||||||
|
return
|
||||||
|
discovered[host] = []
|
||||||
|
_LOGGER.debug("Discovered %s (%s)", host, state_change.name)
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Starting mDNS discovery for '%s<mac>.local' (timeout=%.1fs)",
|
||||||
|
prefix,
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
aiozc = AsyncEsphomeZeroconf()
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
# Zeroconf init can raise OSError, NonUniqueNameException, etc.
|
||||||
|
# Any failure here just means we can't discover — log and move on.
|
||||||
|
_LOGGER.warning("mDNS discovery failed to initialize: %s", err)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
browser = AsyncServiceBrowser(
|
||||||
|
aiozc.zeroconf,
|
||||||
|
ESPHOME_SERVICE_TYPE,
|
||||||
|
handlers=[on_service_state_change],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
finally:
|
||||||
|
await browser.async_cancel()
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Browse finished: %d device(s) matched '%s<mac>'",
|
||||||
|
len(discovered),
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve each discovered hostname on the SAME Zeroconf instance so
|
||||||
|
# we don't spin up a second client. ``async_resolve_hosts`` tries the
|
||||||
|
# cache synchronously (the browse usually primes it) before issuing
|
||||||
|
# any ``async_request`` in parallel for misses.
|
||||||
|
resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered))
|
||||||
|
for host, addresses in resolved.items():
|
||||||
|
if addresses:
|
||||||
|
discovered[host] = addresses
|
||||||
|
_LOGGER.debug("Resolved %s -> %s", host, addresses)
|
||||||
|
else:
|
||||||
|
_LOGGER.debug("No addresses returned for %s", host)
|
||||||
|
finally:
|
||||||
|
await aiozc.async_close()
|
||||||
|
|
||||||
|
return dict(sorted(discovered.items()))
|
||||||
|
|
||||||
|
|
||||||
|
def _await_discovery(
|
||||||
|
runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Wait for ``runner`` to finish and return its discovery result.
|
||||||
|
|
||||||
|
Split out of :func:`discover_mdns_devices` so the timeout branch is
|
||||||
|
testable without patching ``asyncio`` or ``threading`` internals — a test
|
||||||
|
passes a stub whose ``event.wait`` returns ``False``.
|
||||||
|
"""
|
||||||
|
# Give the discovery an extra second over the browse timeout for the
|
||||||
|
# resolution + cleanup pass.
|
||||||
|
if not runner.event.wait(timeout=timeout + 2.0):
|
||||||
|
_LOGGER.warning("mDNS discovery timed out after %.1fs", timeout)
|
||||||
|
return {}
|
||||||
|
if runner.exception is not None:
|
||||||
|
_LOGGER.warning("mDNS discovery failed: %s", runner.exception)
|
||||||
|
return {}
|
||||||
|
return runner.result or {}
|
||||||
|
|
||||||
|
|
||||||
|
def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]:
|
||||||
|
"""Synchronous wrapper around :func:`async_discover_mdns_devices`."""
|
||||||
|
runner = AsyncThreadRunner(
|
||||||
|
lambda: async_discover_mdns_devices(base_name, timeout=timeout)
|
||||||
|
)
|
||||||
|
runner.start()
|
||||||
|
return _await_discovery(runner, timeout)
|
||||||
|
|||||||
@@ -121,6 +121,26 @@ def test_get_addresses_auto_detection() -> None:
|
|||||||
assert cache.get_addresses("unknown.com") is None
|
assert cache.get_addresses("unknown.com") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_mdns_addresses_stores_and_normalizes() -> None:
|
||||||
|
"""add_mdns_addresses inserts entries under the normalized hostname."""
|
||||||
|
cache = AddressCache()
|
||||||
|
cache.add_mdns_addresses("Device.Local.", ["192.168.1.10", "192.168.1.11"])
|
||||||
|
|
||||||
|
assert cache.mdns_cache == {
|
||||||
|
normalize_hostname("Device.Local."): ["192.168.1.10", "192.168.1.11"]
|
||||||
|
}
|
||||||
|
# Overwrites on subsequent calls for the same host
|
||||||
|
cache.add_mdns_addresses("device.local", ["10.0.0.1"])
|
||||||
|
assert cache.mdns_cache[normalize_hostname("device.local")] == ["10.0.0.1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_mdns_addresses_empty_is_noop() -> None:
|
||||||
|
"""Passing an empty address list must not create an entry."""
|
||||||
|
cache = AddressCache()
|
||||||
|
cache.add_mdns_addresses("device.local", [])
|
||||||
|
assert cache.mdns_cache == {}
|
||||||
|
|
||||||
|
|
||||||
def test_has_cache() -> None:
|
def test_has_cache() -> None:
|
||||||
"""Test checking if cache has entries."""
|
"""Test checking if cache has entries."""
|
||||||
# Empty cache
|
# Empty cache
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
|
||||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
|
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
|
||||||
@@ -115,24 +115,21 @@ def test_async_resolver_generic_exception() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_async_resolver_thread_timeout() -> None:
|
def test_async_resolver_thread_timeout() -> None:
|
||||||
"""Test timeout when thread doesn't complete in time."""
|
"""Test timeout when the runner thread doesn't complete in time."""
|
||||||
# Mock the start method to prevent actual thread execution
|
# Patch AsyncThreadRunner inside esphome.resolver so we never actually
|
||||||
with (
|
# start a thread and can control the wait return value directly.
|
||||||
patch.object(AsyncResolver, "start"),
|
fake_runner = MagicMock()
|
||||||
patch("esphome.resolver.hr.async_resolve_host"),
|
fake_runner.start = MagicMock()
|
||||||
):
|
fake_runner.event.wait.return_value = False # simulate timeout
|
||||||
resolver = AsyncResolver(["test.local"], 6053)
|
|
||||||
# Override event.wait to simulate timeout (return False = timeout occurred)
|
|
||||||
with (
|
|
||||||
patch.object(resolver.event, "wait", return_value=False),
|
|
||||||
pytest.raises(
|
|
||||||
EsphomeError, match=re.escape("Timeout resolving IP address")
|
|
||||||
),
|
|
||||||
):
|
|
||||||
resolver.resolve()
|
|
||||||
|
|
||||||
# Verify thread start was called
|
with (
|
||||||
resolver.start.assert_called_once()
|
patch("esphome.resolver.AsyncThreadRunner", return_value=fake_runner),
|
||||||
|
patch("esphome.resolver.hr.async_resolve_host"),
|
||||||
|
pytest.raises(EsphomeError, match=re.escape("Timeout resolving IP address")),
|
||||||
|
):
|
||||||
|
AsyncResolver(["test.local"], 6053).resolve()
|
||||||
|
|
||||||
|
fake_runner.start.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
|
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user