[core] Allow finding all devices as target that match mac suffix (#13135)

This commit is contained in:
Paulus Schoutsen
2026-04-23 09:43:32 -04:00
committed by GitHub
parent 70ae614abd
commit 9b45b046a8
8 changed files with 912 additions and 70 deletions
+108 -12
View File
@@ -39,6 +39,7 @@ from esphome.const import (
CONF_MDNS,
CONF_MQTT,
CONF_NAME,
CONF_NAME_ADD_MAC_SUFFIX,
CONF_OTA,
CONF_PASSWORD,
CONF_PLATFORM,
@@ -71,6 +72,7 @@ from esphome.util import (
run_external_process,
safe_print,
)
from esphome.zeroconf import discover_mdns_devices
_LOGGER = logging.getLogger(__name__)
@@ -204,6 +206,64 @@ def _resolve_with_cache(address: str, purpose: Purpose) -> list[str]:
return [address]
def _populate_mdns_cache(hosts_to_addresses: dict[str, list[str]]) -> None:
"""Store discovered ``host -> [ips]`` entries in ``CORE.address_cache``.
Ensures ``CORE.address_cache`` exists, then records each mDNS hostname so
the downstream resolution path (``resolve_ip_address``) can skip opening a
second Zeroconf client.
"""
from esphome.address_cache import AddressCache
if CORE.address_cache is None:
CORE.address_cache = AddressCache()
for host, addresses in hosts_to_addresses.items():
if addresses:
_LOGGER.debug("Caching mDNS result %s -> %s", host, addresses)
CORE.address_cache.add_mdns_addresses(host, addresses)
def _discover_mac_suffix_devices() -> list[str] | None:
"""Discover ``<name>-<mac>.local`` devices and cache their IPs.
Returns:
- ``None`` when discovery isn't applicable (``name_add_mac_suffix`` off,
mDNS disabled, or ``CORE.address`` is already an IP). Callers should
then fall back to whatever default OTA address they normally use.
- ``[]`` when discovery ran but found nothing. Callers should NOT fall
back to the base name: with ``name_add_mac_suffix`` enabled, the base
name by definition doesn't exist on the network.
- A non-empty sorted list of ``.local`` hostnames on success.
Populates ``CORE.address_cache`` so downstream resolution (``espota2`` or
``aioesphomeapi`` via :func:`_resolve_network_devices`) reuses the IPs we
already have without opening a second Zeroconf client.
"""
if not (has_name_add_mac_suffix() and has_mdns() and has_non_ip_address()):
return None
_LOGGER.info("Discovering devices...")
if not (discovered := discover_mdns_devices(CORE.name)):
_LOGGER.warning(
"No devices matching '%s-<mac>.local' were discovered.", CORE.name
)
return []
_populate_mdns_cache(discovered)
return list(discovered)
def _ota_hostnames_for_default(purpose: Purpose) -> list[str]:
"""Return OTA hostname(s) for the ``--device OTA`` / default-resolve path.
When ``name_add_mac_suffix`` is enabled, returns discovered
``<name>-<mac>.local`` hostnames (possibly empty — in which case the
caller should not fall back to the base name). Otherwise falls back to
the cache-resolved ``CORE.address``.
"""
if (discovered := _discover_mac_suffix_devices()) is not None:
return discovered
return _resolve_with_cache(CORE.address, purpose)
def choose_upload_log_host(
default: list[str] | str | None,
check_default: str | None,
@@ -242,14 +302,14 @@ def choose_upload_log_host(
resolved.append("MQTT")
if has_api() and has_non_ip_address() and has_resolvable_address():
resolved.extend(_resolve_with_cache(CORE.address, purpose))
resolved.extend(_ota_hostnames_for_default(purpose))
elif purpose == Purpose.UPLOADING:
if has_ota() and has_mqtt_ip_lookup():
resolved.append("MQTTIP")
if has_ota() and has_non_ip_address() and has_resolvable_address():
resolved.extend(_resolve_with_cache(CORE.address, purpose))
resolved.extend(_ota_hostnames_for_default(purpose))
else:
resolved.append(device)
if not resolved:
@@ -281,22 +341,29 @@ def choose_upload_log_host(
elif bootsel.permission_error:
bootsel_permission_error = True
def add_ota_options() -> None:
"""Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled."""
if (discovered := _discover_mac_suffix_devices()) is not None:
# Discovery was applicable. Use whatever we found — on empty,
# intentionally skip the base-name fallback since with
# name_add_mac_suffix on, the base name doesn't exist on the net.
for host in discovered:
options.append((f"Over The Air ({host})", host))
elif has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
if purpose == Purpose.LOGGING:
if has_mqtt_logging():
mqtt_config = CORE.config[CONF_MQTT]
options.append((f"MQTT ({mqtt_config[CONF_BROKER]})", "MQTT"))
if has_api():
if has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
add_ota_options()
elif purpose == Purpose.UPLOADING and has_ota():
if has_resolvable_address():
options.append((f"Over The Air ({CORE.address})", CORE.address))
if has_mqtt_ip_lookup():
options.append(("Over The Air (MQTT IP lookup)", "MQTTIP"))
add_ota_options()
# Show helpful BOOTSEL instructions for RP2040 when no BOOTSEL device is found
if (
@@ -407,7 +474,17 @@ def has_resolvable_address() -> bool:
return not CORE.address.endswith(".local")
def mqtt_get_ip(config: ConfigType, username: str, password: str, client_id: str):
def has_name_add_mac_suffix() -> bool:
"""Check if name_add_mac_suffix is enabled in the config."""
if CORE.config is None:
return False
esphome_config = CORE.config.get(CONF_ESPHOME, {})
return esphome_config.get(CONF_NAME_ADD_MAC_SUFFIX, False)
def mqtt_get_ip(
config: ConfigType, username: str, password: str, client_id: str
) -> list[str]:
from esphome import mqtt
return mqtt.get_esphome_device_ip(config, username, password, client_id)
@@ -420,6 +497,9 @@ def _resolve_network_devices(
This function filters the devices list to:
- Replace MQTT/MQTTIP magic strings with actual IP addresses via MQTT lookup
- Expand hostnames that are already in ``CORE.address_cache`` to their
cached IPs so downstream code (e.g. aioesphomeapi) doesn't open a second
Zeroconf client to resolve them
- Deduplicate addresses while preserving order
- Only resolve MQTT once even if multiple MQTT strings are present
- If MQTT resolution fails, log a warning and continue with other devices
@@ -444,13 +524,29 @@ def _resolve_network_devices(
mqtt_ips = mqtt_get_ip(
config, args.username, args.password, args.client_id
)
network_devices.extend(mqtt_ips)
# pylint can't infer mqtt_get_ip's return through its
# lazy ``from esphome import mqtt`` import, so it flags
# the genexpr below.
network_devices.extend(
addr
for addr in mqtt_ips # pylint: disable=not-an-iterable
if addr not in network_devices
)
except EsphomeError as err:
_LOGGER.warning(
"MQTT IP discovery failed (%s), will try other devices if available",
err,
)
mqtt_resolved = True
continue
# If the hostname is already in the address cache (e.g. populated by
# mDNS discovery), substitute the cached IPs so aioesphomeapi doesn't
# open its own Zeroconf to re-resolve it.
if CORE.address_cache and (cached := CORE.address_cache.get_addresses(device)):
network_devices.extend(
addr for addr in cached if addr not in network_devices
)
elif device not in network_devices:
# Regular network address or IP - add if not already present
network_devices.append(device)
+11
View File
@@ -101,6 +101,17 @@ class AddressCache:
"""Check if any cache entries exist."""
return bool(self.mdns_cache or self.dns_cache)
def add_mdns_addresses(self, hostname: str, addresses: list[str]) -> None:
"""Store resolved mDNS addresses for ``hostname`` in the cache.
Callers that discover ``.local`` hosts (e.g. via mDNS browse) can use
this to avoid a second resolution round-trip during the upload path.
No-op when ``addresses`` is empty.
"""
if not addresses:
return
self.mdns_cache[normalize_hostname(hostname)] = addresses
@classmethod
def from_cli_args(
cls, mdns_args: Iterable[str], dns_args: Iterable[str]
+56
View File
@@ -0,0 +1,56 @@
"""Helpers for running an async coroutine from sync code via a daemon thread.
``asyncio.run(coro())`` in the main thread blocks until the loop's cleanup
cycle finishes, which can add hundreds of milliseconds before the caller
receives the result. Running the loop in a daemon thread lets the caller
observe the result as soon as the coroutine completes while cleanup finishes
in the background.
"""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
import threading
from typing import Generic, TypeVar
_T = TypeVar("_T")
class AsyncThreadRunner(threading.Thread, Generic[_T]):
"""Run an async coroutine in a daemon thread and expose its result.
The runner catches all exceptions from the coroutine and stores them in
``exception`` so ``event`` is always set — this prevents callers waiting
on ``event`` from hanging forever when the coroutine crashes.
Typical usage::
runner = AsyncThreadRunner(lambda: my_coro(arg))
runner.start()
if not runner.event.wait(timeout=5.0):
... # timed out
if runner.exception is not None:
raise runner.exception
result = runner.result
"""
def __init__(self, coro_factory: Callable[[], Awaitable[_T]]) -> None:
super().__init__(daemon=True)
self._coro_factory = coro_factory
self.result: _T | None = None
self.exception: BaseException | None = None
self.event = threading.Event()
async def _runner(self) -> None:
try:
self.result = await self._coro_factory()
except Exception as exc: # pylint: disable=broad-except
# Capture all exceptions so ``event`` is always set — otherwise a
# crash would hang the waiter forever.
self.exception = exc
finally:
self.event.set()
def run(self) -> None:
asyncio.run(self._runner())
+17 -31
View File
@@ -2,66 +2,52 @@
from __future__ import annotations
import asyncio
import threading
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
import aioesphomeapi.host_resolver as hr
from esphome.async_thread import AsyncThreadRunner
from esphome.core import EsphomeError
RESOLVE_TIMEOUT = 10.0 # seconds
class AsyncResolver(threading.Thread):
class AsyncResolver:
"""Resolver using aioesphomeapi that runs in a thread for faster results.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS resolution,
including proper .local domain fallback. Running in a thread allows us to get
the result immediately without waiting for asyncio.run() to complete its
cleanup cycle, which can take significant time.
This resolver uses aioesphomeapi's async_resolve_host to handle DNS
resolution, including proper .local domain fallback. Running in a thread
(via :class:`AsyncThreadRunner`) allows us to get the result immediately
without waiting for ``asyncio.run()`` to complete its cleanup cycle, which
can take significant time.
"""
def __init__(self, hosts: list[str], port: int) -> None:
"""Initialize the resolver."""
super().__init__(daemon=True)
self.hosts = hosts
self.port = port
self.result: list[hr.AddrInfo] | None = None
self.exception: Exception | None = None
self.event = threading.Event()
async def _resolve(self) -> None:
async def _resolve(self) -> list[hr.AddrInfo]:
"""Resolve hostnames to IP addresses."""
try:
self.result = await hr.async_resolve_host(
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
)
except Exception as e: # pylint: disable=broad-except
# We need to catch all exceptions to ensure the event is set
# Otherwise the thread could hang forever
self.exception = e
finally:
self.event.set()
def run(self) -> None:
"""Run the DNS resolution."""
asyncio.run(self._resolve())
return await hr.async_resolve_host(
self.hosts, self.port, timeout=RESOLVE_TIMEOUT
)
def resolve(self) -> list[hr.AddrInfo]:
"""Start the thread and wait for the result."""
self.start()
runner: AsyncThreadRunner[list[hr.AddrInfo]] = AsyncThreadRunner(self._resolve)
runner.start()
if not self.event.wait(
if not runner.event.wait(
timeout=RESOLVE_TIMEOUT + 1.0
): # Give it 1 second more than the resolver timeout
raise EsphomeError("Timeout resolving IP address")
if exc := self.exception:
if exc := runner.exception:
if isinstance(exc, ResolveTimeoutAPIError):
raise EsphomeError(f"Timeout resolving IP address: {exc}") from exc
if isinstance(exc, ResolveAPIError):
raise EsphomeError(f"Error resolving IP address: {exc}") from exc
raise exc
return self.result
assert runner.result is not None # guaranteed when event set and no exception
return runner.result
+174 -7
View File
@@ -14,8 +14,13 @@ from zeroconf import (
)
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
from esphome.async_thread import AsyncThreadRunner
from esphome.storage_json import StorageJSON, ext_storage_path
# Length of the MAC suffix appended when name_add_mac_suffix is enabled.
MAC_SUFFIX_LEN = 6
_HEX_CHARS = frozenset("0123456789abcdef")
_LOGGER = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 10.0
@@ -188,15 +193,177 @@ class EsphomeZeroconf(Zeroconf):
return None
async def async_resolve_hosts(
zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT
) -> dict[str, list[str]]:
"""Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance.
Tries the cache synchronously first (so hosts already primed by a recent
browse return immediately with no network round-trip), then issues
``async_request`` for the remaining misses in parallel via
``asyncio.gather``. Returns a dict mapping each host to its list of
addresses (empty list when unresolved). Only ``<short>.local`` form is
queried, matching the name scheme the resolvers below expect.
"""
resolvers: dict[str, AddressResolver] = {}
pending: list[str] = []
for host in hosts:
resolver = AddressResolver(f"{host.partition('.')[0]}.local.")
resolvers[host] = resolver
if not resolver.load_from_cache(zeroconf):
pending.append(host)
if pending and timeout:
results = await asyncio.gather(
*(
resolvers[host].async_request(zeroconf, timeout * 1000)
for host in pending
),
return_exceptions=True,
)
for host, result in zip(pending, results):
if isinstance(result, BaseException):
_LOGGER.debug("Failed to resolve %s: %s", host, result)
return {
host: resolver.parsed_scoped_addresses(IPVersion.All)
for host, resolver in resolvers.items()
}
class AsyncEsphomeZeroconf(AsyncZeroconf):
async def async_resolve_host(
self, host: str, timeout: float = DEFAULT_TIMEOUT
) -> list[str] | None:
"""Resolve a host name to an IP address."""
info = AddressResolver(f"{host.partition('.')[0]}.local.")
if (
info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
return addresses
return None
addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host]
return addresses or None
def _is_mac_suffix_match(device_name: str, prefix: str) -> bool:
"""Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC."""
if not device_name.startswith(prefix):
return False
suffix = device_name[len(prefix) :]
return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix)
async def async_discover_mdns_devices(
base_name: str, timeout: float = 5.0
) -> dict[str, list[str]]:
"""Discover ESPHome devices via mDNS that match the base name + MAC suffix.
When ``name_add_mac_suffix`` is enabled, devices advertise as
``<base_name>-<6-hex-mac>.local``. This function uses a single
``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and
resolve their IP addresses, so callers get resolved addresses without
opening a second Zeroconf client.
Args:
base_name: The base device name (without MAC suffix).
timeout: How long to wait for mDNS responses (default 5 seconds).
Returns:
Mapping of ``<device>.local`` hostnames to their resolved IP addresses
(may be empty for a device if resolution failed within the timeout).
"""
prefix = f"{base_name}-"
# Preserves insertion order for stable output and deduplicates
discovered: dict[str, list[str]] = {}
def on_service_state_change(
zeroconf: Zeroconf,
service_type: str,
name: str,
state_change: ServiceStateChange,
) -> None:
if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated):
return
device_name = name.partition(".")[0]
if not _is_mac_suffix_match(device_name, prefix):
_LOGGER.debug(
"Ignoring %s (%s): does not match '%s<6-hex>'",
device_name,
state_change.name,
prefix,
)
return
host = f"{device_name}.local"
if host in discovered:
return
discovered[host] = []
_LOGGER.debug("Discovered %s (%s)", host, state_change.name)
_LOGGER.debug(
"Starting mDNS discovery for '%s<mac>.local' (timeout=%.1fs)",
prefix,
timeout,
)
try:
aiozc = AsyncEsphomeZeroconf()
except Exception as err: # pylint: disable=broad-except
# Zeroconf init can raise OSError, NonUniqueNameException, etc.
# Any failure here just means we can't discover — log and move on.
_LOGGER.warning("mDNS discovery failed to initialize: %s", err)
return {}
try:
browser = AsyncServiceBrowser(
aiozc.zeroconf,
ESPHOME_SERVICE_TYPE,
handlers=[on_service_state_change],
)
try:
await asyncio.sleep(timeout)
finally:
await browser.async_cancel()
_LOGGER.debug(
"Browse finished: %d device(s) matched '%s<mac>'",
len(discovered),
prefix,
)
# Resolve each discovered hostname on the SAME Zeroconf instance so
# we don't spin up a second client. ``async_resolve_hosts`` tries the
# cache synchronously (the browse usually primes it) before issuing
# any ``async_request`` in parallel for misses.
resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered))
for host, addresses in resolved.items():
if addresses:
discovered[host] = addresses
_LOGGER.debug("Resolved %s -> %s", host, addresses)
else:
_LOGGER.debug("No addresses returned for %s", host)
finally:
await aiozc.async_close()
return dict(sorted(discovered.items()))
def _await_discovery(
runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float
) -> dict[str, list[str]]:
"""Wait for ``runner`` to finish and return its discovery result.
Split out of :func:`discover_mdns_devices` so the timeout branch is
testable without patching ``asyncio`` or ``threading`` internals — a test
passes a stub whose ``event.wait`` returns ``False``.
"""
# Give the discovery an extra second over the browse timeout for the
# resolution + cleanup pass.
if not runner.event.wait(timeout=timeout + 2.0):
_LOGGER.warning("mDNS discovery timed out after %.1fs", timeout)
return {}
if runner.exception is not None:
_LOGGER.warning("mDNS discovery failed: %s", runner.exception)
return {}
return runner.result or {}
def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]:
"""Synchronous wrapper around :func:`async_discover_mdns_devices`."""
runner = AsyncThreadRunner(
lambda: async_discover_mdns_devices(base_name, timeout=timeout)
)
runner.start()
return _await_discovery(runner, timeout)
+20
View File
@@ -121,6 +121,26 @@ def test_get_addresses_auto_detection() -> None:
assert cache.get_addresses("unknown.com") is None
def test_add_mdns_addresses_stores_and_normalizes() -> None:
"""add_mdns_addresses inserts entries under the normalized hostname."""
cache = AddressCache()
cache.add_mdns_addresses("Device.Local.", ["192.168.1.10", "192.168.1.11"])
assert cache.mdns_cache == {
normalize_hostname("Device.Local."): ["192.168.1.10", "192.168.1.11"]
}
# Overwrites on subsequent calls for the same host
cache.add_mdns_addresses("device.local", ["10.0.0.1"])
assert cache.mdns_cache[normalize_hostname("device.local")] == ["10.0.0.1"]
def test_add_mdns_addresses_empty_is_noop() -> None:
"""Passing an empty address list must not create an entry."""
cache = AddressCache()
cache.add_mdns_addresses("device.local", [])
assert cache.mdns_cache == {}
def test_has_cache() -> None:
"""Test checking if cache has entries."""
# Empty cache
File diff suppressed because it is too large Load Diff
+15 -18
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import re
import socket
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from aioesphomeapi.core import ResolveAPIError, ResolveTimeoutAPIError
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr, IPv6Sockaddr
@@ -115,24 +115,21 @@ def test_async_resolver_generic_exception() -> None:
def test_async_resolver_thread_timeout() -> None:
"""Test timeout when thread doesn't complete in time."""
# Mock the start method to prevent actual thread execution
with (
patch.object(AsyncResolver, "start"),
patch("esphome.resolver.hr.async_resolve_host"),
):
resolver = AsyncResolver(["test.local"], 6053)
# Override event.wait to simulate timeout (return False = timeout occurred)
with (
patch.object(resolver.event, "wait", return_value=False),
pytest.raises(
EsphomeError, match=re.escape("Timeout resolving IP address")
),
):
resolver.resolve()
"""Test timeout when the runner thread doesn't complete in time."""
# Patch AsyncThreadRunner inside esphome.resolver so we never actually
# start a thread and can control the wait return value directly.
fake_runner = MagicMock()
fake_runner.start = MagicMock()
fake_runner.event.wait.return_value = False # simulate timeout
# Verify thread start was called
resolver.start.assert_called_once()
with (
patch("esphome.resolver.AsyncThreadRunner", return_value=fake_runner),
patch("esphome.resolver.hr.async_resolve_host"),
pytest.raises(EsphomeError, match=re.escape("Timeout resolving IP address")),
):
AsyncResolver(["test.local"], 6053).resolve()
fake_runner.start.assert_called_once()
def test_async_resolver_ip_addresses(mock_addr_info_ipv4: AddrInfo) -> None: