Added test for output shape of the model.

This commit is contained in:
Andru Liu
2024-10-09 19:48:23 -07:00
parent 32dae03ad1
commit aaeb3a4507

View File

@@ -36,6 +36,13 @@ class TestRBFAdaptiveModel(unittest.TestCase):
self.assertIsInstance(self.model.rbf_layer, RBFLayer)
self.assertEqual(self.model.output_layer.units, 1)
def test_call_output_shape(self):
""" Test the call method for output shape."""
inputs = tf.random.normal((3, self.input_dim))
output = self.model(inputs)
self.assertEqual(output.shape, (3, 1))
if __name__ == '__main__':
unittest.main()