diff --git a/TF_Implementation/test/test_RBF_tf.py b/TF_Implementation/test/test_RBF_tf.py index 74a5401..f9e63eb 100644 --- a/TF_Implementation/test/test_RBF_tf.py +++ b/TF_Implementation/test/test_RBF_tf.py @@ -23,7 +23,19 @@ class TestRBFLayer(unittest.TestCase): self.assertEqual(output.shape, (3, self.n_centers)) self.assertTrue(tf.reduce_all(output >= 0)) - +class TestRBFAdaptiveModel(unittest.TestCase): + def setUp(self): + """ Set up a RBFAdaptiveModel class instance.""" + self.n_centers = 5 + self.input_dim = 3 + self.model = RBFAdaptiveModel(self.n_centers, self.input_dim) + + def test_initialization(self): + """ Test the initialization of the model.""" + self.assertIsInstance(self.model.rbf_layer, RBFLayer) + self.assertEqual(self.model.output_layer.units, 1) + + if __name__ == '__main__': unittest.main()