diff --git a/TF_Implementation/test/test_RBF_tf.py b/TF_Implementation/test/test_RBF_tf.py index 7d83c59..7910eb0 100644 --- a/TF_Implementation/test/test_RBF_tf.py +++ b/TF_Implementation/test/test_RBF_tf.py @@ -43,6 +43,15 @@ class TestRBFAdaptiveModel(unittest.TestCase): self.assertEqual(output.shape, (3, 1)) - + def test_train(self): + """ Test the train method.""" + initial_weights = self.model.rbf_layer.centers.numpy().copy() + errors = np.random.normal(size=(100, self.input_dim)) # create 100 samples of input errors + control_signals = np.random.normal(size=(100, 1)) # create 100 matching control signals + train_rbf_adaptive(self.model, errors, control_signals, 10) + + new_weights = self.model.rbf_layer.centers.numpy() + self.assertFalse(np.array_equal(initial_weights, new_weights)) + if __name__ == '__main__': unittest.main()