From 363a5ffe5db8f16f432ad707d913554da906d49c Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:57:27 -0700 Subject: [PATCH] Test for training function. --- TF_Implementation/test/test_RBF_tf.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/TF_Implementation/test/test_RBF_tf.py b/TF_Implementation/test/test_RBF_tf.py index 7d83c59..7910eb0 100644 --- a/TF_Implementation/test/test_RBF_tf.py +++ b/TF_Implementation/test/test_RBF_tf.py @@ -43,6 +43,15 @@ class TestRBFAdaptiveModel(unittest.TestCase): self.assertEqual(output.shape, (3, 1)) - + def test_train(self): + """ Test the train method.""" + initial_weights = self.model.rbf_layer.centers.numpy().copy() + errors = np.random.normal(size=(100, self.input_dim)) # create 100 samples of input errors + control_signals = np.random.normal(size=(100, 1)) # create 100 matching control signals + train_rbf_adaptive(self.model, errors, control_signals, 10) + + new_weights = self.model.rbf_layer.centers.numpy() + self.assertFalse(np.array_equal(initial_weights, new_weights)) + if __name__ == '__main__': unittest.main()