From effb96a9a2ef089f55397f0c43160a1f91f89629 Mon Sep 17 00:00:00 2001 From: Andru Liu <90433630+WallabyLester@users.noreply.github.com> Date: Wed, 9 Oct 2024 19:35:49 -0700 Subject: [PATCH] Add test for layer call. --- TF_Implementation/test/test_RBF_tf.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/TF_Implementation/test/test_RBF_tf.py b/TF_Implementation/test/test_RBF_tf.py index 9989eb2..74a5401 100644 --- a/TF_Implementation/test/test_RBF_tf.py +++ b/TF_Implementation/test/test_RBF_tf.py @@ -16,6 +16,14 @@ class TestRBFLayer(unittest.TestCase): self.assertEqual(self.rbf_layer.centers.shape, (self.n_centers, self.input_dim)) self.assertEqual(self.rbf_layer.sigmas.shape, (self.n_centers,)) + def test_call(self): + """ Test the call method.""" + inputs = tf.random.normal((3, self.input_dim)) + output = self.rbf_layer(inputs) + + self.assertEqual(output.shape, (3, self.n_centers)) + self.assertTrue(tf.reduce_all(output >= 0)) + if __name__ == '__main__': unittest.main()