From 7951a895848c187c6ab640d42958a92b91aa0490 Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:24:56 -0700 Subject: [PATCH] Added test for training. --- NP_Implementation/tests/test_RBF_np.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/NP_Implementation/tests/test_RBF_np.py b/NP_Implementation/tests/test_RBF_np.py index 0f47f6a..c1e04c9 100644 --- a/NP_Implementation/tests/test_RBF_np.py +++ b/NP_Implementation/tests/test_RBF_np.py @@ -30,7 +30,23 @@ class TestRBFNetwork(unittest.TestCase): output_after = self.rbf_network.predict(self.x) self.assertIsInstance(output_after, float) self.assertNotEqual(output_before, output_after) - + + def test_train(self): + """Test the training function.""" + target = 1.0 + initial_weights = self.rbf_network.weights.copy() + output_before = np.dot(np.array([self.rbf_network.gaussian(self.x, center) + for center in self.rbf_network.centers]), initial_weights) + + self.rbf_network.train(self.x, target) + + self.assertFalse(np.array_equal(initial_weights, self.rbf_network.weights)) + + output_after = self.rbf_network.predict(self.x) + + self.assertNotEqual(output_before, output_after) + if not abs(target - output_after) < abs(target - output_before): + print("Output did not move closer to the target after training.") if __name__ == "__main__": unittest.main()