diff --git a/NP_Implementation/tests/test_RBF_np.py b/NP_Implementation/tests/test_RBF_np.py index 0f47f6a..c1e04c9 100644 --- a/NP_Implementation/tests/test_RBF_np.py +++ b/NP_Implementation/tests/test_RBF_np.py @@ -30,7 +30,23 @@ class TestRBFNetwork(unittest.TestCase): output_after = self.rbf_network.predict(self.x) self.assertIsInstance(output_after, float) self.assertNotEqual(output_before, output_after) - + + def test_train(self): + """Test the training function.""" + target = 1.0 + initial_weights = self.rbf_network.weights.copy() + output_before = np.dot(np.array([self.rbf_network.gaussian(self.x, center) + for center in self.rbf_network.centers]), initial_weights) + + self.rbf_network.train(self.x, target) + + self.assertFalse(np.array_equal(initial_weights, self.rbf_network.weights)) + + output_after = self.rbf_network.predict(self.x) + + self.assertNotEqual(output_before, output_after) + if not abs(target - output_after) < abs(target - output_before): + print("Output did not move closer to the target after training.") if __name__ == "__main__": unittest.main()