diff --git a/esphome/__main__.py b/esphome/__main__.py index 781bcd6288..e7ce36ae2d 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -1125,15 +1125,16 @@ def upload_program( remote_port = int(ota_conf[CONF_PORT]) password = ota_conf.get(CONF_PASSWORD) - if getattr(args, "file", None) is not None: - binary = Path(args.file) - else: - binary = CORE.firmware_bin # Resolve MQTT magic strings to actual IP addresses network_devices = _resolve_network_devices(devices, config, args) - return espota2.run_ota(network_devices, remote_port, password, binary) + binary = CORE.firmware_bin + ota_type = espota2.OTA_TYPE_UPDATE_APP + if getattr(args, "file", None) is not None: + binary = Path(args.file) + + return espota2.run_ota(network_devices, remote_port, password, binary, ota_type) def show_logs(config: ConfigType, args: ArgsProtocol, devices: list[str]) -> int | None: diff --git a/esphome/components/esphome/ota/ota_esphome.cpp b/esphome/components/esphome/ota/ota_esphome.cpp index be771eb689..955b4dc96f 100644 --- a/esphome/components/esphome/ota/ota_esphome.cpp +++ b/esphome/components/esphome/ota/ota_esphome.cpp @@ -114,8 +114,10 @@ void ESPHomeOTAComponent::loop() { this->handle_handshake_(); } -static const uint8_t FEATURE_SUPPORTS_COMPRESSION = 0x01; -static const uint8_t FEATURE_SUPPORTS_SHA256_AUTH = 0x02; +static constexpr uint8_t CLIENT_FEATURE_SUPPORTS_COMPRESSION = 0x01; +static constexpr uint8_t CLIENT_FEATURE_SUPPORTS_SHA256_AUTH = 0x02; +static constexpr uint8_t CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL = 0x04; +static constexpr uint8_t SERVER_FEATURE_SUPPORTS_COMPRESSION = 0x01; void ESPHomeOTAComponent::handle_handshake_() { /// Handle the OTA handshake and authentication. @@ -201,16 +203,30 @@ void ESPHomeOTAComponent::handle_handshake_() { this->ota_features_ = this->handshake_buf_[0]; ESP_LOGV(TAG, "Features: 0x%02X", this->ota_features_); this->transition_ota_state_(OTAState::FEATURE_ACK); - this->handshake_buf_[0] = - ((this->ota_features_ & FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression()) - ? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION - : ota::OTA_RESPONSE_HEADER_OK; + + const bool supports_compression = + (this->ota_features_ & CLIENT_FEATURE_SUPPORTS_COMPRESSION) != 0 && this->backend_->supports_compression(); + + // Compose the feature-ack response. When the client negotiates the extended protocol we emit + // a 2-byte response (marker + server feature flags); otherwise we emit the single-byte + // legacy response. + this->extended_proto_ = (this->ota_features_ & CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL) != 0; + if (this->extended_proto_) { + static_assert(HANDSHAKE_BUF_SIZE >= 2, "handshake_buf_ must hold the 2-byte extended-protocol feature ack"); + this->handshake_buf_[0] = ota::OTA_RESPONSE_FEATURE_FLAGS; + this->handshake_buf_[1] = (supports_compression ? SERVER_FEATURE_SUPPORTS_COMPRESSION : 0); + } else { + this->handshake_buf_[0] = + supports_compression ? ota::OTA_RESPONSE_SUPPORTS_COMPRESSION : ota::OTA_RESPONSE_HEADER_OK; + } [[fallthrough]]; } case OTAState::FEATURE_ACK: { - // Acknowledge header - 1 byte - if (!this->try_write_(1, LOG_STR("ack feature"))) { + static constexpr size_t STANDARD_PROTO_ACK_SIZE = 1; + static constexpr size_t EXTENDED_PROTO_ACK_SIZE = 2; + const size_t ack_size = this->extended_proto_ ? EXTENDED_PROTO_ACK_SIZE : STANDARD_PROTO_ACK_SIZE; + if (!this->try_write_(ack_size, LOG_STR("ack feature"))) { return; } #ifdef USE_OTA_PASSWORD @@ -296,6 +312,7 @@ void ESPHomeOTAComponent::handle_data_() { uint8_t buf[OTA_BUFFER_SIZE]; char *sbuf = reinterpret_cast(buf); size_t ota_size; + ota::OTAType ota_type = ota::OTA_TYPE_UPDATE_APP; #if USE_OTA_VERSION == 2 size_t size_acknowledged = 0; #endif @@ -311,6 +328,16 @@ void ESPHomeOTAComponent::handle_data_() { // Acknowledge auth OK - 1 byte this->write_byte_(ota::OTA_RESPONSE_AUTH_OK); + if (this->extended_proto_) { + // Read ota type, 1 byte + if (!this->readall_(buf, 1)) { + this->log_read_error_(LOG_STR("OTA type")); + goto error; // NOLINT(cppcoreguidelines-avoid-goto) + } + ota_type = static_cast(buf[0]); + } + ESP_LOGV(TAG, "OTA type is 0x%02x", ota_type); + // Read size, 4 bytes MSB first if (!this->readall_(buf, 4)) { this->log_read_error_(LOG_STR("size")); @@ -320,6 +347,11 @@ void ESPHomeOTAComponent::handle_data_() { (static_cast(buf[2]) << 8) | buf[3]; ESP_LOGV(TAG, "Size is %u bytes", ota_size); + if (ota_type != ota::OTA_TYPE_UPDATE_APP) { + error_code = ota::OTA_RESPONSE_ERROR_UNSUPPORTED_OTA_TYPE; + goto error; // NOLINT(cppcoreguidelines-avoid-goto) + } + // Now that we've passed authentication and are actually // starting the update, set the warning status and notify // listeners. This ensures that port scanners do not @@ -616,7 +648,7 @@ void ESPHomeOTAComponent::yield_and_feed_watchdog_() { void ESPHomeOTAComponent::log_auth_warning_(const LogString *msg) { ESP_LOGW(TAG, "Auth: %s", LOG_STR_ARG(msg)); } bool ESPHomeOTAComponent::select_auth_type_() { - bool client_supports_sha256 = (this->ota_features_ & FEATURE_SUPPORTS_SHA256_AUTH) != 0; + bool client_supports_sha256 = (this->ota_features_ & CLIENT_FEATURE_SUPPORTS_SHA256_AUTH) != 0; // Require SHA256 if (!client_supports_sha256) { diff --git a/esphome/components/esphome/ota/ota_esphome.h b/esphome/components/esphome/ota/ota_esphome.h index 53288fc000..5043bc33ef 100644 --- a/esphome/components/esphome/ota/ota_esphome.h +++ b/esphome/components/esphome/ota/ota_esphome.h @@ -97,8 +97,9 @@ class ESPHomeOTAComponent final : public ota::OTAComponent { ota::OTABackendPtr backend_; uint32_t client_connect_time_{0}; + static constexpr size_t HANDSHAKE_BUF_SIZE = 5; uint16_t port_; - uint8_t handshake_buf_[5]; + uint8_t handshake_buf_[HANDSHAKE_BUF_SIZE]; OTAState ota_state_{OTAState::IDLE}; uint8_t handshake_buf_pos_{0}; uint8_t ota_features_{0}; @@ -106,6 +107,7 @@ class ESPHomeOTAComponent final : public ota::OTAComponent { uint8_t auth_buf_pos_{0}; uint8_t auth_type_{0}; // Store auth type to know which hasher to use #endif // USE_OTA_PASSWORD + bool extended_proto_{false}; }; } // namespace esphome diff --git a/esphome/components/ota/ota_backend.h b/esphome/components/ota/ota_backend.h index bd9c481901..7e7b0f6523 100644 --- a/esphome/components/ota/ota_backend.h +++ b/esphome/components/ota/ota_backend.h @@ -4,6 +4,8 @@ #include "esphome/core/defines.h" #include "esphome/core/helpers.h" +#include + #ifdef USE_OTA_STATE_LISTENER #include #endif @@ -23,6 +25,7 @@ enum OTAResponseTypes { OTA_RESPONSE_UPDATE_END_OK = 0x45, OTA_RESPONSE_SUPPORTS_COMPRESSION = 0x46, OTA_RESPONSE_CHUNK_OK = 0x47, + OTA_RESPONSE_FEATURE_FLAGS = 0x48, OTA_RESPONSE_ERROR_MAGIC = 0x80, OTA_RESPONSE_ERROR_UPDATE_PREPARE = 0x81, @@ -38,6 +41,7 @@ enum OTAResponseTypes { OTA_RESPONSE_ERROR_MD5_MISMATCH = 0x8B, OTA_RESPONSE_ERROR_RP2040_NOT_ENOUGH_SPACE = 0x8C, OTA_RESPONSE_ERROR_SIGNATURE_INVALID = 0x8D, + OTA_RESPONSE_ERROR_UNSUPPORTED_OTA_TYPE = 0x8E, OTA_RESPONSE_ERROR_UNKNOWN = 0xFF, }; @@ -49,6 +53,10 @@ enum OTAState { OTA_ERROR, }; +enum OTAType : uint8_t { + OTA_TYPE_UPDATE_APP = 0x00, +}; + /** Listener interface for OTA state changes. * * Components can implement this interface to receive OTA state updates diff --git a/esphome/espota2.py b/esphome/espota2.py index 39f51e02e9..f4c0c73589 100644 --- a/esphome/espota2.py +++ b/esphome/espota2.py @@ -15,6 +15,8 @@ from typing import Any from esphome.core import EsphomeError from esphome.helpers import ProgressBar, resolve_ip_address +OTA_TYPE_UPDATE_APP = 0x00 + RESPONSE_OK = 0x00 RESPONSE_REQUEST_AUTH = 0x01 RESPONSE_REQUEST_SHA256_AUTH = 0x02 @@ -27,6 +29,7 @@ RESPONSE_RECEIVE_OK = 0x44 RESPONSE_UPDATE_END_OK = 0x45 RESPONSE_SUPPORTS_COMPRESSION = 0x46 RESPONSE_CHUNK_OK = 0x47 +RESPONSE_FEATURE_FLAGS = 0x48 RESPONSE_ERROR_MAGIC = 0x80 RESPONSE_ERROR_UPDATE_PREPARE = 0x81 @@ -42,6 +45,7 @@ RESPONSE_ERROR_NO_UPDATE_PARTITION = 0x8A RESPONSE_ERROR_MD5_MISMATCH = 0x8B RESPONSE_ERROR_RP2040_NOT_ENOUGH_SPACE = 0x8C RESPONSE_ERROR_SIGNATURE_INVALID = 0x8D +RESPONSE_ERROR_UNSUPPORTED_OTA_TYPE = 0x8E RESPONSE_ERROR_UNKNOWN = 0xFF OTA_VERSION_1_0 = 1 @@ -49,9 +53,16 @@ OTA_VERSION_2_0 = 2 MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45] -FEATURE_SUPPORTS_COMPRESSION = 0x01 -FEATURE_SUPPORTS_SHA256_AUTH = 0x02 +CLIENT_FEATURE_SUPPORTS_COMPRESSION = 0x01 +CLIENT_FEATURE_SUPPORTS_SHA256_AUTH = 0x02 +CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL = 0x04 +SERVER_FEATURE_SUPPORTS_COMPRESSION = 0x01 +SERVER_FEATURE_SUPPORTS_PARTITION_ACCESS = 0x02 +# OTA types this client knows how to send. Future PRs that add bootloader/partition +# updates extend this set. Anything outside the set is rejected up front so callers +# of perform_ota/run_ota get a clear error instead of a post-auth 0x8E from the device. +_SUPPORTED_OTA_TYPES: frozenset[int] = frozenset({OTA_TYPE_UPDATE_APP}) UPLOAD_BLOCK_SIZE = 8192 UPLOAD_BUFFER_SIZE = UPLOAD_BLOCK_SIZE * 8 @@ -64,6 +75,62 @@ _AUTH_METHODS: dict[int, tuple[Callable[..., Any], int, str]] = { RESPONSE_REQUEST_AUTH: (hashlib.md5, 32, "MD5"), } +# Error response code -> human-readable message (without the "Error: " prefix; check_error() +# prepends it uniformly). Looked up by check_error() to translate a single byte from the device +# into an OTAError. Add new error codes here rather than extending the if-chain in check_error(). +_ERROR_MESSAGES: dict[int, str] = { + RESPONSE_ERROR_MAGIC: "Invalid magic byte", + RESPONSE_ERROR_UPDATE_PREPARE: ( + "Couldn't prepare flash memory for update. Is the binary too big? " + "Please try restarting the ESP." + ), + RESPONSE_ERROR_AUTH_INVALID: "Authentication invalid. Is the password correct?", + RESPONSE_ERROR_WRITING_FLASH: ( + "Writing OTA data to flash memory failed. See USB logs for more information." + ), + RESPONSE_ERROR_UPDATE_END: ( + "Finishing update failed. See the MQTT/USB logs for more information." + ), + RESPONSE_ERROR_INVALID_BOOTSTRAPPING: ( + "Please press the reset button on the ESP. A manual reset is " + "required on the first OTA-Update after flashing via USB." + ), + RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG: ( + "ESP has been flashed with wrong flash size. Please choose the " + "correct 'board' option (esp01_1m always works) and then flash over USB." + ), + RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG: ( + "ESP does not have the requested flash size (wrong board). Please " + "choose the correct 'board' option (esp01_1m always works) and try " + "uploading again." + ), + RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE: ( + "ESP does not have enough space to store OTA file. Please try " + "flashing a minimal firmware (remove everything except ota)" + ), + RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE: ( + "The OTA partition on the ESP is too small. ESPHome needs to resize " + "this partition, please flash over USB." + ), + RESPONSE_ERROR_NO_UPDATE_PARTITION: ( + "The OTA partition on the ESP couldn't be found. ESPHome needs to " + "create this partition, please flash over USB." + ), + RESPONSE_ERROR_MD5_MISMATCH: ( + "Application MD5 code mismatch. Please try again " + "or flash over USB with a good quality cable." + ), + RESPONSE_ERROR_SIGNATURE_INVALID: ( + "Firmware signature verification failed. The firmware was not signed " + "with the correct key. Ensure the signing key matches the one used to build " + "the firmware currently running on the device." + ), + RESPONSE_ERROR_UNSUPPORTED_OTA_TYPE: ( + "The requested OTA type is not supported by the device." + ), + RESPONSE_ERROR_UNKNOWN: "Unknown error from ESP", +} + class OTAError(EsphomeError): pass @@ -130,8 +197,10 @@ def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None :param expect: Expected response code(s), None to skip validation. :raises OTAError: If an error code is detected or response doesn't match expected. """ - if expect is None: - return + # Detect device errors and connection-closed cases regardless of `expect`. If we + # only ran these checks when expect was set, error bytes returned during + # accept-any-response reads (e.g. feature negotiation, auth nonces) would be + # silently passed through and surface later as cryptic decode/timeout failures. if not data: raise OTAError( "Error: Device closed connection without responding. " @@ -139,69 +208,11 @@ def check_error(data: list[int] | bytes, expect: int | list[int] | None) -> None "a network issue, or the connection was interrupted." ) dat = data[0] - if dat == RESPONSE_ERROR_MAGIC: - raise OTAError("Error: Invalid magic byte") - if dat == RESPONSE_ERROR_UPDATE_PREPARE: - raise OTAError( - "Error: Couldn't prepare flash memory for update. Is the binary too big? " - "Please try restarting the ESP." - ) - if dat == RESPONSE_ERROR_AUTH_INVALID: - raise OTAError("Error: Authentication invalid. Is the password correct?") - if dat == RESPONSE_ERROR_WRITING_FLASH: - raise OTAError( - "Error: Writing OTA data to flash memory failed. See USB logs for more " - "information." - ) - if dat == RESPONSE_ERROR_UPDATE_END: - raise OTAError( - "Error: Finishing update failed. See the MQTT/USB logs for more " - "information." - ) - if dat == RESPONSE_ERROR_INVALID_BOOTSTRAPPING: - raise OTAError( - "Error: Please press the reset button on the ESP. A manual reset is " - "required on the first OTA-Update after flashing via USB." - ) - if dat == RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG: - raise OTAError( - "Error: ESP has been flashed with wrong flash size. Please choose the " - "correct 'board' option (esp01_1m always works) and then flash over USB." - ) - if dat == RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG: - raise OTAError( - "Error: ESP does not have the requested flash size (wrong board). Please " - "choose the correct 'board' option (esp01_1m always works) and try " - "uploading again." - ) - if dat == RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE: - raise OTAError( - "Error: ESP does not have enough space to store OTA file. Please try " - "flashing a minimal firmware (remove everything except ota)" - ) - if dat == RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE: - raise OTAError( - "Error: The OTA partition on the ESP is too small. ESPHome needs to resize " - "this partition, please flash over USB." - ) - if dat == RESPONSE_ERROR_NO_UPDATE_PARTITION: - raise OTAError( - "Error: The OTA partition on the ESP couldn't be found. ESPHome needs to create " - "this partition, please flash over USB." - ) - if dat == RESPONSE_ERROR_MD5_MISMATCH: - raise OTAError( - "Error: Application MD5 code mismatch. Please try again " - "or flash over USB with a good quality cable." - ) - if dat == RESPONSE_ERROR_SIGNATURE_INVALID: - raise OTAError( - "Error: Firmware signature verification failed. The firmware was not signed " - "with the correct key. Ensure the signing key matches the one used to build " - "the firmware currently running on the device." - ) - if dat == RESPONSE_ERROR_UNKNOWN: - raise OTAError("Unknown error from ESP") + error_msg = _ERROR_MESSAGES.get(dat) + if error_msg is not None: + raise OTAError(f"Error: {error_msg}") + if expect is None: + return if not isinstance(expect, (list, tuple)): expect = [expect] if dat not in expect: @@ -232,8 +243,25 @@ def send_check( def perform_ota( - sock: socket.socket, password: str | None, file_handle: io.IOBase, filename: Path + sock: socket.socket, + password: str | None, + file_handle: io.IOBase, + filename: Path, + ota_type: int = OTA_TYPE_UPDATE_APP, ) -> None: + # Validate ota_type up front. It travels as a single byte on the wire, and + # passing an out-of-range value would only surface as a ValueError from + # bytes([ota_type]) deep inside send_check, bypassing OTAError handling. + if not isinstance(ota_type, int) or not 0 <= ota_type <= 0xFF: + raise OTAError( + f"Invalid ota_type {ota_type!r}; expected an integer in range 0-255" + ) + if ota_type not in _SUPPORTED_OTA_TYPES: + supported = ", ".join(f"0x{t:02X}" for t in sorted(_SUPPORTED_OTA_TYPES)) + raise OTAError( + f"Unsupported OTA type 0x{ota_type:02X}; this ESPHome supports: {supported}" + ) + file_contents = file_handle.read() file_size = len(file_contents) _LOGGER.info("Uploading %s (%s bytes)", filename, file_size) @@ -251,7 +279,11 @@ def perform_ota( ) # Features - send both compression and SHA256 auth support - features_to_send = FEATURE_SUPPORTS_COMPRESSION | FEATURE_SUPPORTS_SHA256_AUTH + features_to_send = ( + CLIENT_FEATURE_SUPPORTS_COMPRESSION + | CLIENT_FEATURE_SUPPORTS_SHA256_AUTH + | CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL + ) send_check(sock, features_to_send, "features") features = receive_exactly( sock, @@ -260,7 +292,36 @@ def perform_ota( None, # Accept any response )[0] - if features == RESPONSE_SUPPORTS_COMPRESSION: + extended_proto = False + if features == RESPONSE_FEATURE_FLAGS: + extended_proto = True + features = receive_exactly( + sock, + 1, + "feature flags", + None, # Accept any response + )[0] + elif features == RESPONSE_SUPPORTS_COMPRESSION: + features = SERVER_FEATURE_SUPPORTS_COMPRESSION + else: + features = 0 + + if ota_type != OTA_TYPE_UPDATE_APP: + # Any non-app OTA type requires the extended protocol and the + # partition-access server feature. Reject up front so the user gets + # a clear capability error instead of a post-auth 0x8E from the device. + if not extended_proto: + raise OTAError( + f"Device does not support extended OTA protocol; " + f"OTA type 0x{ota_type:02X} requires it" + ) + if not (features & SERVER_FEATURE_SUPPORTS_PARTITION_ACCESS): + raise OTAError( + f"Device does not support partition access; " + f"OTA type 0x{ota_type:02X} cannot be used" + ) + + if features & SERVER_FEATURE_SUPPORTS_COMPRESSION: upload_contents = gzip.compress(file_contents, compresslevel=9) _LOGGER.info("Compressed to %s bytes", len(upload_contents)) else: @@ -315,6 +376,9 @@ def perform_ota( # Timeout must match device-side OTA_SOCKET_TIMEOUT_DATA to prevent premature failures sock.settimeout(90.0) + if extended_proto: + send_check(sock, ota_type, "ota type") + upload_size = len(upload_contents) upload_size_encoded = [ (upload_size >> 24) & 0xFF, @@ -375,7 +439,11 @@ def perform_ota( def run_ota_impl_( - remote_host: str | list[str], remote_port: int, password: str | None, filename: Path + remote_host: str | list[str], + remote_port: int, + password: str | None, + filename: Path, + ota_type: int = OTA_TYPE_UPDATE_APP, ) -> tuple[int, str | None]: from esphome.core import CORE @@ -413,7 +481,7 @@ def run_ota_impl_( _LOGGER.info("Connected to %s", sa[0]) with open(filename, "rb") as file_handle: try: - perform_ota(sock, password, file_handle, filename) + perform_ota(sock, password, file_handle, filename, ota_type) except OTAError as err: _LOGGER.error(str(err)) return 1, None @@ -428,10 +496,14 @@ def run_ota_impl_( def run_ota( - remote_host: str | list[str], remote_port: int, password: str | None, filename: Path + remote_host: str | list[str], + remote_port: int, + password: str | None, + filename: Path, + ota_type: int = OTA_TYPE_UPDATE_APP, ) -> tuple[int, str | None]: try: - return run_ota_impl_(remote_host, remote_port, password, filename) + return run_ota_impl_(remote_host, remote_port, password, filename, ota_type) except OTAError as err: _LOGGER.error(err) return 1, None diff --git a/tests/unit_tests/test_espota2.py b/tests/unit_tests/test_espota2.py index 20ba4b1f76..b114f17e6c 100644 --- a/tests/unit_tests/test_espota2.py +++ b/tests/unit_tests/test_espota2.py @@ -185,6 +185,14 @@ def test_receive_exactly_socket_error(mock_socket: Mock) -> None: "Error: The OTA partition on the ESP couldn't be found", ), (espota2.RESPONSE_ERROR_MD5_MISMATCH, "Error: Application MD5 code mismatch"), + ( + espota2.RESPONSE_ERROR_SIGNATURE_INVALID, + "Error: Firmware signature verification failed", + ), + ( + espota2.RESPONSE_ERROR_UNSUPPORTED_OTA_TYPE, + "Error: The requested OTA type is not supported by the device", + ), (espota2.RESPONSE_ERROR_UNKNOWN, "Unknown error from ESP"), ], ) @@ -270,12 +278,13 @@ def test_perform_ota_successful_md5_auth( # Verify magic bytes were sent assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES)) - # Verify features were sent (compression + SHA256 support) + # Verify features were sent (compression + SHA256 support + extended protocol) assert mock_socket.sendall.call_args_list[1] == call( bytes( [ - espota2.FEATURE_SUPPORTS_COMPRESSION - | espota2.FEATURE_SUPPORTS_SHA256_AUTH + espota2.CLIENT_FEATURE_SUPPORTS_COMPRESSION + | espota2.CLIENT_FEATURE_SUPPORTS_SHA256_AUTH + | espota2.CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL ] ) ) @@ -640,12 +649,13 @@ def test_perform_ota_successful_sha256_auth( # Verify magic bytes were sent assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES)) - # Verify features were sent (compression + SHA256 support) + # Verify features were sent (compression + SHA256 support + extended protocol) assert mock_socket.sendall.call_args_list[1] == call( bytes( [ - espota2.FEATURE_SUPPORTS_COMPRESSION - | espota2.FEATURE_SUPPORTS_SHA256_AUTH + espota2.CLIENT_FEATURE_SUPPORTS_COMPRESSION + | espota2.CLIENT_FEATURE_SUPPORTS_SHA256_AUTH + | espota2.CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL ] ) ) @@ -699,8 +709,9 @@ def test_perform_ota_sha256_fallback_to_md5( assert mock_socket.sendall.call_args_list[1] == call( bytes( [ - espota2.FEATURE_SUPPORTS_COMPRESSION - | espota2.FEATURE_SUPPORTS_SHA256_AUTH + espota2.CLIENT_FEATURE_SUPPORTS_COMPRESSION + | espota2.CLIENT_FEATURE_SUPPORTS_SHA256_AUTH + | espota2.CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL ] ) ) @@ -765,3 +776,220 @@ def test_perform_ota_version_differences( # For v2.0, verify more recv calls due to chunk acknowledgments assert mock_socket.recv.call_count == 9 # v2.0 has 9 recv calls (includes chunk OK) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_extended_protocol_app( + mock_socket: Mock, mock_file: io.BytesIO +) -> None: + """Test OTA extended protocol app update.""" + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_FEATURE_FLAGS]), # Device supports extended protocol + bytes( + [ + espota2.SERVER_FEATURE_SUPPORTS_COMPRESSION + | espota2.SERVER_FEATURE_SUPPORTS_PARTITION_ACCESS + ] + ), # Device feature flags + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_UPDATE_PREPARE_OK]), # Binary size OK + bytes([espota2.RESPONSE_BIN_MD5_OK]), # MD5 checksum OK + bytes([espota2.RESPONSE_CHUNK_OK]), # Chunk OK + bytes([espota2.RESPONSE_RECEIVE_OK]), # Receive OK + bytes([espota2.RESPONSE_UPDATE_END_OK]), # Update end OK + ] + + mock_socket.recv.side_effect = recv_responses + + espota2.perform_ota( + mock_socket, + "testpass", + mock_file, + "test.bin", + espota2.OTA_TYPE_UPDATE_APP, + ) + + # Verify magic bytes were sent + assert mock_socket.sendall.call_args_list[0] == call(bytes(espota2.MAGIC_BYTES)) + + # Verify features were sent (compression + SHA256 support + extended protocol) + assert mock_socket.sendall.call_args_list[1] == call( + bytes( + [ + espota2.CLIENT_FEATURE_SUPPORTS_COMPRESSION + | espota2.CLIENT_FEATURE_SUPPORTS_SHA256_AUTH + | espota2.CLIENT_FEATURE_SUPPORTS_EXTENDED_PROTOCOL + ] + ) + ) + + # Verify ota type was sent + assert mock_socket.sendall.call_args_list[2] == call( + bytes([espota2.OTA_TYPE_UPDATE_APP]) + ) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_device_rejects_with_unsupported_ota_type( + mock_socket: Mock, mock_file: io.BytesIO +) -> None: + """End-to-end: device returns 0x8E after the size byte; perform_ota must + surface the human-readable 'unsupported OTA type' error from the lookup + table in check_error().""" + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_FEATURE_FLAGS]), # Extended protocol marker + bytes( + [ + espota2.SERVER_FEATURE_SUPPORTS_COMPRESSION + | espota2.SERVER_FEATURE_SUPPORTS_PARTITION_ACCESS + ] + ), # Feature flags + bytes([espota2.RESPONSE_AUTH_OK]), # No auth required + bytes([espota2.RESPONSE_ERROR_UNSUPPORTED_OTA_TYPE]), # Reject at size step + ] + + mock_socket.recv.side_effect = recv_responses + + with pytest.raises( + espota2.OTAError, + match="The requested OTA type is not supported by the device", + ): + espota2.perform_ota( + mock_socket, + "testpass", + mock_file, + "test.bin", + espota2.OTA_TYPE_UPDATE_APP, + ) + + # Verify the client did send the OTA type byte before the size step + assert mock_socket.sendall.call_args_list[2] == call( + bytes([espota2.OTA_TYPE_UPDATE_APP]) + ) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_unsupported_type_rejected_early( + mock_socket: Mock, mock_file: io.BytesIO +) -> None: + """ota_type values not in _SUPPORTED_OTA_TYPES are rejected before any I/O.""" + with pytest.raises(espota2.OTAError, match="Unsupported OTA type 0xFF"): + espota2.perform_ota( + mock_socket, + "testpass", + mock_file, + "test.bin", + 0xFF, + ) + # No bytes should have been transmitted to the device. + mock_socket.sendall.assert_not_called() + + +@pytest.mark.parametrize("bad_type", [-1, 256, 0x10000, "app", None, 1.5]) +def test_perform_ota_rejects_out_of_range_type( + mock_socket: Mock, mock_file: io.BytesIO, bad_type: object +) -> None: + """Out-of-range or non-int ota_type must raise OTAError, not ValueError.""" + with pytest.raises(espota2.OTAError, match="Invalid ota_type"): + espota2.perform_ota( + mock_socket, + "testpass", + mock_file, + "test.bin", + bad_type, # type: ignore[arg-type] + ) + mock_socket.sendall.assert_not_called() + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_non_app_type_requires_extended_protocol( + mock_socket: Mock, mock_file: io.BytesIO, monkeypatch: pytest.MonkeyPatch +) -> None: + """Non-app OTA type must fail when device only supports the legacy protocol.""" + monkeypatch.setattr( + espota2, + "_SUPPORTED_OTA_TYPES", + frozenset({espota2.OTA_TYPE_UPDATE_APP, 0xFF}), + ) + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_HEADER_OK]), # Legacy single-byte feature ack + ] + + mock_socket.recv.side_effect = recv_responses + + with pytest.raises( + espota2.OTAError, match="Device does not support extended OTA protocol" + ): + espota2.perform_ota( + mock_socket, + "testpass", + mock_file, + "test.bin", + 0xFF, + ) + + +@pytest.mark.usefixtures("mock_time") +def test_perform_ota_non_app_type_requires_partition_access( + mock_socket: Mock, mock_file: io.BytesIO, monkeypatch: pytest.MonkeyPatch +) -> None: + """Non-app OTA type must fail when device advertises extended protocol but + not the partition-access feature.""" + monkeypatch.setattr( + espota2, + "_SUPPORTED_OTA_TYPES", + frozenset({espota2.OTA_TYPE_UPDATE_APP, 0xFF}), + ) + recv_responses = [ + bytes([espota2.RESPONSE_OK]), # First byte of version response + bytes([espota2.OTA_VERSION_2_0]), # Version number + bytes([espota2.RESPONSE_FEATURE_FLAGS]), # Extended protocol marker + bytes( + [espota2.SERVER_FEATURE_SUPPORTS_COMPRESSION] + ), # Compression only, no partition access + ] + + mock_socket.recv.side_effect = recv_responses + + with pytest.raises( + espota2.OTAError, match="Device does not support partition access" + ): + espota2.perform_ota( + mock_socket, + "testpass", + mock_file, + "test.bin", + 0xFF, + ) + + +def test_check_error_detects_errors_when_expect_is_none() -> None: + """check_error must surface device error bytes even when expect is None. + + Regression test: previously, receive_exactly(..., expect=None) calls (used + during feature negotiation and nonce reads) silently passed error bytes + through, turning clean device errors into confusing later failures. + """ + with pytest.raises(espota2.OTAError, match="Error: Authentication invalid"): + espota2.check_error([espota2.RESPONSE_ERROR_AUTH_INVALID], None) + + +def test_check_error_detects_empty_when_expect_is_none() -> None: + """Empty data with expect=None must still raise (connection closed).""" + with pytest.raises( + espota2.OTAError, match="Device closed connection without responding" + ): + espota2.check_error([], None) + + +def test_check_error_passes_non_error_when_expect_is_none() -> None: + """Non-error bytes with expect=None must pass through silently.""" + espota2.check_error([espota2.RESPONSE_OK], None) + espota2.check_error([espota2.RESPONSE_HEADER_OK], None) + espota2.check_error([espota2.RESPONSE_FEATURE_FLAGS], None) diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py index fb8f206a1d..186d8a9573 100644 --- a/tests/unit_tests/test_main.py +++ b/tests/unit_tests/test_main.py @@ -83,6 +83,7 @@ from esphome.const import ( PLATFORM_RP2040, ) from esphome.core import CORE, EsphomeError +from esphome.espota2 import OTA_TYPE_UPDATE_APP from esphome.util import BootselResult from esphome.zeroconf import _await_discovery, discover_mdns_devices @@ -1593,7 +1594,7 @@ def test_upload_program_ota_success( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( - ["192.168.1.100"], 3232, "secret", expected_firmware + ["192.168.1.100"], 3232, "secret", expected_firmware, OTA_TYPE_UPDATE_APP ) @@ -1624,7 +1625,7 @@ def test_upload_program_ota_with_file_arg( assert exit_code == 0 assert host == "192.168.1.100" mock_run_ota.assert_called_once_with( - ["192.168.1.100"], 3232, None, Path("custom.bin") + ["192.168.1.100"], 3232, None, Path("custom.bin"), OTA_TYPE_UPDATE_APP ) @@ -1682,7 +1683,7 @@ def test_upload_program_ota_with_mqtt_resolution( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( - ["192.168.1.100"], 3232, None, expected_firmware + ["192.168.1.100"], 3232, None, expected_firmware, OTA_TYPE_UPDATE_APP ) @@ -1730,7 +1731,7 @@ def test_upload_program_ota_with_mqtt_empty_broker( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( - ["192.168.1.50"], 3232, None, expected_firmware + ["192.168.1.50"], 3232, None, expected_firmware, OTA_TYPE_UPDATE_APP ) # Verify warning was logged assert "MQTT IP discovery failed" in caplog.text @@ -3207,7 +3208,11 @@ def test_upload_program_ota_static_ip_with_mqttip( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( - ["192.168.1.100", "192.168.2.50"], 3232, None, expected_firmware + ["192.168.1.100", "192.168.2.50"], + 3232, + None, + expected_firmware, + OTA_TYPE_UPDATE_APP, ) @@ -3250,7 +3255,11 @@ def test_upload_program_ota_multiple_mqttip_resolves_once( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( - ["192.168.2.50", "192.168.2.51", "192.168.1.100"], 3232, None, expected_firmware + ["192.168.2.50", "192.168.2.51", "192.168.1.100"], + 3232, + None, + expected_firmware, + OTA_TYPE_UPDATE_APP, ) @@ -3415,7 +3424,7 @@ def test_upload_program_ota_mqtt_timeout_fallback( tmp_path / ".esphome" / "build" / "test" / ".pioenvs" / "test" / "firmware.bin" ) mock_run_ota.assert_called_once_with( - ["192.168.1.100"], 3232, None, expected_firmware + ["192.168.1.100"], 3232, None, expected_firmware, OTA_TYPE_UPDATE_APP )