Implemented training function.

This commit is contained in:
Andru Liu
2024-10-08 21:28:09 -07:00
parent 478c944f8a
commit abe184b807
+14
View File
@@ -73,6 +73,20 @@ void RBFModel::adapt(double error, double learning_rate, const double* input) {
}
}
/**
* @brief Train the RBF model using recorded data.
*/
void RBFModel::train(const double* inputs, const double* targets, int n_samples, int epochs, double learning_rate) {
for (int iter = 0; iter < epochs; ++iter) {
for (int sample = 0; sample < n_samples; ++sample) {
double output = predict(&inputs[sample * input_dim]);
double error = targets[sample] - output;
adapt(error, learning_rate, &inputs[sample * input_dim]); // Adapt weights based on the error
}
}
}
/**
* @brief Get the weight at a specific index.
*/