Files
esphome/tests/integration/state_utils.py
2026-03-15 13:02:23 -10:00

349 lines
13 KiB
Python

"""Shared utilities for ESPHome integration tests - state handling."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
import logging
from typing import TypeVar
from aioesphomeapi import (
BinarySensorState,
ButtonInfo,
EntityInfo,
EntityState,
SensorState,
TextSensorState,
)
_LOGGER = logging.getLogger(__name__)
T = TypeVar("T", bound=EntityInfo)
def find_entity(
entities: list[EntityInfo],
object_id_substring: str,
entity_type: type[T] | None = None,
) -> T | EntityInfo | None:
"""Find an entity by object_id substring and optionally by type.
Args:
entities: List of entity info objects from the API
object_id_substring: Substring to search for in object_id (case-insensitive)
entity_type: Optional entity type to filter by (e.g., BinarySensorInfo)
Returns:
The first matching entity, or None if not found
Example:
binary_sensor = find_entity(entities, "test_binary_sensor", BinarySensorInfo)
button = find_entity(entities, "set_true") # Any entity type
"""
substring_lower = object_id_substring.lower()
for entity in entities:
if substring_lower in entity.object_id.lower() and (
entity_type is None or isinstance(entity, entity_type)
):
return entity
return None
def require_entity(
entities: list[EntityInfo],
object_id_substring: str,
entity_type: type[T] | None = None,
description: str | None = None,
) -> T | EntityInfo:
"""Find an entity or raise AssertionError if not found.
Args:
entities: List of entity info objects from the API
object_id_substring: Substring to search for in object_id (case-insensitive)
entity_type: Optional entity type to filter by (e.g., BinarySensorInfo)
description: Human-readable description for error message
Returns:
The first matching entity
Raises:
AssertionError: If no matching entity is found
Example:
binary_sensor = require_entity(entities, "test_sensor", BinarySensorInfo)
button = require_entity(entities, "set_true", description="Set True button")
"""
entity = find_entity(entities, object_id_substring, entity_type)
if entity is None:
desc = description or f"entity with '{object_id_substring}' in object_id"
type_info = f" of type {entity_type.__name__}" if entity_type else ""
raise AssertionError(f"{desc}{type_info} not found in entities")
return entity
def build_key_to_entity_mapping(
entities: list[EntityInfo], entity_names: list[str]
) -> dict[int, str]:
"""Build a mapping from entity keys to entity names.
Args:
entities: List of entity info objects from the API
entity_names: List of entity names to match exactly against object_ids
Returns:
Dictionary mapping entity keys to entity names
"""
key_to_entity: dict[int, str] = {}
for entity in entities:
obj_id = entity.object_id.lower()
for entity_name in entity_names:
if entity_name == obj_id:
key_to_entity[entity.key] = entity_name
break
return key_to_entity
class InitialStateHelper:
"""Helper to wait for initial states before processing test states.
When an API client connects, ESPHome sends the current state of all entities.
This helper wraps the user's state callback and swallows the first state for
each entity, then forwards all subsequent states to the user callback.
Usage:
entities, services = await client.list_entities_services()
helper = InitialStateHelper(entities)
client.subscribe_states(helper.on_state_wrapper(user_callback))
await helper.wait_for_initial_states()
# Access initial states via helper.initial_states[key]
"""
def __init__(self, entities: list[EntityInfo]) -> None:
"""Initialize the helper.
Args:
entities: All entities from list_entities_services()
"""
# Set of (device_id, key) tuples waiting for initial state
# Buttons are stateless, so exclude them
self._wait_initial_states = {
(entity.device_id, entity.key)
for entity in entities
if not isinstance(entity, ButtonInfo)
}
# Keep entity info for debugging - use (device_id, key) tuple
self._entities_by_id = {
(entity.device_id, entity.key): entity for entity in entities
}
# Store initial states by key for test access
self.initial_states: dict[int, EntityState] = {}
# Log all entities
_LOGGER.debug(
"InitialStateHelper: Found %d total entities: %s",
len(entities),
[(type(e).__name__, e.object_id) for e in entities],
)
# Log which ones we're waiting for
_LOGGER.debug(
"InitialStateHelper: Waiting for %d entities (excluding ButtonInfo): %s",
len(self._wait_initial_states),
[self._entities_by_id[k].object_id for k in self._wait_initial_states],
)
# Log which ones we're NOT waiting for
not_waiting = {
(e.device_id, e.key) for e in entities
} - self._wait_initial_states
if not_waiting:
not_waiting_info = [
f"{type(self._entities_by_id[k]).__name__}:{self._entities_by_id[k].object_id}"
for k in not_waiting
]
_LOGGER.debug(
"InitialStateHelper: NOT waiting for %d entities: %s",
len(not_waiting),
not_waiting_info,
)
# Create future in the running event loop
self._initial_states_received = asyncio.get_running_loop().create_future()
# If no entities to wait for, mark complete immediately
if not self._wait_initial_states:
self._initial_states_received.set_result(True)
def on_state_wrapper(self, user_callback):
"""Wrap a user callback to track initial states.
Args:
user_callback: The user's state callback function
Returns:
Wrapped callback that swallows first state per entity, forwards rest
"""
def wrapper(state: EntityState) -> None:
"""Swallow initial state per entity, forward subsequent states."""
# Create entity identifier tuple
entity_id = (state.device_id, state.key)
# Log which entity is sending state
if entity_id in self._entities_by_id:
entity = self._entities_by_id[entity_id]
_LOGGER.debug(
"Received state for %s (type: %s, device_id: %s, key: %d)",
entity.object_id,
type(entity).__name__,
state.device_id,
state.key,
)
# If this entity is waiting for initial state
if entity_id in self._wait_initial_states:
# Store the initial state for test access
self.initial_states[state.key] = state
# Remove from waiting set
self._wait_initial_states.discard(entity_id)
_LOGGER.debug(
"Swallowed initial state for %s, %d entities remaining",
self._entities_by_id[entity_id].object_id
if entity_id in self._entities_by_id
else entity_id,
len(self._wait_initial_states),
)
# Check if we've now seen all entities
if (
not self._wait_initial_states
and not self._initial_states_received.done()
):
_LOGGER.debug("All initial states received")
self._initial_states_received.set_result(True)
# Don't forward initial state to user
return
# Forward subsequent states to user callback
_LOGGER.debug("Forwarding state to user callback")
user_callback(state)
return wrapper
async def wait_for_initial_states(self, timeout: float = 5.0) -> None:
"""Wait for all initial states to be received.
Args:
timeout: Maximum time to wait in seconds
Raises:
asyncio.TimeoutError: If initial states aren't received within timeout
"""
await asyncio.wait_for(self._initial_states_received, timeout=timeout)
class SensorStateCollector:
"""Collects sensor, binary sensor, and text sensor state updates with wait helpers.
Usage:
collector = SensorStateCollector(
sensor_names=["moving_distance", "still_distance"],
binary_sensor_names=["has_target"],
text_sensor_names=["direction"],
)
# Use collector.on_state as the callback (or wrap it)
client.subscribe_states(helper.on_state_wrapper(collector.on_state))
# Wait for all sensors to have at least one value
await collector.wait_for_all(timeout=3.0)
# Access collected states
assert collector.sensor_states["moving_distance"][0] == approx(100.0)
assert collector.text_sensor_states["direction"][0] == "Approaching"
"""
def __init__(
self,
sensor_names: list[str],
binary_sensor_names: list[str] | None = None,
text_sensor_names: list[str] | None = None,
entities: list[EntityInfo] | None = None,
) -> None:
self.sensor_states: dict[str, list[float]] = {name: [] for name in sensor_names}
self.binary_states: dict[str, list[bool]] = {
name: [] for name in (binary_sensor_names or [])
}
self.text_sensor_states: dict[str, list[str]] = {
name: [] for name in (text_sensor_names or [])
}
self._key_to_sensor: dict[int, str] = {}
self._waiters: list[tuple[Callable[[], bool], asyncio.Future[bool]]] = []
if entities is not None:
self.build_key_mapping(entities)
def build_key_mapping(self, entities: list[EntityInfo]) -> None:
"""Build key-to-name mapping from entities. Sorted by descending length."""
all_names = (
list(self.sensor_states.keys())
+ list(self.binary_states.keys())
+ list(self.text_sensor_states.keys())
)
all_names.sort(key=len, reverse=True)
self._key_to_sensor = build_key_to_entity_mapping(entities, all_names)
def on_state(self, state: EntityState) -> None:
"""Process a state update."""
if isinstance(state, SensorState) and not state.missing_state:
sensor_name = self._key_to_sensor.get(state.key)
if sensor_name and sensor_name in self.sensor_states:
self.sensor_states[sensor_name].append(state.state)
self._check_waiters()
elif isinstance(state, BinarySensorState):
sensor_name = self._key_to_sensor.get(state.key)
if sensor_name and sensor_name in self.binary_states:
self.binary_states[sensor_name].append(state.state)
self._check_waiters()
elif isinstance(state, TextSensorState) and not state.missing_state:
sensor_name = self._key_to_sensor.get(state.key)
if sensor_name and sensor_name in self.text_sensor_states:
self.text_sensor_states[sensor_name].append(state.state)
self._check_waiters()
def _check_waiters(self) -> None:
"""Check all pending waiters and resolve any whose condition is met."""
for condition, future in self._waiters:
if not future.done() and condition():
future.set_result(True)
def _all_have_values(self) -> bool:
"""Check if all sensor, binary sensor, and text sensor lists have at least one value."""
return (
all(len(v) >= 1 for v in self.sensor_states.values())
and all(len(v) >= 1 for v in self.binary_states.values())
and all(len(v) >= 1 for v in self.text_sensor_states.values())
)
async def wait_for_all(self, timeout: float = 3.0) -> None:
"""Wait until all sensors and binary sensors have at least one value."""
if self._all_have_values():
return
future: asyncio.Future[bool] = asyncio.get_running_loop().create_future()
self._waiters.append((self._all_have_values, future))
await asyncio.wait_for(future, timeout=timeout)
def add_waiter(self, condition: Callable[[], bool]) -> asyncio.Future[bool]:
"""Add a custom waiter that resolves when condition returns True.
Returns:
A future that resolves when the condition is met.
"""
future: asyncio.Future[bool] = asyncio.get_running_loop().create_future()
if condition():
future.set_result(True)
else:
self._waiters.append((condition, future))
return future