From b2efa77059adbf2f895dd6072d3d2c49be357f36 Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:15:04 -0700 Subject: [PATCH] Created RBF network tests for numpy. Test for gaussian function. --- NP_Implementation/tests/test_RBF_np.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 NP_Implementation/tests/test_RBF_np.py diff --git a/NP_Implementation/tests/test_RBF_np.py b/NP_Implementation/tests/test_RBF_np.py new file mode 100644 index 0000000..d35e568 --- /dev/null +++ b/NP_Implementation/tests/test_RBF_np.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np + +from RBF_numpy import RBFNetwork + +class TestRBFNetwork(unittest.TestCase): + def setUp(self): + """Set up an RBFNetwork instance for testing.""" + self.input_dim = 3 + self.n_centers = 5 + self.x = np.array([6.0, 0.5, 0.2]) + self.rbf_network = RBFNetwork(self.input_dim, self.n_centers) + + def test_gaussian(self): + """Test the Gaussian function.""" + center = np.random.rand(1, self.input_dim) + expected_output = np.exp(-np.linalg.norm(self.x - center) ** 2 / (2 * self.rbf_network.sigma ** 2)) + output = self.rbf_network.gaussian(self.x, center) + self.assertAlmostEqual(output, expected_output, places=5) + +if __name__ == "__main__": + unittest.main()