mirror of
https://github.com/WallabyLester/RBF-aPID-Controller.git
synced 2026-06-02 04:47:00 +08:00
Add test for layer call.
This commit is contained in:
@@ -16,6 +16,14 @@ class TestRBFLayer(unittest.TestCase):
|
|||||||
self.assertEqual(self.rbf_layer.centers.shape, (self.n_centers, self.input_dim))
|
self.assertEqual(self.rbf_layer.centers.shape, (self.n_centers, self.input_dim))
|
||||||
self.assertEqual(self.rbf_layer.sigmas.shape, (self.n_centers,))
|
self.assertEqual(self.rbf_layer.sigmas.shape, (self.n_centers,))
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
""" Test the call method."""
|
||||||
|
inputs = tf.random.normal((3, self.input_dim))
|
||||||
|
output = self.rbf_layer(inputs)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (3, self.n_centers))
|
||||||
|
self.assertTrue(tf.reduce_all(output >= 0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user