diff --git a/TF_Implementation/test/test_RBF_tf.py b/TF_Implementation/test/test_RBF_tf.py index f9e63eb..7d83c59 100644 --- a/TF_Implementation/test/test_RBF_tf.py +++ b/TF_Implementation/test/test_RBF_tf.py @@ -36,6 +36,13 @@ class TestRBFAdaptiveModel(unittest.TestCase): self.assertIsInstance(self.model.rbf_layer, RBFLayer) self.assertEqual(self.model.output_layer.units, 1) + def test_call_output_shape(self): + """ Test the call method for output shape.""" + inputs = tf.random.normal((3, self.input_dim)) + output = self.model(inputs) + + self.assertEqual(output.shape, (3, 1)) + if __name__ == '__main__': unittest.main()