diff --git a/TF_Implementation/test/test_RBF_tf.py b/TF_Implementation/test/test_RBF_tf.py index 9989eb2..74a5401 100644 --- a/TF_Implementation/test/test_RBF_tf.py +++ b/TF_Implementation/test/test_RBF_tf.py @@ -16,6 +16,14 @@ class TestRBFLayer(unittest.TestCase): self.assertEqual(self.rbf_layer.centers.shape, (self.n_centers, self.input_dim)) self.assertEqual(self.rbf_layer.sigmas.shape, (self.n_centers,)) + def test_call(self): + """ Test the call method.""" + inputs = tf.random.normal((3, self.input_dim)) + output = self.rbf_layer(inputs) + + self.assertEqual(output.shape, (3, self.n_centers)) + self.assertTrue(tf.reduce_all(output >= 0)) + if __name__ == '__main__': unittest.main()