From 478c944f8a6502fcbf41627ac2f9df7ec96798cf Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:27:39 -0700 Subject: [PATCH] Adding training function. --- CPP_Implementation/rbf_model.h | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CPP_Implementation/rbf_model.h b/CPP_Implementation/rbf_model.h index 2b77469..4515cba 100644 --- a/CPP_Implementation/rbf_model.h +++ b/CPP_Implementation/rbf_model.h @@ -50,6 +50,17 @@ public: */ void adapt(double error, double learning_rate, const double* input); + /** + * @brief Train the RBF model using recorded data. + * + * @param inputs A pointer to an array of input samples (n_samples x input_dim). + * @param targets A pointer to an array of target outputs (n_samples). + * @param n_samples The number of samples to train on. + * @param epochs The number of training epochs to perform. + * @param learning_rate The learning rate for weight adaptation. + */ + void train(const double* inputs, const double* targets, int n_samples, int epochs, double learning_rate); + /** * @brief Get the weight at a specific index. *