From 9ac342f003f59fe1ec91799cb2e9887f114d2a19 Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Mon, 7 Oct 2024 19:24:52 -0700 Subject: [PATCH] Added training function. --- RBF_tf.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/RBF_tf.py b/RBF_tf.py index 6557def..63615b9 100644 --- a/RBF_tf.py +++ b/RBF_tf.py @@ -99,3 +99,20 @@ class RBFAdaptiveModel(tf.keras.Model): rbf_output = self.rbf_layer(inputs) control_signal = self.output_layer(rbf_output) return control_signal + +def train_rbf_adaptive(model, errors, control_signals, epochs=100): + """ Training method for the RBF adaptive model. + + Parameters + ---------- + model : RBFAdaptiveModel + A RBF Adaptive Model instance. + errors : ndarray + Multi-sample training data, [error, derivative, integral]. + control_signals : ndarray + Control signal target values. + epochs : int + Number of epochs to train for. + """ + model.compile(optimizer="adam", loss="mean_squared_error") + model.fit(errors, control_signals, epochs=epochs, verbose=1)