Add test for layer call.

This commit is contained in:
Andru Liu
2024-10-09 19:35:49 -07:00
parent 0e2908525d
commit effb96a9a2
+8
View File
@@ -16,6 +16,14 @@ class TestRBFLayer(unittest.TestCase):
self.assertEqual(self.rbf_layer.centers.shape, (self.n_centers, self.input_dim))
self.assertEqual(self.rbf_layer.sigmas.shape, (self.n_centers,))
def test_call(self):
""" Test the call method."""
inputs = tf.random.normal((3, self.input_dim))
output = self.rbf_layer(inputs)
self.assertEqual(output.shape, (3, self.n_centers))
self.assertTrue(tf.reduce_all(output >= 0))
if __name__ == '__main__':
unittest.main()