From abe184b8078bdd20684bcab74fc59cd1119046f9 Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:28:09 -0700 Subject: [PATCH] Implemented training function. --- CPP_Implementation/rbf_model.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/CPP_Implementation/rbf_model.cpp b/CPP_Implementation/rbf_model.cpp index 6b6b11f..c142114 100644 --- a/CPP_Implementation/rbf_model.cpp +++ b/CPP_Implementation/rbf_model.cpp @@ -73,6 +73,20 @@ void RBFModel::adapt(double error, double learning_rate, const double* input) { } } +/** + * @brief Train the RBF model using recorded data. + */ +void RBFModel::train(const double* inputs, const double* targets, int n_samples, int epochs, double learning_rate) { + for (int iter = 0; iter < epochs; ++iter) { + for (int sample = 0; sample < n_samples; ++sample) { + double output = predict(&inputs[sample * input_dim]); + double error = targets[sample] - output; + + adapt(error, learning_rate, &inputs[sample * input_dim]); // Adapt weights based on the error + } + } +} + /** * @brief Get the weight at a specific index. */