Added training function.

This commit is contained in:
Andru Liu
2024-10-07 19:24:52 -07:00
parent 8199266dd4
commit 9ac342f003
+17
View File
@@ -99,3 +99,20 @@ class RBFAdaptiveModel(tf.keras.Model):
rbf_output = self.rbf_layer(inputs)
control_signal = self.output_layer(rbf_output)
return control_signal
def train_rbf_adaptive(model, errors, control_signals, epochs=100):
""" Training method for the RBF adaptive model.
Parameters
----------
model : RBFAdaptiveModel
A RBF Adaptive Model instance.
errors : ndarray
Multi-sample training data, [error, derivative, integral].
control_signals : ndarray
Control signal target values.
epochs : int
Number of epochs to train for.
"""
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(errors, control_signals, epochs=epochs, verbose=1)