From 8ceada8d04a3b2fa9a428ce3ad5c7e68e7311404 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 28 Apr 2026 21:32:30 -0500 Subject: [PATCH] [core] Download external_files in parallel (#16021) --- esphome/components/audio_file/__init__.py | 18 +- .../speaker/media_player/__init__.py | 24 +-- esphome/external_files.py | 98 ++++++++- tests/unit_tests/test_external_files.py | 187 +++++++++++++++++- 4 files changed, 295 insertions(+), 32 deletions(-) diff --git a/esphome/components/audio_file/__init__.py b/esphome/components/audio_file/__init__.py index bb1ce257db..88be6db168 100644 --- a/esphome/components/audio_file/__init__.py +++ b/esphome/components/audio_file/__init__.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from functools import partial import hashlib import logging from pathlib import Path @@ -19,7 +20,7 @@ from esphome.const import ( ) from esphome.core import CORE, ID, HexInt from esphome.cpp_generator import MockObj -from esphome.external_files import download_content +from esphome.external_files import download_web_files_in_config from esphome.types import ConfigType _LOGGER = logging.getLogger(__name__) @@ -63,15 +64,6 @@ def _compute_local_file_path(value: ConfigType) -> Path: return base_dir / key -def _download_web_file(value: ConfigType) -> ConfigType: - url = value[CONF_URL] - path = _compute_local_file_path(value) - - download_content(url, path) - _LOGGER.debug("download_web_file: path=%s", path) - return value - - def _file_schema(value: ConfigType | str) -> ConfigType: if isinstance(value, str): return _validate_file_shorthand(value) @@ -142,11 +134,10 @@ LOCAL_SCHEMA = cv.Schema( } ) -WEB_SCHEMA = cv.All( +WEB_SCHEMA = cv.Schema( { cv.Required(CONF_URL): cv.url, - }, - _download_web_file, + } ) @@ -209,6 +200,7 @@ def _validate_supported_local_file(config: list[ConfigType]) -> list[ConfigType] CONFIG_SCHEMA = cv.All( cv.only_on_esp32, cv.ensure_list(MEDIA_FILE_TYPE_SCHEMA), + partial(download_web_files_in_config, path_for=_compute_local_file_path), _validate_supported_local_file, ) diff --git a/esphome/components/speaker/media_player/__init__.py b/esphome/components/speaker/media_player/__init__.py index abfd599808..fbc83ef12f 100644 --- a/esphome/components/speaker/media_player/__init__.py +++ b/esphome/components/speaker/media_player/__init__.py @@ -1,5 +1,6 @@ """Speaker Media Player Setup.""" +from functools import partial import hashlib import logging from pathlib import Path @@ -32,7 +33,7 @@ from esphome.const import ( CONF_URL, ) from esphome.core import CORE, HexInt -from esphome.external_files import download_content +from esphome.external_files import download_web_files_in_config _LOGGER = logging.getLogger(__name__) @@ -92,15 +93,6 @@ def _compute_local_file_path(value: dict) -> Path: return base_dir / key -def _download_web_file(value): - url = value[CONF_URL] - path = _compute_local_file_path(value) - - download_content(url, path) - _LOGGER.debug("download_web_file: path=%s", path) - return value - - _PURPOSE_MAP = { "MEDIA": media_player.MEDIA_PLAYER_FORMAT_PURPOSE_ENUM["default"], "ANNOUNCEMENT": media_player.MEDIA_PLAYER_FORMAT_PURPOSE_ENUM["announcement"], @@ -229,11 +221,10 @@ LOCAL_SCHEMA = cv.Schema( } ) -WEB_SCHEMA = cv.All( +WEB_SCHEMA = cv.Schema( { cv.Required(CONF_URL): cv.url, - }, - _download_web_file, + } ) @@ -285,7 +276,12 @@ CONFIG_SCHEMA = cv.All( ), # Remove before 2026.10.0 cv.Optional(CONF_CODEC_SUPPORT_ENABLED): cv.Any(cv.boolean, cv.string), - cv.Optional(CONF_FILES): cv.ensure_list(MEDIA_FILE_TYPE_SCHEMA), + cv.Optional(CONF_FILES): cv.All( + cv.ensure_list(MEDIA_FILE_TYPE_SCHEMA), + partial( + download_web_files_in_config, path_for=_compute_local_file_path + ), + ), cv.Optional(CONF_TASK_STACK_IN_PSRAM): cv.All( cv.boolean, cv.requires_component(psram.DOMAIN) ), diff --git a/esphome/external_files.py b/esphome/external_files.py index bd29dc93b1..fbc261f8e0 100644 --- a/esphome/external_files.py +++ b/esphome/external_files.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections.abc import Callable, Iterable +from concurrent.futures import ThreadPoolExecutor import contextlib from datetime import UTC, datetime import logging @@ -9,9 +11,10 @@ from pathlib import Path import requests import esphome.config_validation as cv -from esphome.const import __version__ +from esphome.const import CONF_FILE, CONF_TYPE, CONF_URL, __version__ from esphome.core import CORE, EsphomeError, TimePeriodSeconds from esphome.helpers import write_file +from esphome.types import ConfigType _LOGGER = logging.getLogger(__name__) CODEOWNERS = ["@landonr"] @@ -85,7 +88,9 @@ def _write_etag(local_file_path: Path, etag: str | None) -> None: ) -def has_remote_file_changed(url: str, local_file_path: Path) -> bool: +def has_remote_file_changed( + url: str, local_file_path: Path, timeout: int = NETWORK_TIMEOUT +) -> bool: if local_file_path.exists(): _LOGGER.debug("has_remote_file_changed: File exists at %s", local_file_path) try: @@ -101,7 +106,7 @@ def has_remote_file_changed(url: str, local_file_path: Path) -> bool: if etag := _read_etag(local_file_path): headers[IF_NONE_MATCH] = etag response = requests.head( - url, headers=headers, timeout=NETWORK_TIMEOUT, allow_redirects=True + url, headers=headers, timeout=timeout, allow_redirects=True ) _LOGGER.debug( @@ -153,7 +158,7 @@ def download_content(url: str, path: Path, timeout: int = NETWORK_TIMEOUT) -> by if CORE.skip_external_update and path.exists(): _LOGGER.debug("Skipping update for %s (refresh disabled)", url) return path.read_bytes() - if not has_remote_file_changed(url, path): + if not has_remote_file_changed(url, path, timeout): _LOGGER.debug("Remote file has not changed %s", url) return path.read_bytes() @@ -184,3 +189,88 @@ def download_content(url: str, path: Path, timeout: int = NETWORK_TIMEOUT) -> by write_file(path, data) _write_etag(path, req.headers.get(ETAG)) return data + + +# Cap concurrent connections so a config with hundreds of remote files doesn't +# open hundreds of sockets at once. 8 matches the requests connection-pool +# default and the per-host connection limit browsers use, which keeps us +# polite to the upstream host while still cutting wall time roughly 8x for +# typical configs (a couple dozen files). +DEFAULT_DOWNLOAD_WORKERS = 8 + + +def download_content_many( + items: Iterable[tuple[str, Path]], + timeout: int = NETWORK_TIMEOUT, + max_workers: int = DEFAULT_DOWNLOAD_WORKERS, +) -> None: + """Run `download_content` for each (url, path) pair concurrently. + + Wall time drops from `sum(latency)` to roughly `max(latency)` for cached + files where the HEAD round-trip dominates. All workers run to + completion before this returns; every `cv.Invalid` raised by a worker + is collected and surfaced together as `cv.MultipleInvalid` so the user + sees every broken file in a single validation pass instead of fixing + them one round-trip at a time. + + Items are de-duplicated by `path` -- two callers asking for the same + cache file (e.g. the same URL referenced twice in a config) would + otherwise race on `download_content`'s non-atomic write. When the + same `path` appears more than once, the last URL wins (standard dict + comprehension semantics); in practice duplicate paths only arise when + the URL is duplicated, so the choice doesn't matter. + """ + seen: dict[Path, str] = {path: url for url, path in items} + if not seen: + return + if len(seen) == 1: + path, url = next(iter(seen.items())) + download_content(url, path, timeout) + return + + def _download_one(path_url: tuple[Path, str]) -> None: + # `seen` stores entries as (path, url) so the dict can dedupe by + # path; flip them back to download_content's (url, path) order. + path, url = path_url + download_content(url, path, timeout) + + workers = max(1, min(max_workers, len(seen))) + errors: list[cv.Invalid] = [] + with ThreadPoolExecutor(max_workers=workers) as ex: + futures = [ex.submit(_download_one, item) for item in seen.items()] + for future in futures: + try: + future.result() + except cv.Invalid as e: + errors.append(e) + if not errors: + return + if len(errors) == 1: + raise errors[0] + raise cv.MultipleInvalid(errors) + + +# Each component that uses external_files defines its own local +# `TYPE_WEB = "web"`; the string is repeated here rather than imported +# because there is no canonical `TYPE_WEB` in `esphome.const` to share. +WEB_TYPE = "web" + + +def download_web_files_in_config( + config: list[ConfigType], + path_for: Callable[[ConfigType], Path], +) -> list[ConfigType]: + """Voluptuous-friendly validator that downloads any web-sourced files in + `config` in parallel. + + Each entry is expected to contain a `file` key whose value is a dict + that may be `{type: "web", url: ...}`; `path_for(file_dict)` returns + the cache path for that file. Returns `config` unchanged so it can be + slotted directly into a `cv.All(...)` chain. + """ + download_content_many( + (conf_file[CONF_URL], path_for(conf_file)) + for entry in config + if (conf_file := entry.get(CONF_FILE, {})).get(CONF_TYPE) == WEB_TYPE + ) + return config diff --git a/tests/unit_tests/test_external_files.py b/tests/unit_tests/test_external_files.py index f4d268abe0..c894f90666 100644 --- a/tests/unit_tests/test_external_files.py +++ b/tests/unit_tests/test_external_files.py @@ -9,7 +9,7 @@ import pytest import requests from esphome import external_files -from esphome.config_validation import Invalid +from esphome.config_validation import Invalid, MultipleInvalid from esphome.core import CORE, EsphomeError, TimePeriod @@ -60,6 +60,24 @@ def mock_write_file() -> MagicMock: yield m +@pytest.fixture +def mock_download_content() -> MagicMock: + """Patch `external_files.download_content` for tests that exercise the + parallel batch helper without doing real I/O. + """ + with patch("esphome.external_files.download_content") as m: + yield m + + +@pytest.fixture +def mock_download_content_many() -> MagicMock: + """Patch `external_files.download_content_many` for tests that exercise + the URL-collection helper without dispatching to the thread pool. + """ + with patch("esphome.external_files.download_content_many") as m: + yield m + + def test_compute_local_file_dir(setup_core: Path) -> None: """Test compute_local_file_dir creates and returns correct path.""" domain = "font" @@ -494,6 +512,173 @@ def test_download_content_skip_external_update_downloads_when_missing( assert test_file.read_bytes() == new_content +def test_download_content_many_empty_is_noop( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """Empty input shouldn't spin up a thread pool or call download_content.""" + external_files.download_content_many([]) + mock_download_content.assert_not_called() + + +def test_download_content_many_single_item_avoids_pool( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """A single item should be downloaded inline (no thread pool overhead).""" + item = ("https://example.com/file.txt", setup_core / "f.txt") + external_files.download_content_many([item]) + mock_download_content.assert_called_once_with( + item[0], item[1], external_files.NETWORK_TIMEOUT + ) + + +def test_download_content_many_runs_in_parallel( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """Multiple items should run concurrently — total wall time ≈ max latency.""" + import threading + + barrier = threading.Barrier(3) + + def slow_download(url: str, path: Path, timeout: int) -> bytes: + # If calls were serial this would deadlock (third caller never arrives + # while the first is blocked at the barrier). + barrier.wait(timeout=2.0) + return b"" + + mock_download_content.side_effect = slow_download + items = [ + ("https://example.com/a", setup_core / "a"), + ("https://example.com/b", setup_core / "b"), + ("https://example.com/c", setup_core / "c"), + ] + external_files.download_content_many(items, max_workers=4) + assert mock_download_content.call_count == 3 + + +def test_download_content_many_propagates_single_error( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """A single failing worker should raise its `Invalid` directly, not wrap + it in a `MultipleInvalid` that the caller would have to unpack. + """ + + def fake_download(url: str, path: Path, timeout: int) -> bytes: + if url.endswith("bad"): + raise Invalid(f"could not download {url}") + return b"" + + mock_download_content.side_effect = fake_download + items = [ + ("https://example.com/ok", setup_core / "ok"), + ("https://example.com/bad", setup_core / "bad"), + ] + with pytest.raises(Invalid, match="could not download") as exc_info: + external_files.download_content_many(items) + assert not isinstance(exc_info.value, MultipleInvalid) + + +def test_download_content_many_aggregates_multiple_errors( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """Every failing worker should be reported in a single MultipleInvalid so + the user sees all broken URLs in one validation pass instead of fixing + them one network round-trip at a time. + """ + + def fake_download(url: str, path: Path, timeout: int) -> bytes: + if url.endswith("ok"): + return b"" + raise Invalid(f"could not download {url}") + + mock_download_content.side_effect = fake_download + items = [ + ("https://example.com/ok", setup_core / "ok"), + ("https://example.com/bad1", setup_core / "bad1"), + ("https://example.com/bad2", setup_core / "bad2"), + ] + with pytest.raises(MultipleInvalid) as exc_info: + external_files.download_content_many(items) + messages = {str(e) for e in exc_info.value.errors} + assert messages == { + "could not download https://example.com/bad1", + "could not download https://example.com/bad2", + } + + +def test_download_content_many_dedupes_by_path( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """Two items pointing at the same cache path must collapse to one + download -- otherwise concurrent writes race on the same file. Which + URL wins doesn't matter (in practice duplicate paths only arise when + the URL is duplicated), so we only assert the call count and path. + """ + path = setup_core / "shared" + items = [ + ("https://example.com/a", path), + ("https://example.com/b", path), + ("https://example.com/a", path), + ] + external_files.download_content_many(items) + assert mock_download_content.call_count == 1 + args, _ = mock_download_content.call_args + assert args[1] == path + + +def test_download_content_many_clamps_invalid_max_workers( + mock_download_content: MagicMock, setup_core: Path +) -> None: + """`max_workers <= 0` must not raise from ThreadPoolExecutor; it should + be clamped up to at least 1 worker. + """ + items = [ + ("https://example.com/a", setup_core / "a"), + ("https://example.com/b", setup_core / "b"), + ] + external_files.download_content_many(items, max_workers=0) + assert mock_download_content.call_count == 2 + + +def test_download_web_files_in_config_filters_and_dispatches( + mock_download_content_many: MagicMock, setup_core: Path +) -> None: + """Only `file.type == "web"` entries should be forwarded to + download_content_many, and the unmodified config should be returned so + the helper can sit in a `cv.All(...)` chain. + """ + + def path_for(file_dict: dict) -> Path: + return setup_core / file_dict["url"].rsplit("/", 1)[-1] + + config = [ + {"file": {"type": "web", "url": "https://example.com/a"}}, + {"file": {"type": "local", "path": "/tmp/b"}}, + {"file": {"type": "web", "url": "https://example.com/c"}}, + {}, # no `file` key at all + ] + result = external_files.download_web_files_in_config(config, path_for) + + assert result is config + mock_download_content_many.assert_called_once() + assert list(mock_download_content_many.call_args[0][0]) == [ + ("https://example.com/a", setup_core / "a"), + ("https://example.com/c", setup_core / "c"), + ] + + +def test_download_web_files_in_config_no_web_entries( + mock_download_content_many: MagicMock, setup_core: Path +) -> None: + """A config with no web entries should still call through to + download_content_many (which is itself a no-op for empty input) so the + behavior stays consistent. + """ + config = [{"file": {"type": "local", "path": "/tmp/a"}}] + external_files.download_web_files_in_config(config, lambda _: setup_core / "x") + mock_download_content_many.assert_called_once() + assert list(mock_download_content_many.call_args[0][0]) == [] + + def test_download_content_saves_etag( mock_has_remote_file_changed: MagicMock, mock_requests_get: MagicMock,