diff --git a/esphome/platformio/runner.py b/esphome/platformio/runner.py index 976979dc57b..caab47dcc24 100644 --- a/esphome/platformio/runner.py +++ b/esphome/platformio/runner.py @@ -51,9 +51,13 @@ def patch_file_downloader() -> None: """Retry PlatformIO package downloads with exponential backoff. PlatformIO's ``FileDownloader`` uses an ``HTTPSession`` without built-in - retry for 502/503 errors. We wrap ``__init__`` to retry on - ``PackageException`` and close the session between attempts so a new - TCP connection can route to a different CDN edge node. + retry. We wrap ``__init__`` to retry on transient failures and close the + session between attempts so a new TCP connection can route to a different + CDN edge node. We catch both ``PackageException`` (raised when the server + returns a non-200 status such as 502/503) and ``OSError`` -- which covers + ``requests.exceptions.ConnectionError``, ``ReadTimeout``, and + ``ChunkedEncodingError`` (all subclasses of ``OSError``) that get raised + when the connection is aborted before a response is parsed. """ from platformio.package.download import FileDownloader from platformio.package.exception import PackageException @@ -70,7 +74,7 @@ def patch_file_downloader() -> None: try: original_init(self, *args, **kwargs) return - except PackageException as e: + except (PackageException, OSError) as e: if attempt < max_retries - 1: delay = 2 ** (attempt + 1) _LOGGER.warning( diff --git a/tests/unit_tests/test_platformio_toolchain.py b/tests/unit_tests/test_platformio_toolchain.py index f771437dd47..c1d16530cbf 100644 --- a/tests/unit_tests/test_platformio_toolchain.py +++ b/tests/unit_tests/test_platformio_toolchain.py @@ -2,10 +2,13 @@ # pylint: disable=protected-access +from contextlib import contextmanager +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer import json import os from pathlib import Path import shutil +import threading from types import SimpleNamespace from unittest.mock import MagicMock, Mock, call, patch @@ -867,6 +870,56 @@ def test_patch_file_downloader_closes_session_and_response_between_retries() -> mock_session.close.assert_called_once() +def test_patch_file_downloader_retries_on_connection_error() -> None: + """Test patch_file_downloader retries on transport-layer errors (OSError subclasses). + + ``requests.exceptions.ConnectionError`` and ``ReadTimeout`` subclass + ``OSError`` and are raised when the connection is aborted before any HTTP + response is parsed -- e.g. ``RemoteDisconnected`` mid-download. These must + retry too, not just ``PackageException``. + """ + mock_exception_cls = type("PackageException", (Exception,), {}) + call_count = 0 + + def failing_init(self, *args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ConnectionError( + f"Connection aborted attempt {call_count}: RemoteDisconnected" + ) + + with ( + patch.dict( + "sys.modules", + { + "platformio": MagicMock(), + "platformio.package": MagicMock(), + "platformio.package.download": SimpleNamespace( + FileDownloader=type( + "FileDownloader", (), {"__init__": failing_init} + ) + ), + "platformio.package.exception": SimpleNamespace( + PackageException=mock_exception_cls + ), + }, + ), + patch("time.sleep") as mock_sleep, + ): + runner.patch_file_downloader() + + from platformio.package.download import FileDownloader + + instance = object.__new__(FileDownloader) + FileDownloader.__init__(instance, "http://example.com/file.zip") + + assert call_count == 3 + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(2) + mock_sleep.assert_any_call(4) + + def test_patch_file_downloader_idempotent() -> None: """Test patch_file_downloader does not stack wrappers when called multiple times.""" mock_exception_cls = type("PackageException", (Exception,), {}) @@ -903,6 +956,74 @@ def test_patch_file_downloader_idempotent() -> None: assert call_count == 1 +@contextmanager +def _flaky_http_server(fail_first_n: int, fail_mode: str): + """Local HTTP server that fails the first ``fail_first_n`` requests. + + ``fail_mode="drop"`` closes the TCP connection without responding, so + the client raises ``RemoteDisconnected`` -- the exact CI failure mode. + ``fail_mode="502"`` returns an HTTP 502, triggering ``PackageException``. + """ + state = {"hits": 0} + + class _Handler(BaseHTTPRequestHandler): + def handle_one_request(self) -> None: + state["hits"] += 1 + if state["hits"] <= fail_first_n and fail_mode == "drop": + return # Skip read+respond → kernel sends FIN → RemoteDisconnected + super().handle_one_request() + + def do_GET(self) -> None: # noqa: N802 + if state["hits"] <= fail_first_n and fail_mode == "502": + self.send_error(502) + return + body = b"esphome-test-payload" + self.send_response(200) + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def log_message(self, format: str, *args: object) -> None: # noqa: A002 + pass # silence default stderr logging + + server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + try: + yield server.server_address[1], state + finally: + server.shutdown() + server.server_close() + thread.join(timeout=2) + + +@pytest.mark.parametrize("fail_mode", ["drop", "502"]) +def test_patch_file_downloader_recovers_against_real_server( + tmp_path: Path, fail_mode: str +) -> None: + """End-to-end: real PlatformIO ``FileDownloader`` against a local server + that fails twice then succeeds. Exercises the real + requests/urllib3/http.client stack for both failure modes: + + - ``drop``: TCP close mid-request → ``RemoteDisconnected`` → caught as + ``OSError`` by the retry patch (the CI failure path). + - ``502``: HTTP error response → ``PackageException`` (the original path). + """ + runner.patch_file_downloader() + from platformio.package.download import FileDownloader + + with ( + _flaky_http_server(fail_first_n=2, fail_mode=fail_mode) as (port, state), + patch("time.sleep"), + ): + fd = FileDownloader(f"http://127.0.0.1:{port}/payload.bin") + fd.set_destination(str(tmp_path / "out.bin")) + fd.start(with_progress=False, silent=True) + + assert state["hits"] == 3 # 2 failures + 1 success + assert (tmp_path / "out.bin").read_bytes() == b"esphome-test-payload" + + def _filter_through_redirect(line: str) -> str: """Write a line through RedirectText with FILTER_PLATFORMIO_LINES and return what passes.""" import io