diff --git a/esphome/components/micro_wake_word/__init__.py b/esphome/components/micro_wake_word/__init__.py index ff27dec6df..de95e4961b 100644 --- a/esphome/components/micro_wake_word/__init__.py +++ b/esphome/components/micro_wake_word/__init__.py @@ -451,6 +451,8 @@ async def to_code(config): ota.request_ota_state_listeners() esp32.add_idf_component(name="espressif/esp-tflite-micro", ref="1.3.3~1") + # Pin esp-nn for stable future builds (esp-tflite-micro depends on esp-nn) + esp32.add_idf_component(name="espressif/esp-nn", ref="1.2.1") cg.add_build_flag("-DTF_LITE_STATIC_MEMORY") cg.add_build_flag("-DTF_LITE_DISABLE_X86_NEON") diff --git a/esphome/components/micro_wake_word/streaming_model.cpp b/esphome/components/micro_wake_word/streaming_model.cpp index 0ab6cd3772..e761e4866f 100644 --- a/esphome/components/micro_wake_word/streaming_model.cpp +++ b/esphome/components/micro_wake_word/streaming_model.cpp @@ -29,14 +29,6 @@ void VADModel::log_model_config() { bool StreamingModel::load_model_() { RAMAllocator arena_allocator; - if (this->tensor_arena_ == nullptr) { - this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); - if (this->tensor_arena_ == nullptr) { - ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena."); - return false; - } - } - if (this->var_arena_ == nullptr) { this->var_arena_ = arena_allocator.allocate(STREAMING_MODEL_VARIABLE_ARENA_SIZE); if (this->var_arena_ == nullptr) { @@ -53,6 +45,26 @@ bool StreamingModel::load_model_() { return false; } + // Probe for the actual required tensor arena size if not yet determined + if (!this->tensor_arena_size_probed_) { + size_t probed_size = this->probe_arena_size_(); + if (probed_size > 0) { + ESP_LOGD(TAG, "Probed tensor arena size: %zu bytes", probed_size); + this->tensor_arena_size_ = probed_size; + } else { + ESP_LOGW(TAG, "Arena size probe failed, using manifest size: %zu bytes", this->tensor_arena_size_); + } + this->tensor_arena_size_probed_ = true; + } + + if (this->tensor_arena_ == nullptr) { + this->tensor_arena_ = arena_allocator.allocate(this->tensor_arena_size_); + if (this->tensor_arena_ == nullptr) { + ESP_LOGE(TAG, "Could not allocate the streaming model's tensor arena."); + return false; + } + } + if (this->interpreter_ == nullptr) { this->interpreter_ = make_unique(tflite::GetModel(this->model_start_), this->streaming_op_resolver_, @@ -94,6 +106,70 @@ bool StreamingModel::load_model_() { return true; } +size_t StreamingModel::probe_arena_size_() { + RAMAllocator arena_allocator; + + // Try with the manifest size first, then escalates to 1.5, then 2x if it fails. Different platforms and different + // versions of the esp-nn library require different amounts of memory, so the manifest size may not always be correct, + // and probing allows us to find the actual required size for the current build and platform. Aligns test sizes to 16 + // bytes. + size_t attempt_sizes[] = {(this->tensor_arena_size_ + 15) & ~15, (this->tensor_arena_size_ * 3 / 2 + 15) & ~15, + (this->tensor_arena_size_ * 2 + 15) & ~15}; + + for (size_t attempt_size : attempt_sizes) { + uint8_t *probe_arena = arena_allocator.allocate(attempt_size); + if (probe_arena == nullptr) { + continue; + } + + // Verify the model works at all with this arena size + auto probe_interpreter = make_unique( + tflite::GetModel(this->model_start_), this->streaming_op_resolver_, probe_arena, attempt_size, this->mrv_); + + if (probe_interpreter->AllocateTensors() != kTfLiteOk) { + probe_interpreter.reset(); + arena_allocator.deallocate(probe_arena, attempt_size); + this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); + this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20); + continue; + } + + // Try to shrink the arena. Start with arena_used_bytes() + 16 (rounded to 16-byte alignment). + // If that works, use it. Otherwise, try midpoints between that and the full size until one succeeds. + size_t lower = (probe_interpreter->arena_used_bytes() + 16 + 15) & ~15; + probe_interpreter.reset(); + this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); + this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20); + + size_t upper = attempt_size; + + while (lower < upper) { + auto test_interpreter = make_unique( + tflite::GetModel(this->model_start_), this->streaming_op_resolver_, probe_arena, lower, this->mrv_); + + bool ok = test_interpreter->AllocateTensors() == kTfLiteOk; + + test_interpreter.reset(); + this->ma_ = tflite::MicroAllocator::Create(this->var_arena_, STREAMING_MODEL_VARIABLE_ARENA_SIZE); + this->mrv_ = tflite::MicroResourceVariables::Create(this->ma_, 20); + + if (ok) { + // Found a working size smaller than the full arena + upper = lower + 16; // Pad by 16 bytes to be safe for future allocations + break; + } + + // Try the midpoint between current attempt and full size + lower = ((lower + upper) / 2 + 15) & ~15; + } + + arena_allocator.deallocate(probe_arena, attempt_size); + return upper; + } + + return 0; +} + void StreamingModel::unload_model() { this->interpreter_.reset(); diff --git a/esphome/components/micro_wake_word/streaming_model.h b/esphome/components/micro_wake_word/streaming_model.h index 0811bfb19b..fc9eeb5e2d 100644 --- a/esphome/components/micro_wake_word/streaming_model.h +++ b/esphome/components/micro_wake_word/streaming_model.h @@ -63,6 +63,10 @@ class StreamingModel { /// @brief Allocates tensor and variable arenas and sets up the model interpreter /// @return True if successful, false otherwise bool load_model_(); + /// @brief Probes the actual required tensor arena size by trial allocation. + /// Tries the manifest size first, then 2x if that fails. + /// @return The required arena size rounded up to 16-byte alignment, or 0 on failure. + size_t probe_arena_size_(); /// @brief Returns true if successfully registered the streaming model's TensorFlow operations bool register_streaming_ops_(tflite::MicroMutableOpResolver<20> &op_resolver); @@ -70,6 +74,7 @@ class StreamingModel { bool loaded_{false}; bool enabled_{true}; + bool tensor_arena_size_probed_{false}; bool unprocessed_probability_status_{false}; uint8_t current_stride_step_{0}; int16_t ignore_windows_{-MIN_SLICES_BEFORE_DETECTION};