From f30ad588ea421db25f0784b91365bb44f5991751 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 5 May 2026 18:25:53 -0500 Subject: [PATCH] [cli] Add --ota-platform flag to pick web_server or native API OTA (#16207) --- esphome/__main__.py | 166 +++++- esphome/web_server_ota.py | 202 +++++++ tests/unit_tests/test_main.py | 283 ++++++++++ tests/unit_tests/test_web_server_ota.py | 670 ++++++++++++++++++++++++ 4 files changed, 1307 insertions(+), 14 deletions(-) create mode 100644 esphome/web_server_ota.py create mode 100644 tests/unit_tests/test_web_server_ota.py diff --git a/esphome/__main__.py b/esphome/__main__.py index 222e299f6d..f4a276b74c 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -28,6 +28,7 @@ from esphome.const import ( ALLOWED_NAME_CHARS, ARGUMENT_HELP_DEVICE, CONF_API, + CONF_AUTH, CONF_BAUD_RATE, CONF_BROKER, CONF_DEASSERT_RTS_DTR, @@ -47,6 +48,8 @@ from esphome.const import ( CONF_PORT, CONF_SUBSTITUTIONS, CONF_TOPIC, + CONF_USERNAME, + CONF_WEB_SERVER, ENV_NOGITIGNORE, KEY_CORE, KEY_NATIVE_IDF, @@ -349,6 +352,17 @@ def choose_upload_log_host( elif bootsel.permission_error: bootsel_permission_error = True + # Annotate the OTA chooser entry only in the non-default case: when the + # config has web_server OTA but no native API OTA, the upload will fall + # through to the HTTP path and the user benefits from seeing that + # explicitly. The native-API path is the default and gets a plain label + # to avoid noise on the most common scenario. For LOGGING the OTA + # transport doesn't apply, so always leave the label plain. + if purpose == Purpose.UPLOADING and not has_native_ota() and has_web_server_ota(): + ota_suffix = " via web_server" + else: + ota_suffix = "" + def add_ota_options() -> None: """Add OTA options, using mDNS discovery if name_add_mac_suffix is enabled.""" if (discovered := _discover_mac_suffix_devices()) is not None: @@ -356,11 +370,11 @@ def choose_upload_log_host( # intentionally skip the base-name fallback since with # name_add_mac_suffix on, the base name doesn't exist on the net. for host in discovered: - options.append((f"Over The Air ({host})", host)) + options.append((f"Over The Air{ota_suffix} ({host})", host)) elif has_resolvable_address(): - options.append((f"Over The Air ({CORE.address})", CORE.address)) + options.append((f"Over The Air{ota_suffix} ({CORE.address})", CORE.address)) if has_mqtt_ip_lookup(): - options.append(("Over The Air (MQTT IP lookup)", "MQTTIP")) + options.append((f"Over The Air{ota_suffix} (MQTT IP lookup)", "MQTTIP")) if purpose == Purpose.LOGGING: if has_mqtt_logging(): @@ -429,7 +443,19 @@ def has_api() -> bool: def has_ota() -> bool: - """Check if OTA upload is available (requires platform: esphome).""" + """Check if any network OTA upload is available. + + True if the config exposes either ``platform: esphome`` (native API + OTA) or ``platform: web_server`` (HTTP OTA). Both reach the device + over the same network stack, so the OTA discovery path treats them + interchangeably; ``upload_program`` picks the actual transport based + on ``--ota-platform`` and what's configured. + """ + return has_native_ota() or has_web_server_ota() + + +def has_native_ota() -> bool: + """Check if native API OTA upload is available (``platform: esphome``).""" if CONF_OTA not in CORE.config: return False return any( @@ -438,6 +464,16 @@ def has_ota() -> bool: ) +def has_web_server_ota() -> bool: + """Check if web_server OTA upload is available (``platform: web_server``).""" + if CONF_OTA not in CORE.config: + return False + return any( + ota_item.get(CONF_PLATFORM) == CONF_WEB_SERVER + for ota_item in CORE.config[CONF_OTA] + ) + + def has_mqtt_ip_lookup() -> bool: """Check if MQTT is available and IP lookup is supported.""" from esphome.components.mqtt import CONF_DISCOVER_IP @@ -1115,25 +1151,83 @@ def upload_program( return exit_code, host if exit_code == 0 else None - ota_conf = {} + requested_platform = getattr(args, "ota_platform", None) + chosen_platform = _choose_ota_platform(config, requested_platform) + + # Resolve MQTT magic strings to actual IP addresses + network_devices = _resolve_network_devices(devices, config, args) + + if chosen_platform == CONF_WEB_SERVER: + if getattr(args, "partition_table", False): + raise EsphomeError( + "--partition-table is only supported with the esphome OTA platform; " + "the web_server OTA path can only update the firmware image." + ) + binary = CORE.firmware_bin + if getattr(args, "file", None) is not None: + binary = Path(args.file) + return _upload_via_web_server(config, network_devices, binary) + + return _upload_via_native_api(config, network_devices, args) + + +def _choose_ota_platform(config: ConfigType, requested: str | None) -> str: + """Pick the OTA platform to use, optionally honoring ``--ota-platform``. + + Default behavior prefers ``esphome`` (native API) when it is configured. + The native API uses challenge-response auth with MD5/SHA256 hashing of a + server-issued nonce, so the password is never sent over the wire; the + ``web_server`` path uses HTTP Basic auth which transmits credentials in + cleartext over the LAN. (The native path also supports gzip compression + on ESP8266, where flash space is tight; on ESP32/RP2040/LibreTiny the + backend reports ``supports_compression() == false`` and the firmware is + sent uncompressed regardless of which platform is used.) Falls back to + ``web_server`` only when that is the only available platform. + """ + # Use a dict (insertion-ordered) instead of a list so error messages and + # membership checks see one entry per platform even if the user has + # multiple ``ota:`` items of the same platform; the web_server OTA + # platform's final-validate hook merges duplicates anyway. + available: dict[str, None] = {} for ota_item in config.get(CONF_OTA, []): - if ota_item[CONF_PLATFORM] == CONF_ESPHOME: + platform = ota_item.get(CONF_PLATFORM) + if platform in (CONF_ESPHOME, CONF_WEB_SERVER): + available[platform] = None + + if not available: + raise EsphomeError( + f"Cannot upload Over the Air as the {CONF_OTA} configuration is not " + f"present or does not include {CONF_PLATFORM}: {CONF_ESPHOME} or " + f"{CONF_PLATFORM}: {CONF_WEB_SERVER}" + ) + + if requested is not None: + if requested not in available: + raise EsphomeError( + f"--ota-platform {requested} was requested but the configuration " + f"only provides: {', '.join(available)}" + ) + return requested + + if CONF_ESPHOME in available: + return CONF_ESPHOME + return CONF_WEB_SERVER + + +def _upload_via_native_api( + config: ConfigType, network_devices: list[str], args: ArgsProtocol +) -> tuple[int, str | None]: + ota_conf: ConfigType = {} + for ota_item in config.get(CONF_OTA, []): + if ota_item.get(CONF_PLATFORM) == CONF_ESPHOME: ota_conf = ota_item break - if not ota_conf: - raise EsphomeError( - f"Cannot upload Over the Air as the {CONF_OTA} configuration is not present or does not include {CONF_PLATFORM}: {CONF_ESPHOME}" - ) - from esphome import espota2 remote_port = int(ota_conf[CONF_PORT]) password = ota_conf.get(CONF_PASSWORD) - # Resolve MQTT magic strings to actual IP addresses - network_devices = _resolve_network_devices(devices, config, args) - binary = CORE.firmware_bin ota_type = espota2.OTA_TYPE_UPDATE_APP if getattr(args, "partition_table", False): @@ -1157,6 +1251,28 @@ def upload_program( return espota2.run_ota(network_devices, remote_port, password, binary, ota_type) +def _upload_via_web_server( + config: ConfigType, network_devices: list[str], binary: Path +) -> tuple[int, str | None]: + web_conf = config.get(CONF_WEB_SERVER) + if not web_conf: + raise EsphomeError( + f"Cannot upload via web_server OTA: the {CONF_WEB_SERVER} component " + f"is not configured." + ) + + remote_port = int(web_conf[CONF_PORT]) + auth = web_conf.get(CONF_AUTH) or {} + username = auth.get(CONF_USERNAME) + password = auth.get(CONF_PASSWORD) + + from esphome import web_server_ota + + return web_server_ota.run_ota( + network_devices, remote_port, username, password, binary + ) + + # Layout of esp_partition_info_t on flash. Each entry is 32 bytes, leading with a # 16-bit little-endian magic. ESP-IDF defines ESP_PARTITION_MAGIC = 0x50AA (stored as # bytes 0xAA, 0x50) for partition entries and ESP_PARTITION_MAGIC_MD5 = 0xEBEB for the @@ -1877,6 +1993,17 @@ def parse_args(argv): "--file", help="Manually specify the binary file to upload.", ) + parser_upload.add_argument( + "--ota-platform", + choices=[CONF_ESPHOME, CONF_WEB_SERVER], + help=( + "OTA platform to use for network uploads. Defaults to " + f"'{CONF_ESPHOME}' (native API) when configured because it uses " + "challenge-response auth so the password is never sent in " + f"cleartext on the wire. Falls back to '{CONF_WEB_SERVER}' " + "(HTTP Basic auth) when that is the only configured platform." + ), + ) parser_upload.add_argument( "--partition-table", help="Upload as partition table (OTA).", @@ -1951,6 +2078,17 @@ def parse_args(argv): help="Build with native ESP-IDF instead of PlatformIO (ESP32 esp-idf framework only).", action="store_true", ) + parser_run.add_argument( + "--ota-platform", + choices=[CONF_ESPHOME, CONF_WEB_SERVER], + help=( + "OTA platform to use for network uploads. Defaults to " + f"'{CONF_ESPHOME}' (native API) when configured because it uses " + "challenge-response auth so the password is never sent in " + f"cleartext on the wire. Falls back to '{CONF_WEB_SERVER}' " + "(HTTP Basic auth) when that is the only configured platform." + ), + ) parser_clean = subparsers.add_parser( "clean-mqtt", diff --git a/esphome/web_server_ota.py b/esphome/web_server_ota.py new file mode 100644 index 0000000000..a49f46b270 --- /dev/null +++ b/esphome/web_server_ota.py @@ -0,0 +1,202 @@ +"""HTTP-based OTA upload via the ``web_server`` component's ``/update`` endpoint. + +This is the alternative to ``espota2`` (the native API OTA path). Useful when +a device only has ``platform: web_server`` configured under ``ota:``, or when +the user has lost the native OTA password but still has ``web_server`` basic +auth credentials. +""" + +from __future__ import annotations + +import io +import logging +from pathlib import Path +import secrets +import socket +from typing import BinaryIO + +import requests +from requests.auth import HTTPBasicAuth + +from esphome.core import EsphomeError +from esphome.helpers import ProgressBar, resolve_ip_address + +_LOGGER = logging.getLogger(__name__) + +OTA_PATH = "/update" +FORM_FIELD = "update" +# (connect_timeout, read_timeout). The device reboots after a successful +# upload so the read side must allow for a slow flash + response. +TIMEOUT = (20.0, 120.0) + + +class WebServerOTAError(EsphomeError): + pass + + +class _MultipartStreamer: + """Stream a single-file multipart/form-data body during transmission. + + ``requests.post(files=...)`` materializes the entire body in memory before + sending, so a progress callback wired into the file-like fires during + encoding instead of during the network send. Pass this via ``data=`` + (with ``__len__`` so urllib3 sets ``Content-Length`` instead of using + chunked transfer encoding); urllib3 then calls ``read(blocksize)`` + repeatedly during the POST and the progress bar tracks bytes leaving the + host. + """ + + def __init__(self, file: BinaryIO, file_size: int, filename: str) -> None: + self.boundary = f"esphomeOTA{secrets.token_hex(16)}" + prefix = ( + f"--{self.boundary}\r\n" + f'Content-Disposition: form-data; name="{FORM_FIELD}"; ' + f'filename="{filename}"\r\n' + f"Content-Type: application/octet-stream\r\n\r\n" + ).encode() + suffix = f"\r\n--{self.boundary}--\r\n".encode() + # Walked in order; ``read()`` advances to the next source on EOF. + self._sources: list[BinaryIO] = [io.BytesIO(prefix), file, io.BytesIO(suffix)] + self._idx = 0 + self._total = len(prefix) + file_size + len(suffix) + self._sent = 0 + self.progress = ProgressBar() + + def __len__(self) -> int: + return self._total + + @property + def content_type(self) -> str: + return f"multipart/form-data; boundary={self.boundary}" + + def read(self, size: int = -1) -> bytes: + remaining = self._total if size is None or size < 0 else size + out = bytearray() + while remaining > 0 and self._idx < len(self._sources): + chunk = self._sources[self._idx].read(remaining) + if not chunk: + self._idx += 1 + continue + out += chunk + remaining -= len(chunk) + if out: + self._sent += len(out) + self.progress.update(self._sent / self._total) + return bytes(out) + + +def _try_upload( + host: str, + port: int, + username: str | None, + password: str | None, + filename: Path, +) -> tuple[int, str | None]: + from esphome.core import CORE + + try: + addr_infos = resolve_ip_address(host, port, address_cache=CORE.address_cache) + except EsphomeError as err: + _LOGGER.error( + "Error resolving IP address of %s. Is it connected to WiFi?", host + ) + if not CORE.dashboard: + _LOGGER.error("(If you know the IP, try --device )") + raise WebServerOTAError(err) from err + + if not addr_infos: + _LOGGER.error("Could not resolve %s", host) + return 1, None + + file_size = filename.stat().st_size + _LOGGER.info("Uploading %s (%s bytes) via web_server OTA", filename, file_size) + auth = HTTPBasicAuth(username, password) if username and password else None + + # Iterate resolved IPs (IPv4 + IPv6 candidates) just like espota2 does. + for af, _socktype, _, _, sa in addr_infos: + ip = sa[0] + # IPv6 literals must be wrapped in brackets in URLs; link-local + # addresses need a percent-encoded zone index per RFC 6874. + if af == socket.AF_INET6: + scope = sa[3] if len(sa) >= 4 else 0 + host_part = f"[{ip}%25{scope}]" if scope else f"[{ip}]" + else: + host_part = ip + url = f"http://{host_part}:{port}{OTA_PATH}" + _LOGGER.info("Connecting to %s port %s...", ip, port) + + try: + with open(filename, "rb") as fh: + streamer = _MultipartStreamer(fh, file_size, filename.name) + try: + response = requests.post( + url, + data=streamer, + auth=auth, + timeout=TIMEOUT, + headers={ + "Content-Type": streamer.content_type, + "Connection": "close", + }, + ) + finally: + streamer.progress.done() + except requests.RequestException as err: + _LOGGER.error("OTA upload to %s port %s failed: %s", ip, port, err) + continue + + if response.status_code == 401: + raise WebServerOTAError( + "Authentication failed (HTTP 401). Check the 'web_server' " + "'auth' username and password." + ) + if response.status_code != 200: + detail = response.text.strip() or response.reason or "no response body" + raise WebServerOTAError( + f"Unexpected HTTP {response.status_code} response from device: {detail}" + ) + + # The endpoint returns HTTP 200 for both success and failure; the + # body is what tells us which (see ota_web_server.cpp handleRequest). + body = response.text.strip() + if "Successful" in body: + _LOGGER.info("Device response: %s", body) + _LOGGER.info("OTA successful") + return 0, ip + + raise WebServerOTAError( + f"Device reported OTA failure: {body or 'no response body'}" + ) + + return 1, None + + +def run_ota( + remote_hosts: str | list[str], + remote_port: int, + username: str | None, + password: str | None, + filename: Path, +) -> tuple[int, str | None]: + """Upload ``filename`` to the first reachable host via ``web_server`` OTA. + + Mirrors :func:`esphome.espota2.run_ota` so callers can swap between the + two paths with the same return contract: ``(0, host)`` on success or + ``(1, None)`` on failure. + """ + hosts = [remote_hosts] if isinstance(remote_hosts, str) else list(remote_hosts) + for host in hosts: + try: + exit_code, used_host = _try_upload( + host, remote_port, username, password, filename + ) + except WebServerOTAError as err: + _LOGGER.error("%s", err) + continue + if exit_code == 0: + return 0, used_host + # Reached only when every attempt failed; per-attempt errors were + # already logged. This summary line gives the user an unambiguous + # "stop reading, nothing worked" marker. + _LOGGER.error("OTA upload failed.") + return 1, None diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index f8a3ea888e..0b96000a57 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -43,6 +43,7 @@ from esphome.__main__ import ( has_non_ip_address, has_ota, has_resolvable_address, + has_web_server_ota, mqtt_get_ip, run_esphome, run_miniterm, @@ -58,6 +59,7 @@ from esphome.components import esp32 from esphome.components.esp32 import KEY_ESP32, KEY_VARIANT, VARIANT_ESP32 from esphome.const import ( CONF_API, + CONF_AUTH, CONF_BAUD_RATE, CONF_BROKER, CONF_DISABLED, @@ -76,6 +78,8 @@ from esphome.const import ( CONF_SUBSTITUTIONS, CONF_TOPIC, CONF_USE_ADDRESS, + CONF_USERNAME, + CONF_WEB_SERVER, CONF_WIFI, KEY_CORE, KEY_TARGET_PLATFORM, @@ -213,6 +217,13 @@ def mock_run_ota() -> Generator[Mock]: yield mock +@pytest.fixture +def mock_run_web_server_ota() -> Generator[Mock]: + """Mock web_server_ota.run_ota for testing.""" + with patch("esphome.web_server_ota.run_ota") as mock: + yield mock + + @pytest.fixture def mock_is_ip_address() -> Generator[Mock]: """Mock is_ip_address for testing.""" @@ -1114,6 +1125,7 @@ class MockArgs: reset: bool = False list_only: bool = False output: str | None = None + ota_platform: str | None = None partition_table: bool = False @@ -1878,6 +1890,277 @@ def test_upload_program_ota_no_config( upload_program(config, args, devices) +def test_has_web_server_ota_detects_platform() -> None: + """has_web_server_ota returns True when web_server OTA platform is configured.""" + setup_core( + config={ + CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}], + } + ) + assert has_web_server_ota() is True + assert has_ota() is True + + +def test_has_web_server_ota_returns_false_without_config() -> None: + """has_web_server_ota returns False when only native OTA is configured.""" + setup_core( + config={ + CONF_OTA: [{CONF_PLATFORM: CONF_ESPHOME}], + } + ) + assert has_web_server_ota() is False + assert has_ota() is True + + +def test_upload_program_web_server_only_auto_dispatches( + mock_run_web_server_ota: Mock, + mock_run_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """When only web_server OTA is configured, upload_program picks it automatically.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + mock_run_web_server_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}], + CONF_WEB_SERVER: { + CONF_PORT: 80, + CONF_AUTH: {CONF_USERNAME: "admin", CONF_PASSWORD: "pw"}, + }, + } + args = MockArgs() + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + expected_firmware = ( + tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" + ) + mock_run_web_server_ota.assert_called_once_with( + ["192.168.1.100"], 80, "admin", "pw", expected_firmware + ) + mock_run_ota.assert_not_called() + + +def test_upload_program_web_server_no_auth( + mock_run_web_server_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """web_server OTA works without an auth block (passes None for credentials).""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + mock_run_web_server_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}], + CONF_WEB_SERVER: {CONF_PORT: 8080}, + } + args = MockArgs() + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + expected_firmware = ( + tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" + ) + mock_run_web_server_ota.assert_called_once_with( + ["192.168.1.100"], 8080, None, None, expected_firmware + ) + + +def test_upload_program_both_platforms_default_prefers_native( + mock_run_ota: Mock, + mock_run_web_server_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """When both OTA platforms are configured, default selection is native API.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + CONF_PASSWORD: "secret", + }, + {CONF_PLATFORM: CONF_WEB_SERVER}, + ], + CONF_WEB_SERVER: {CONF_PORT: 80}, + } + args = MockArgs() + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_run_ota.assert_called_once() + mock_run_web_server_ota.assert_not_called() + + +def test_upload_program_ota_platform_override_to_web_server( + mock_run_ota: Mock, + mock_run_web_server_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """--ota-platform web_server forces web_server OTA even when native is configured.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + mock_run_web_server_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + CONF_PASSWORD: "secret", + }, + {CONF_PLATFORM: CONF_WEB_SERVER}, + ], + CONF_WEB_SERVER: {CONF_PORT: 80}, + } + args = MockArgs(ota_platform=CONF_WEB_SERVER) + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_run_ota.assert_not_called() + mock_run_web_server_ota.assert_called_once() + + +def test_upload_program_ota_platform_unavailable( + mock_get_port_type: Mock, +) -> None: + """--ota-platform must reference a platform that is actually configured.""" + setup_core(platform=PLATFORM_ESP32) + mock_get_port_type.return_value = "NETWORK" + + config = { + CONF_OTA: [ + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + CONF_PASSWORD: "secret", + } + ], + } + args = MockArgs(ota_platform=CONF_WEB_SERVER) + devices = ["192.168.1.100"] + + with pytest.raises(EsphomeError, match="--ota-platform web_server"): + upload_program(config, args, devices) + + +def test_upload_program_web_server_missing_component( + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """web_server OTA without a web_server component fails with a clear error.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + + config = { + CONF_OTA: [{CONF_PLATFORM: CONF_WEB_SERVER}], + # No CONF_WEB_SERVER + } + args = MockArgs() + devices = ["192.168.1.100"] + + with pytest.raises(EsphomeError, match="web_server.*not configured"): + upload_program(config, args, devices) + + +def test_upload_program_unrelated_ota_platform_ignored( + mock_run_ota: Mock, + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """OTA list entries that are neither esphome nor web_server are ignored. + + Covers the false branch in _choose_ota_platform's filter loop and the + no-match branch in _upload_via_native_api's lookup loop. + """ + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + mock_run_ota.return_value = (0, "192.168.1.100") + + config = { + CONF_OTA: [ + {CONF_PLATFORM: "http_request"}, # unrelated platform; ignored + { + CONF_PLATFORM: CONF_ESPHOME, + CONF_PORT: 3232, + CONF_PASSWORD: "secret", + }, + ], + } + args = MockArgs() + devices = ["192.168.1.100"] + + exit_code, host = upload_program(config, args, devices) + + assert exit_code == 0 + assert host == "192.168.1.100" + mock_run_ota.assert_called_once() + + +def test_upload_program_duplicate_platform_dedup_in_error( + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """Duplicate same-platform OTA entries don't repeat in --ota-platform errors.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + + config = { + CONF_OTA: [ + {CONF_PLATFORM: CONF_ESPHOME, CONF_PORT: 3232}, + {CONF_PLATFORM: CONF_ESPHOME, CONF_PORT: 3233}, + ], + } + args = MockArgs(ota_platform=CONF_WEB_SERVER) + devices = ["192.168.1.100"] + + with pytest.raises(EsphomeError) as excinfo: + upload_program(config, args, devices) + + # Error mentions esphome once in the platform list, not "esphome, esphome". + msg = str(excinfo.value) + assert "esphome, esphome" not in msg + assert msg.endswith(": esphome") + + +def test_upload_program_only_unrelated_ota_platforms( + mock_get_port_type: Mock, + tmp_path: Path, +) -> None: + """Only unrelated OTA platforms configured -> raises like missing OTA.""" + setup_core(platform=PLATFORM_ESP32, tmp_path=tmp_path) + mock_get_port_type.return_value = "NETWORK" + + config = { + CONF_OTA: [{CONF_PLATFORM: "http_request"}], + } + args = MockArgs() + devices = ["192.168.1.100"] + + with pytest.raises(EsphomeError, match="Cannot upload Over the Air"): + upload_program(config, args, devices) + + def test_upload_program_ota_with_mqtt_resolution( mock_mqtt_get_ip: Mock, mock_is_ip_address: Mock, diff --git a/tests/unit_tests/test_web_server_ota.py b/tests/unit_tests/test_web_server_ota.py new file mode 100644 index 0000000000..606905e36e --- /dev/null +++ b/tests/unit_tests/test_web_server_ota.py @@ -0,0 +1,670 @@ +"""Unit tests for esphome.web_server_ota module.""" + +from __future__ import annotations + +import io +import logging +from pathlib import Path +import socket +from unittest.mock import MagicMock, patch + +import pytest +import requests +from requests.auth import HTTPBasicAuth + +from esphome.core import CORE, EsphomeError +from esphome.helpers import ProgressBar +from esphome.web_server_ota import ( + OTA_PATH, + WebServerOTAError, + _MultipartStreamer, + run_ota, +) + + +@pytest.fixture +def firmware(tmp_path: Path) -> Path: + binary = tmp_path / "firmware.bin" + binary.write_bytes(b"\x00\x01\x02FIRMWARE\xff" * 64) + return binary + + +def _make_response(status: int, body: str) -> MagicMock: + response = MagicMock(spec=requests.Response) + response.status_code = status + response.text = body + response.reason = "" + return response + + +def _patch_resolve( + monkeypatch: pytest.MonkeyPatch, hosts: list[tuple[str, int]] +) -> None: + """Replace resolve_ip_address so tests don't actually do DNS.""" + addr_infos = [ + (socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, port)) + for host, port in hosts + ] + monkeypatch.setattr( + "esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos + ) + + +# --------------------------------------------------------------------------- +# _MultipartStreamer +# --------------------------------------------------------------------------- + + +def test_multipart_streamer_emits_full_body() -> None: + """Streaming the whole body in one call yields prefix + file + suffix.""" + data = b"abcdef" * 100 + streamer = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin") + + body = streamer.read() + while True: + chunk = streamer.read() + if not chunk: + break + body += chunk + + assert body.startswith(f"--{streamer.boundary}\r\n".encode()) + assert b'name="update"' in body + assert b'filename="fw.bin"' in body + assert data in body + assert body.endswith(f"\r\n--{streamer.boundary}--\r\n".encode()) + + +def test_multipart_streamer_chunked_read_matches_full_read() -> None: + """Chunked reads (urllib3 calls read(8192) repeatedly) yield the same body.""" + data = b"abcdef" * 1000 # 6000 bytes + full = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin").read() + + streamed = bytearray() + s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin") + # Same boundary lengths -> identical total length. + while True: + chunk = s.read(64) + if not chunk: + break + streamed += chunk + # Boundaries are random per instance, so compare lengths and structure. + assert len(streamed) == len(full) + assert streamed.startswith(f"--{s.boundary}\r\n".encode()) + assert streamed.endswith(f"\r\n--{s.boundary}--\r\n".encode()) + + +def test_multipart_streamer_len_matches_emitted_bytes() -> None: + """``__len__`` is what urllib3 uses to set Content-Length, so it must + equal the total bytes emitted by ``read``.""" + data = b"x" * 12345 + s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin") + declared = len(s) + + emitted = 0 + while True: + chunk = s.read(1024) + if not chunk: + break + emitted += len(chunk) + + assert emitted == declared + + +def test_multipart_streamer_progress_ticks_during_read() -> None: + """Each read advances the progress bar (this is the whole point of + streaming via ``data=``: progress reflects bytes leaving the host).""" + data = b"x" * 1000 + s = _MultipartStreamer(io.BytesIO(data), len(data), "fw.bin") + + updates: list[float] = [] + s.progress.update = updates.append # type: ignore[method-assign] + + while True: + chunk = s.read(128) + if not chunk: + break + + assert updates, "progress.update was never called" + # Strictly non-decreasing. + assert updates == sorted(updates) + # Final update reaches (within FP) 1.0 because all bytes were read. + assert updates[-1] == pytest.approx(1.0, abs=1e-9) + + +def test_multipart_streamer_content_type_includes_boundary() -> None: + s = _MultipartStreamer(io.BytesIO(b""), 0, "fw.bin") + assert s.content_type == f"multipart/form-data; boundary={s.boundary}" + + +def test_multipart_streamer_zero_size_file() -> None: + """A zero-byte file still produces a well-formed body and progress is + skipped (avoiding a divide-by-zero on the empty file segment).""" + s = _MultipartStreamer(io.BytesIO(b""), 0, "empty.bin") + body = b"" + while True: + chunk = s.read(64) + if not chunk: + break + body += chunk + assert body.startswith(f"--{s.boundary}".encode()) + assert body.endswith(f"--{s.boundary}--\r\n".encode()) + + +def test_multipart_streamer_unique_boundary_per_instance() -> None: + a = _MultipartStreamer(io.BytesIO(b""), 0, "a") + b = _MultipartStreamer(io.BytesIO(b""), 0, "a") + assert a.boundary != b.boundary + + +def test_multipart_streamer_zero_size_read_returns_empty() -> None: + """``read(0)`` short-circuits without touching state.""" + s = _MultipartStreamer(io.BytesIO(b"x" * 10), 10, "fw.bin") + assert s.read(0) == b"" + # No bytes consumed. + assert s._sent == 0 + + +# --------------------------------------------------------------------------- +# run_ota +# --------------------------------------------------------------------------- + + +def test_run_ota_success(monkeypatch: pytest.MonkeyPatch, firmware: Path) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + exit_code, host = run_ota(["device.local"], 80, None, None, firmware) + + assert exit_code == 0 + assert host == "192.168.1.50" + post.assert_called_once() + args, kwargs = post.call_args + assert args == (f"http://192.168.1.50:80{OTA_PATH}",) + assert kwargs["auth"] is None + # Streaming body, not files=, so progress fires during transmission. + assert "files" not in kwargs + assert isinstance(kwargs["data"], _MultipartStreamer) + assert kwargs["headers"]["Content-Type"] == kwargs["data"].content_type + assert kwargs["headers"]["Connection"] == "close" + + +def test_run_ota_logs_device_response_body( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """The device's HTTP response body is surfaced on success.""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + caplog.set_level(logging.INFO, logger="esphome.web_server_ota") + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ): + run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert "Device response: Update Successful!" in caplog.text + assert "OTA successful" in caplog.text + + +def test_run_ota_log_says_via_web_server( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """The upload-start log line names the transport explicitly.""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + caplog.set_level(logging.INFO, logger="esphome.web_server_ota") + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ): + run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert "via web_server OTA" in caplog.text + + +def test_run_ota_sends_basic_auth( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + exit_code, _ = run_ota(["192.168.1.50"], 80, "admin", "secret", firmware) + + assert exit_code == 0 + auth = post.call_args.kwargs["auth"] + assert isinstance(auth, HTTPBasicAuth) + assert auth.username == "admin" + assert auth.password == "secret" + + +def test_run_ota_skips_auth_when_no_credentials( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert post.call_args.kwargs["auth"] is None + + +def test_run_ota_skips_auth_when_only_username( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """Both username and password are required to send Basic auth.""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + run_ota(["192.168.1.50"], 80, "admin", None, firmware) + + assert post.call_args.kwargs["auth"] is None + + +def test_run_ota_uses_update_url( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 8080)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + run_ota(["192.168.1.50"], 8080, None, None, firmware) + + url = post.call_args.args[0] + assert url == f"http://192.168.1.50:8080{OTA_PATH}" + assert OTA_PATH == "/update" + + +def test_run_ota_failure_response( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Failed!"), + ): + exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + assert "OTA failure" in caplog.text + + +def test_run_ota_failure_response_empty_body( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, ""), + ): + exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + assert "no response body" in caplog.text + + +def test_run_ota_auth_failed( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(401, "Unauthorized"), + ): + exit_code, host = run_ota(["192.168.1.50"], 80, "user", "wrong", firmware) + + assert exit_code == 1 + assert host is None + assert "Authentication failed" in caplog.text + + +def test_run_ota_unexpected_status_code( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(500, "Internal Error"), + ): + exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + assert "Unexpected HTTP 500" in caplog.text + + +def test_run_ota_unexpected_status_empty_body_falls_back( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Empty response body uses response.reason / a fallback in the error.""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + response = _make_response(503, "") + response.reason = "Service Unavailable" + + with patch( + "esphome.web_server_ota.requests.post", + return_value=response, + ): + exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + assert "Service Unavailable" in caplog.text + + +def test_run_ota_unexpected_status_no_body_no_reason( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Empty body and empty reason still produce a usable error message.""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + response = _make_response(599, "") + response.reason = "" + + with patch( + "esphome.web_server_ota.requests.post", + return_value=response, + ): + run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert "no response body" in caplog.text + + +def test_run_ota_connection_error_then_success( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """First resolved address fails to connect, second succeeds.""" + _patch_resolve( + monkeypatch, + [("192.168.1.10", 80), ("192.168.1.50", 80)], + ) + + with patch( + "esphome.web_server_ota.requests.post", + side_effect=[ + requests.ConnectionError("refused"), + _make_response(200, "Update Successful!"), + ], + ) as post: + exit_code, host = run_ota(["device.local"], 80, None, None, firmware) + + assert exit_code == 0 + assert host == "192.168.1.50" + assert post.call_count == 2 + + +def test_run_ota_request_exception_falls_through( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """A non-ConnectionError RequestException (e.g. timeout) falls through too.""" + _patch_resolve( + monkeypatch, + [("192.168.1.10", 80), ("192.168.1.50", 80)], + ) + + with patch( + "esphome.web_server_ota.requests.post", + side_effect=[ + requests.Timeout("read timeout"), + _make_response(200, "Update Successful!"), + ], + ): + exit_code, host = run_ota(["device.local"], 80, None, None, firmware) + + assert exit_code == 0 + assert host == "192.168.1.50" + + +def test_run_ota_all_addresses_unreachable( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """When every resolved address fails to connect, run_ota returns failure.""" + _patch_resolve( + monkeypatch, + [("192.168.1.10", 80), ("192.168.1.20", 80)], + ) + + with patch( + "esphome.web_server_ota.requests.post", + side_effect=requests.ConnectionError("refused"), + ): + exit_code, host = run_ota(["device.local"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + # Per-address failure is logged for each attempt; final summary follows. + assert caplog.text.count("OTA upload to ") >= 2 + assert "OTA upload failed." in caplog.text + + +def test_run_ota_no_resolved_addresses( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """If resolve_ip_address returns no candidates, log and return failure.""" + _patch_resolve(monkeypatch, []) + + exit_code, host = run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + assert "Could not resolve 192.168.1.50" in caplog.text + + +def test_run_ota_resolution_failure( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + def _raise(*_args, **_kwargs): + raise EsphomeError("dns failed") + + monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _raise) + + exit_code, host = run_ota(["does.not.exist"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + + +def test_run_ota_resolution_failure_dashboard_mode( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Dashboard mode skips the '--device ' tip on resolution failure.""" + + def _raise(*_args, **_kwargs): + raise EsphomeError("dns failed") + + monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _raise) + monkeypatch.setattr(CORE, "dashboard", True) + try: + exit_code, host = run_ota(["does.not.exist"], 80, None, None, firmware) + finally: + monkeypatch.setattr(CORE, "dashboard", False) + + assert exit_code == 1 + assert host is None + assert "--device " not in caplog.text + + +def test_run_ota_empty_hosts(firmware: Path) -> None: + exit_code, host = run_ota([], 80, None, None, firmware) + assert exit_code == 1 + assert host is None + + +def test_run_ota_string_host_accepted( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """A bare string is accepted in addition to a list of hosts.""" + _patch_resolve(monkeypatch, [("10.0.0.5", 80)]) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ): + exit_code, host = run_ota("10.0.0.5", 80, None, None, firmware) + + assert exit_code == 0 + assert host == "10.0.0.5" + + +def test_run_ota_multiple_hosts_first_fails( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """Multi-host fallthrough: first host's addresses all fail, second host wins.""" + addr_lookup = { + "primary.local": [ + (socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.10", 80)), + ], + "secondary.local": [ + (socket.AF_INET, socket.SOCK_STREAM, 0, "", ("192.168.1.50", 80)), + ], + } + + def _resolve(host, port, address_cache=None): # noqa: ARG001 + return addr_lookup[host] + + monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _resolve) + + with patch( + "esphome.web_server_ota.requests.post", + side_effect=[ + requests.ConnectionError("refused"), + _make_response(200, "Update Successful!"), + ], + ): + exit_code, host = run_ota( + ["primary.local", "secondary.local"], 80, None, None, firmware + ) + + assert exit_code == 0 + assert host == "192.168.1.50" + + +def test_run_ota_all_hosts_return_failure_no_exception( + monkeypatch: pytest.MonkeyPatch, firmware: Path, caplog: pytest.LogCaptureFixture +) -> None: + """All hosts resolve to no addresses; run_ota cleanly returns failure.""" + addr_lookup = { + "a.local": [], + "b.local": [], + } + + def _resolve(host, port, address_cache=None): # noqa: ARG001 + return addr_lookup[host] + + monkeypatch.setattr("esphome.web_server_ota.resolve_ip_address", _resolve) + + exit_code, host = run_ota(["a.local", "b.local"], 80, None, None, firmware) + + assert exit_code == 1 + assert host is None + # Each host gets its own "Could not resolve" log line + final summary. + assert caplog.text.count("Could not resolve") == 2 + assert "OTA upload failed." in caplog.text + + +def test_web_server_ota_error_is_esphome_error() -> None: + assert issubclass(WebServerOTAError, EsphomeError) + + +def test_run_ota_finalizes_progress_bar_on_success( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """progress.done() fires on the success path (finally block).""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + done_called: list[bool] = [] + + with ( + patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ), + patch.object(ProgressBar, "done", lambda self: done_called.append(True)), + ): + run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert done_called + + +def test_run_ota_finalizes_progress_bar_on_failure( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """progress.done() fires when the request itself raises (finally block).""" + _patch_resolve(monkeypatch, [("192.168.1.50", 80)]) + + done_called: list[bool] = [] + + with ( + patch( + "esphome.web_server_ota.requests.post", + side_effect=requests.ConnectionError("boom"), + ), + patch.object(ProgressBar, "done", lambda self: done_called.append(True)), + ): + run_ota(["192.168.1.50"], 80, None, None, firmware) + + assert done_called + + +def test_run_ota_ipv6_url_brackets_host( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """IPv6 candidates are bracketed in the URL so the port parses correctly.""" + addr_infos = [ + (socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("2001:db8::1", 80, 0, 0)), + ] + monkeypatch.setattr( + "esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos + ) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + exit_code, host = run_ota(["device.local"], 80, None, None, firmware) + + assert exit_code == 0 + assert host == "2001:db8::1" + url = post.call_args.args[0] + assert url == f"http://[2001:db8::1]:80{OTA_PATH}" + + +def test_run_ota_ipv6_link_local_includes_scope_id( + monkeypatch: pytest.MonkeyPatch, firmware: Path +) -> None: + """Link-local IPv6 candidates include the percent-encoded zone index.""" + addr_infos = [ + (socket.AF_INET6, socket.SOCK_STREAM, 0, "", ("fe80::1", 80, 0, 3)), + ] + monkeypatch.setattr( + "esphome.web_server_ota.resolve_ip_address", lambda *a, **kw: addr_infos + ) + + with patch( + "esphome.web_server_ota.requests.post", + return_value=_make_response(200, "Update Successful!"), + ) as post: + exit_code, _ = run_ota(["device.local"], 80, None, None, firmware) + + assert exit_code == 0 + url = post.call_args.args[0] + assert url == f"http://[fe80::1%253]:80{OTA_PATH}"