diff --git a/CPP_Implementation/rbf_model_test.cpp b/CPP_Implementation/rbf_model_test.cpp index b21edca..700c9b2 100644 --- a/CPP_Implementation/rbf_model_test.cpp +++ b/CPP_Implementation/rbf_model_test.cpp @@ -5,8 +5,8 @@ class RBFModelTest : public ::testing::Test { protected: void SetUp() override { - // Set up an RBF model with 3 centers, 3D input, and random centers - n_centers = 3; + // Set up an RBF model with 5 centers, 3D input, and random centers + n_centers = 5; input_dim = 3; rbf = new RBFModel(n_centers, input_dim); } @@ -27,6 +27,18 @@ TEST_F(RBFModelTest, Constructor_And_Initial_Weights) { } } +// Test predict method +TEST_F(RBFModelTest, Predict_Output) { + // Set some weights and define centers + for (int i = 0; i < n_centers; ++i) { + rbf->set_weight(i, 1.0); + } + + double input[] = {0.5, 0.5, 0.5}; + double output = rbf->predict(input); + EXPECT_GT(output, 0.0); +} + // Main function for running tests int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv);