mirror of
https://github.com/esphome/esphome.git
synced 2026-05-10 05:37:55 +08:00
370 lines
13 KiB
Python
370 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
import logging
|
|
|
|
from zeroconf import (
|
|
AddressResolver,
|
|
IPVersion,
|
|
ServiceInfo,
|
|
ServiceStateChange,
|
|
Zeroconf,
|
|
)
|
|
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
|
|
|
|
from esphome.async_thread import AsyncThreadRunner
|
|
from esphome.storage_json import StorageJSON, ext_storage_path
|
|
|
|
# Length of the MAC suffix appended when name_add_mac_suffix is enabled.
|
|
MAC_SUFFIX_LEN = 6
|
|
_HEX_CHARS = frozenset("0123456789abcdef")
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
DEFAULT_TIMEOUT = 10.0
|
|
DEFAULT_TIMEOUT_MS = DEFAULT_TIMEOUT * 1000
|
|
|
|
_BACKGROUND_TASKS: set[asyncio.Task] = set()
|
|
|
|
|
|
class DashboardStatus:
|
|
def __init__(self, on_update: Callable[[dict[str, bool | None]], None]) -> None:
|
|
"""Initialize the dashboard status."""
|
|
self.on_update = on_update
|
|
|
|
def browser_callback(
|
|
self,
|
|
zeroconf: Zeroconf,
|
|
service_type: str,
|
|
name: str,
|
|
state_change: ServiceStateChange,
|
|
) -> None:
|
|
"""Handle a service update."""
|
|
short_name = name.partition(".")[0]
|
|
if state_change == ServiceStateChange.Removed:
|
|
self.on_update({short_name: False})
|
|
elif state_change in (ServiceStateChange.Updated, ServiceStateChange.Added):
|
|
self.on_update({short_name: True})
|
|
|
|
|
|
ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local."
|
|
TXT_RECORD_PACKAGE_IMPORT_URL = b"package_import_url"
|
|
TXT_RECORD_PROJECT_NAME = b"project_name"
|
|
TXT_RECORD_PROJECT_VERSION = b"project_version"
|
|
TXT_RECORD_NETWORK = b"network"
|
|
TXT_RECORD_FRIENDLY_NAME = b"friendly_name"
|
|
TXT_RECORD_VERSION = b"version"
|
|
|
|
|
|
@dataclass
|
|
class DiscoveredImport:
|
|
friendly_name: str | None
|
|
device_name: str
|
|
package_import_url: str
|
|
project_name: str
|
|
project_version: str
|
|
network: str
|
|
|
|
|
|
class DashboardBrowser(AsyncServiceBrowser):
|
|
"""A class to browse for ESPHome nodes."""
|
|
|
|
|
|
class DashboardImportDiscovery:
|
|
def __init__(
|
|
self, on_update: Callable[[str, DiscoveredImport | None], None] | None = None
|
|
) -> None:
|
|
self.import_state: dict[str, DiscoveredImport] = {}
|
|
self.on_update = on_update
|
|
|
|
def browser_callback(
|
|
self,
|
|
zeroconf: Zeroconf,
|
|
service_type: str,
|
|
name: str,
|
|
state_change: ServiceStateChange,
|
|
) -> None:
|
|
_LOGGER.debug(
|
|
"service_update: type=%s name=%s state_change=%s",
|
|
service_type,
|
|
name,
|
|
state_change,
|
|
)
|
|
if state_change == ServiceStateChange.Removed:
|
|
removed = self.import_state.pop(name, None)
|
|
if removed and self.on_update:
|
|
self.on_update(name, None)
|
|
return
|
|
|
|
if state_change == ServiceStateChange.Updated and name not in self.import_state:
|
|
# Ignore updates for devices that are not in the import state
|
|
return
|
|
|
|
info = AsyncServiceInfo(
|
|
service_type,
|
|
name,
|
|
)
|
|
if info.load_from_cache(zeroconf):
|
|
self._process_service_info(name, info)
|
|
return
|
|
task = asyncio.create_task(
|
|
self._async_process_service_info(zeroconf, info, service_type, name)
|
|
)
|
|
_BACKGROUND_TASKS.add(task)
|
|
task.add_done_callback(_BACKGROUND_TASKS.discard)
|
|
|
|
async def _async_process_service_info(
|
|
self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str
|
|
) -> None:
|
|
"""Process a service info."""
|
|
if await info.async_request(zeroconf, timeout=DEFAULT_TIMEOUT_MS):
|
|
self._process_service_info(name, info)
|
|
|
|
def _process_service_info(self, name: str, info: ServiceInfo) -> None:
|
|
"""Process a service info."""
|
|
_LOGGER.debug("-> resolved info: %s", info)
|
|
if info is None:
|
|
return
|
|
node_name = name[: -len(ESPHOME_SERVICE_TYPE) - 1]
|
|
required_keys = [
|
|
TXT_RECORD_PACKAGE_IMPORT_URL,
|
|
TXT_RECORD_PROJECT_NAME,
|
|
TXT_RECORD_PROJECT_VERSION,
|
|
]
|
|
if any(key not in info.properties for key in required_keys):
|
|
# Not a dashboard import device
|
|
version = info.properties.get(TXT_RECORD_VERSION)
|
|
if version is not None:
|
|
version = version.decode()
|
|
self.update_device_mdns(node_name, version)
|
|
return
|
|
|
|
import_url = info.properties[TXT_RECORD_PACKAGE_IMPORT_URL].decode()
|
|
project_name = info.properties[TXT_RECORD_PROJECT_NAME].decode()
|
|
project_version = info.properties[TXT_RECORD_PROJECT_VERSION].decode()
|
|
network = info.properties.get(TXT_RECORD_NETWORK, b"wifi").decode()
|
|
friendly_name = info.properties.get(TXT_RECORD_FRIENDLY_NAME)
|
|
if friendly_name is not None:
|
|
friendly_name = friendly_name.decode()
|
|
|
|
discovered = DiscoveredImport(
|
|
friendly_name=friendly_name,
|
|
device_name=node_name,
|
|
package_import_url=import_url,
|
|
project_name=project_name,
|
|
project_version=project_version,
|
|
network=network,
|
|
)
|
|
is_new = name not in self.import_state
|
|
self.import_state[name] = discovered
|
|
if is_new and self.on_update:
|
|
self.on_update(name, discovered)
|
|
|
|
def update_device_mdns(self, node_name: str, version: str):
|
|
storage_path = ext_storage_path(node_name + ".yaml")
|
|
storage_json = StorageJSON.load(storage_path)
|
|
|
|
if storage_json is not None:
|
|
storage_version = storage_json.esphome_version
|
|
if version != storage_version:
|
|
storage_json.esphome_version = version
|
|
storage_json.save(storage_path)
|
|
_LOGGER.info(
|
|
"Updated %s with mdns version %s (was %s)",
|
|
node_name,
|
|
version,
|
|
storage_version,
|
|
)
|
|
|
|
|
|
class EsphomeZeroconf(Zeroconf):
|
|
def resolve_host(
|
|
self, host: str, timeout: float = DEFAULT_TIMEOUT
|
|
) -> list[str] | None:
|
|
"""Resolve a host name to an IP address."""
|
|
info = AddressResolver(f"{host.partition('.')[0]}.local.")
|
|
if (
|
|
info.load_from_cache(self)
|
|
or (timeout and info.request(self, timeout * 1000))
|
|
) and (addresses := info.parsed_scoped_addresses(IPVersion.All)):
|
|
return addresses
|
|
return None
|
|
|
|
|
|
async def async_resolve_hosts(
|
|
zeroconf: Zeroconf, hosts: list[str], timeout: float = DEFAULT_TIMEOUT
|
|
) -> dict[str, list[str]]:
|
|
"""Resolve ``hosts`` to IPs using a shared ``Zeroconf`` instance.
|
|
|
|
Tries the cache synchronously first (so hosts already primed by a recent
|
|
browse return immediately with no network round-trip), then issues
|
|
``async_request`` for the remaining misses in parallel via
|
|
``asyncio.gather``. Returns a dict mapping each host to its list of
|
|
addresses (empty list when unresolved). Only ``<short>.local`` form is
|
|
queried, matching the name scheme the resolvers below expect.
|
|
"""
|
|
resolvers: dict[str, AddressResolver] = {}
|
|
pending: list[str] = []
|
|
for host in hosts:
|
|
resolver = AddressResolver(f"{host.partition('.')[0]}.local.")
|
|
resolvers[host] = resolver
|
|
if not resolver.load_from_cache(zeroconf):
|
|
pending.append(host)
|
|
|
|
if pending and timeout:
|
|
results = await asyncio.gather(
|
|
*(
|
|
resolvers[host].async_request(zeroconf, timeout * 1000)
|
|
for host in pending
|
|
),
|
|
return_exceptions=True,
|
|
)
|
|
for host, result in zip(pending, results):
|
|
if isinstance(result, BaseException):
|
|
_LOGGER.debug("Failed to resolve %s: %s", host, result)
|
|
|
|
return {
|
|
host: resolver.parsed_scoped_addresses(IPVersion.All)
|
|
for host, resolver in resolvers.items()
|
|
}
|
|
|
|
|
|
class AsyncEsphomeZeroconf(AsyncZeroconf):
|
|
async def async_resolve_host(
|
|
self, host: str, timeout: float = DEFAULT_TIMEOUT
|
|
) -> list[str] | None:
|
|
"""Resolve a host name to an IP address."""
|
|
addresses = (await async_resolve_hosts(self.zeroconf, [host], timeout))[host]
|
|
return addresses or None
|
|
|
|
|
|
def _is_mac_suffix_match(device_name: str, prefix: str) -> bool:
|
|
"""Return True if ``device_name`` is ``prefix`` followed by a 6-char hex MAC."""
|
|
if not device_name.startswith(prefix):
|
|
return False
|
|
suffix = device_name[len(prefix) :]
|
|
return len(suffix) == MAC_SUFFIX_LEN and all(c in _HEX_CHARS for c in suffix)
|
|
|
|
|
|
async def async_discover_mdns_devices(
|
|
base_name: str, timeout: float = 5.0
|
|
) -> dict[str, list[str]]:
|
|
"""Discover ESPHome devices via mDNS that match the base name + MAC suffix.
|
|
|
|
When ``name_add_mac_suffix`` is enabled, devices advertise as
|
|
``<base_name>-<6-hex-mac>.local``. This function uses a single
|
|
``AsyncEsphomeZeroconf`` lifecycle to both browse for matching services and
|
|
resolve their IP addresses, so callers get resolved addresses without
|
|
opening a second Zeroconf client.
|
|
|
|
Args:
|
|
base_name: The base device name (without MAC suffix).
|
|
timeout: How long to wait for mDNS responses (default 5 seconds).
|
|
|
|
Returns:
|
|
Mapping of ``<device>.local`` hostnames to their resolved IP addresses
|
|
(may be empty for a device if resolution failed within the timeout).
|
|
"""
|
|
prefix = f"{base_name}-"
|
|
# Preserves insertion order for stable output and deduplicates
|
|
discovered: dict[str, list[str]] = {}
|
|
|
|
def on_service_state_change(
|
|
zeroconf: Zeroconf,
|
|
service_type: str,
|
|
name: str,
|
|
state_change: ServiceStateChange,
|
|
) -> None:
|
|
if state_change not in (ServiceStateChange.Added, ServiceStateChange.Updated):
|
|
return
|
|
device_name = name.partition(".")[0]
|
|
if not _is_mac_suffix_match(device_name, prefix):
|
|
_LOGGER.debug(
|
|
"Ignoring %s (%s): does not match '%s<6-hex>'",
|
|
device_name,
|
|
state_change.name,
|
|
prefix,
|
|
)
|
|
return
|
|
host = f"{device_name}.local"
|
|
if host in discovered:
|
|
return
|
|
discovered[host] = []
|
|
_LOGGER.debug("Discovered %s (%s)", host, state_change.name)
|
|
|
|
_LOGGER.debug(
|
|
"Starting mDNS discovery for '%s<mac>.local' (timeout=%.1fs)",
|
|
prefix,
|
|
timeout,
|
|
)
|
|
try:
|
|
aiozc = AsyncEsphomeZeroconf()
|
|
except Exception as err: # pylint: disable=broad-except
|
|
# Zeroconf init can raise OSError, NonUniqueNameException, etc.
|
|
# Any failure here just means we can't discover — log and move on.
|
|
_LOGGER.warning("mDNS discovery failed to initialize: %s", err)
|
|
return {}
|
|
|
|
try:
|
|
browser = AsyncServiceBrowser(
|
|
aiozc.zeroconf,
|
|
ESPHOME_SERVICE_TYPE,
|
|
handlers=[on_service_state_change],
|
|
)
|
|
try:
|
|
await asyncio.sleep(timeout)
|
|
finally:
|
|
await browser.async_cancel()
|
|
_LOGGER.debug(
|
|
"Browse finished: %d device(s) matched '%s<mac>'",
|
|
len(discovered),
|
|
prefix,
|
|
)
|
|
|
|
# Resolve each discovered hostname on the SAME Zeroconf instance so
|
|
# we don't spin up a second client. ``async_resolve_hosts`` tries the
|
|
# cache synchronously (the browse usually primes it) before issuing
|
|
# any ``async_request`` in parallel for misses.
|
|
resolved = await async_resolve_hosts(aiozc.zeroconf, list(discovered))
|
|
for host, addresses in resolved.items():
|
|
if addresses:
|
|
discovered[host] = addresses
|
|
_LOGGER.debug("Resolved %s -> %s", host, addresses)
|
|
else:
|
|
_LOGGER.debug("No addresses returned for %s", host)
|
|
finally:
|
|
await aiozc.async_close()
|
|
|
|
return dict(sorted(discovered.items()))
|
|
|
|
|
|
def _await_discovery(
|
|
runner: AsyncThreadRunner[dict[str, list[str]]], timeout: float
|
|
) -> dict[str, list[str]]:
|
|
"""Wait for ``runner`` to finish and return its discovery result.
|
|
|
|
Split out of :func:`discover_mdns_devices` so the timeout branch is
|
|
testable without patching ``asyncio`` or ``threading`` internals — a test
|
|
passes a stub whose ``event.wait`` returns ``False``.
|
|
"""
|
|
# Give the discovery an extra second over the browse timeout for the
|
|
# resolution + cleanup pass.
|
|
if not runner.event.wait(timeout=timeout + 2.0):
|
|
_LOGGER.warning("mDNS discovery timed out after %.1fs", timeout)
|
|
return {}
|
|
if runner.exception is not None:
|
|
_LOGGER.warning("mDNS discovery failed: %s", runner.exception)
|
|
return {}
|
|
return runner.result or {}
|
|
|
|
|
|
def discover_mdns_devices(base_name: str, timeout: float = 5.0) -> dict[str, list[str]]:
|
|
"""Synchronous wrapper around :func:`async_discover_mdns_devices`."""
|
|
runner = AsyncThreadRunner(
|
|
lambda: async_discover_mdns_devices(base_name, timeout=timeout)
|
|
)
|
|
runner.start()
|
|
return _await_discovery(runner, timeout)
|