From 4e11dd8c76543fe071df214ae2880a62003eaabb Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:15:46 -0700 Subject: [PATCH] Added test for predict function. --- NP_Implementation/tests/test_RBF_np.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/NP_Implementation/tests/test_RBF_np.py b/NP_Implementation/tests/test_RBF_np.py index d35e568..0f47f6a 100644 --- a/NP_Implementation/tests/test_RBF_np.py +++ b/NP_Implementation/tests/test_RBF_np.py @@ -18,5 +18,19 @@ class TestRBFNetwork(unittest.TestCase): output = self.rbf_network.gaussian(self.x, center) self.assertAlmostEqual(output, expected_output, places=5) + def test_predict(self): + """Test the predict function.""" + output_before = self.rbf_network.predict(self.x) + self.assertIsInstance(output_before, float) + [self.assertNotAlmostEqual(self.rbf_network.weights[i], self.rbf_network.weights[i+1]) + for i in range(len(self.rbf_network.weights)-1)] + + target = 1.0 + self.rbf_network.train(self.x, target) + output_after = self.rbf_network.predict(self.x) + self.assertIsInstance(output_after, float) + self.assertNotEqual(output_before, output_after) + + if __name__ == "__main__": unittest.main()