Added test for training.

This commit is contained in:
Andru Liu
2024-10-09 15:24:56 -07:00
parent 4e11dd8c76
commit 7951a89584
+17 -1
View File
@@ -30,7 +30,23 @@ class TestRBFNetwork(unittest.TestCase):
output_after = self.rbf_network.predict(self.x)
self.assertIsInstance(output_after, float)
self.assertNotEqual(output_before, output_after)
def test_train(self):
"""Test the training function."""
target = 1.0
initial_weights = self.rbf_network.weights.copy()
output_before = np.dot(np.array([self.rbf_network.gaussian(self.x, center)
for center in self.rbf_network.centers]), initial_weights)
self.rbf_network.train(self.x, target)
self.assertFalse(np.array_equal(initial_weights, self.rbf_network.weights))
output_after = self.rbf_network.predict(self.x)
self.assertNotEqual(output_before, output_after)
if not abs(target - output_after) < abs(target - output_before):
print("Output did not move closer to the target after training.")
if __name__ == "__main__":
unittest.main()