mirror of
https://github.com/WallabyLester/RBF-aPID-Controller.git
synced 2026-05-29 02:27:22 +08:00
Added test for training.
This commit is contained in:
@@ -30,7 +30,23 @@ class TestRBFNetwork(unittest.TestCase):
|
|||||||
output_after = self.rbf_network.predict(self.x)
|
output_after = self.rbf_network.predict(self.x)
|
||||||
self.assertIsInstance(output_after, float)
|
self.assertIsInstance(output_after, float)
|
||||||
self.assertNotEqual(output_before, output_after)
|
self.assertNotEqual(output_before, output_after)
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
"""Test the training function."""
|
||||||
|
target = 1.0
|
||||||
|
initial_weights = self.rbf_network.weights.copy()
|
||||||
|
output_before = np.dot(np.array([self.rbf_network.gaussian(self.x, center)
|
||||||
|
for center in self.rbf_network.centers]), initial_weights)
|
||||||
|
|
||||||
|
self.rbf_network.train(self.x, target)
|
||||||
|
|
||||||
|
self.assertFalse(np.array_equal(initial_weights, self.rbf_network.weights))
|
||||||
|
|
||||||
|
output_after = self.rbf_network.predict(self.x)
|
||||||
|
|
||||||
|
self.assertNotEqual(output_before, output_after)
|
||||||
|
if not abs(target - output_after) < abs(target - output_before):
|
||||||
|
print("Output did not move closer to the target after training.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user