diff --git a/CPP_Implementation/rbf_model.h b/CPP_Implementation/rbf_model.h index 2b77469..4515cba 100644 --- a/CPP_Implementation/rbf_model.h +++ b/CPP_Implementation/rbf_model.h @@ -50,6 +50,17 @@ public: */ void adapt(double error, double learning_rate, const double* input); + /** + * @brief Train the RBF model using recorded data. + * + * @param inputs A pointer to an array of input samples (n_samples x input_dim). + * @param targets A pointer to an array of target outputs (n_samples). + * @param n_samples The number of samples to train on. + * @param epochs The number of training epochs to perform. + * @param learning_rate The learning rate for weight adaptation. + */ + void train(const double* inputs, const double* targets, int n_samples, int epochs, double learning_rate); + /** * @brief Get the weight at a specific index. *