diff --git a/RBF_tf.py b/RBF_tf.py index 6557def..63615b9 100644 --- a/RBF_tf.py +++ b/RBF_tf.py @@ -99,3 +99,20 @@ class RBFAdaptiveModel(tf.keras.Model): rbf_output = self.rbf_layer(inputs) control_signal = self.output_layer(rbf_output) return control_signal + +def train_rbf_adaptive(model, errors, control_signals, epochs=100): + """ Training method for the RBF adaptive model. + + Parameters + ---------- + model : RBFAdaptiveModel + A RBF Adaptive Model instance. + errors : ndarray + Multi-sample training data, [error, derivative, integral]. + control_signals : ndarray + Control signal target values. + epochs : int + Number of epochs to train for. + """ + model.compile(optimizer="adam", loss="mean_squared_error") + model.fit(errors, control_signals, epochs=epochs, verbose=1)