diff --git a/NP_Implementation/tests/test_RBF_np.py b/NP_Implementation/tests/test_RBF_np.py index d35e568..0f47f6a 100644 --- a/NP_Implementation/tests/test_RBF_np.py +++ b/NP_Implementation/tests/test_RBF_np.py @@ -18,5 +18,19 @@ class TestRBFNetwork(unittest.TestCase): output = self.rbf_network.gaussian(self.x, center) self.assertAlmostEqual(output, expected_output, places=5) + def test_predict(self): + """Test the predict function.""" + output_before = self.rbf_network.predict(self.x) + self.assertIsInstance(output_before, float) + [self.assertNotAlmostEqual(self.rbf_network.weights[i], self.rbf_network.weights[i+1]) + for i in range(len(self.rbf_network.weights)-1)] + + target = 1.0 + self.rbf_network.train(self.x, target) + output_after = self.rbf_network.predict(self.x) + self.assertIsInstance(output_after, float) + self.assertNotEqual(output_before, output_after) + + if __name__ == "__main__": unittest.main()