diff --git a/CPP_Implementation/rbf_model.cpp b/CPP_Implementation/rbf_model.cpp index 6b6b11f..c142114 100644 --- a/CPP_Implementation/rbf_model.cpp +++ b/CPP_Implementation/rbf_model.cpp @@ -73,6 +73,20 @@ void RBFModel::adapt(double error, double learning_rate, const double* input) { } } +/** + * @brief Train the RBF model using recorded data. + */ +void RBFModel::train(const double* inputs, const double* targets, int n_samples, int epochs, double learning_rate) { + for (int iter = 0; iter < epochs; ++iter) { + for (int sample = 0; sample < n_samples; ++sample) { + double output = predict(&inputs[sample * input_dim]); + double error = targets[sample] - output; + + adapt(error, learning_rate, &inputs[sample * input_dim]); // Adapt weights based on the error + } + } +} + /** * @brief Get the weight at a specific index. */