Added test for predict function.

This commit is contained in:
Andru Liu
2024-10-09 15:15:46 -07:00
parent b2efa77059
commit 4e11dd8c76
+14
View File
@@ -18,5 +18,19 @@ class TestRBFNetwork(unittest.TestCase):
output = self.rbf_network.gaussian(self.x, center)
self.assertAlmostEqual(output, expected_output, places=5)
def test_predict(self):
"""Test the predict function."""
output_before = self.rbf_network.predict(self.x)
self.assertIsInstance(output_before, float)
[self.assertNotAlmostEqual(self.rbf_network.weights[i], self.rbf_network.weights[i+1])
for i in range(len(self.rbf_network.weights)-1)]
target = 1.0
self.rbf_network.train(self.x, target)
output_after = self.rbf_network.predict(self.x)
self.assertIsInstance(output_after, float)
self.assertNotEqual(output_before, output_after)
if __name__ == "__main__":
unittest.main()