Test for training function.

This commit is contained in:
Andru Liu
2024-10-09 19:57:27 -07:00
parent aaeb3a4507
commit 363a5ffe5d
+10 -1
View File
@@ -43,6 +43,15 @@ class TestRBFAdaptiveModel(unittest.TestCase):
self.assertEqual(output.shape, (3, 1))
def test_train(self):
""" Test the train method."""
initial_weights = self.model.rbf_layer.centers.numpy().copy()
errors = np.random.normal(size=(100, self.input_dim)) # create 100 samples of input errors
control_signals = np.random.normal(size=(100, 1)) # create 100 matching control signals
train_rbf_adaptive(self.model, errors, control_signals, 10)
new_weights = self.model.rbf_layer.centers.numpy()
self.assertFalse(np.array_equal(initial_weights, new_weights))
if __name__ == '__main__':
unittest.main()