mirror of
https://github.com/WallabyLester/RBF-aPID-Controller.git
synced 2026-05-27 08:55:36 +08:00
Test for training function.
This commit is contained in:
@@ -43,6 +43,15 @@ class TestRBFAdaptiveModel(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output.shape, (3, 1))
|
self.assertEqual(output.shape, (3, 1))
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
""" Test the train method."""
|
||||||
|
initial_weights = self.model.rbf_layer.centers.numpy().copy()
|
||||||
|
errors = np.random.normal(size=(100, self.input_dim)) # create 100 samples of input errors
|
||||||
|
control_signals = np.random.normal(size=(100, 1)) # create 100 matching control signals
|
||||||
|
train_rbf_adaptive(self.model, errors, control_signals, 10)
|
||||||
|
|
||||||
|
new_weights = self.model.rbf_layer.centers.numpy()
|
||||||
|
self.assertFalse(np.array_equal(initial_weights, new_weights))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user