diff --git a/esphome/platformio_api.py b/esphome/platformio_api.py index d42f89d0290..5d4065207f0 100644 --- a/esphome/platformio_api.py +++ b/esphome/platformio_api.py @@ -5,6 +5,7 @@ import os from pathlib import Path import re import subprocess +import time from typing import Any from esphome.const import CONF_COMPILE_PROCESS_LIMIT, CONF_ESPHOME, KEY_CORE @@ -44,31 +45,61 @@ def patch_structhash(): def patch_file_downloader(): - """Patch PlatformIO's FileDownloader to retry on PackageException errors.""" + """Patch PlatformIO's FileDownloader to retry on PackageException errors. + + PlatformIO's FileDownloader uses HTTPSession which lacks built-in retry + for 502/503 errors. We add retries with exponential backoff and close the + session between attempts to force a fresh TCP connection, which may route + to a different CDN edge node. + """ from platformio.package.download import FileDownloader from platformio.package.exception import PackageException + if getattr(FileDownloader.__init__, "_esphome_patched", False): + return + original_init = FileDownloader.__init__ def patched_init(self, *args: Any, **kwargs: Any) -> None: - max_retries = 3 + max_retries = 5 for attempt in range(max_retries): try: - return original_init(self, *args, **kwargs) + original_init(self, *args, **kwargs) + return except PackageException as e: if attempt < max_retries - 1: + # Exponential backoff: 2, 4, 8, 16 seconds + delay = 2 ** (attempt + 1) _LOGGER.warning( - "Package download failed: %s. Retrying... (attempt %d/%d)", + "Package download failed: %s. " + "Retrying in %d seconds... (attempt %d/%d)", str(e), + delay, attempt + 1, max_retries, ) + # Close the response and session to free resources + # and force a new TCP connection on retry, which may + # route to a different CDN edge node + # pylint: disable=protected-access,broad-except + try: + if ( + hasattr(self, "_http_response") + and self._http_response is not None + ): + self._http_response.close() + if hasattr(self, "_http_session"): + self._http_session.close() + except Exception: + pass + # pylint: enable=protected-access,broad-except + time.sleep(delay) else: # Final attempt - re-raise raise - return None + patched_init._esphome_patched = True # type: ignore[attr-defined] # pylint: disable=protected-access FileDownloader.__init__ = patched_init diff --git a/tests/unit_tests/test_platformio_api.py b/tests/unit_tests/test_platformio_api.py index 4d7b635e595..16861442777 100644 --- a/tests/unit_tests/test_platformio_api.py +++ b/tests/unit_tests/test_platformio_api.py @@ -6,7 +6,7 @@ import os from pathlib import Path import shutil from types import SimpleNamespace -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, call, patch import pytest @@ -673,6 +673,200 @@ def test_process_stacktrace_bad_alloc( assert state is False +def test_patch_file_downloader_succeeds_first_try() -> None: + """Test patch_file_downloader succeeds on first attempt.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + original_init = MagicMock() + + with patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type("FileDownloader", (), {"__init__": original_init}) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + original_init.assert_called_once() + + +def test_patch_file_downloader_retries_on_failure() -> None: + """Test patch_file_downloader retries with backoff on PackageException.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + call_count = 0 + + def failing_init(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise mock_exception_cls(f"502 error attempt {call_count}") + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", (), {"__init__": failing_init} + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep") as mock_sleep, + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Should have been called 3 times (2 failures + 1 success) + assert call_count == 3 + + # Should have slept with exponential backoff: 2s, 4s + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(2) + mock_sleep.assert_any_call(4) + + +def test_patch_file_downloader_raises_after_max_retries() -> None: + """Test patch_file_downloader raises after exhausting all retries.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + + def always_failing_init(self, *args, **kwargs): + raise mock_exception_cls("502 error") + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", (), {"__init__": always_failing_init} + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep") as mock_sleep, + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + with pytest.raises(mock_exception_cls, match="502 error"): + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Should have slept 4 times (before attempts 2-5), not on final attempt + assert mock_sleep.call_count == 4 + mock_sleep.assert_has_calls([call(2), call(4), call(8), call(16)]) + + +def test_patch_file_downloader_closes_session_and_response_between_retries() -> None: + """Test patch_file_downloader closes HTTP session and response between retries.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + mock_session = MagicMock() + mock_response = MagicMock() + call_count = 0 + + def failing_init_with_session(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + self._http_session = mock_session + self._http_response = mock_response + if call_count < 2: + raise mock_exception_cls("502 error") + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", + (), + {"__init__": failing_init_with_session}, + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep"), + ): + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Both response and session should have been closed between retries + mock_response.close.assert_called_once() + mock_session.close.assert_called_once() + + +def test_patch_file_downloader_idempotent() -> None: + """Test patch_file_downloader does not stack wrappers when called multiple times.""" + mock_exception_cls = type("PackageException", (Exception,), {}) + call_count = 0 + + def counting_init(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + + with patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type("FileDownloader", (), {"__init__": counting_init}) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ): + # Patch multiple times + platformio_api.patch_file_downloader() + platformio_api.patch_file_downloader() + platformio_api.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + # Should only be called once, not 3 times from stacked wrappers + assert call_count == 1 + + def test_platformio_log_filter_allows_non_platformio_messages() -> None: """Test that non-platformio logger messages are allowed through.""" log_filter = platformio_api.PlatformioLogFilter()