diff --git a/test/least_squares.cpp b/test/least_squares.cpp index f9cae89cd69..573715caa24 100644 --- a/test/least_squares.cpp +++ b/test/least_squares.cpp @@ -4,7 +4,7 @@ using namespace matrix; int test_4x3(void); -int test_4x4(void); +template int test_4x4(void); int test_4x4_type_double(void); int test_div_zero(void); @@ -12,14 +12,15 @@ int main() { int ret; - ret = test_4x4(); + ret = test_4x4(); + if (ret != 0) return ret; + + ret = test_4x4(); if (ret != 0) return ret; ret = test_4x3(); if (ret != 0) return ret; - ret = test_4x4_type_double(); - ret = test_div_zero(); if (ret != 0) return ret; @@ -35,7 +36,7 @@ int test_4x3() { }; Matrix A(data); - float b_data[4] = {2.0, 3.0, 4.0, 5.0}; + float b_data[4] = {2.0f, 3.0f, 4.0f, 5.0f}; Vector b(b_data); float x_check_data[3] = {-0.69168233f, @@ -51,54 +52,29 @@ int test_4x3() { return 0; } +template int test_4x4() { // Start with an (m x n) A matrix - float data[16] = { 20.f, -10.f, -13.f, 21.f, - 17.f, 16.f, -18.f, -14.f, - 0.7f, -0.8f, 0.9f, -0.5f, - -1.f, -1.1f, -1.2f, -1.3f - }; - Matrix A(data); + const Type data[16] = { 20.f, -10.f, -13.f, 21.f, + 17.f, 16.f, -18.f, -14.f, + 0.7f, -0.8f, 0.9f, -0.5f, + -1.f, -1.1f, -1.2f, -1.3f + }; + Matrix A(data); - float b_data[4] = {2.0, 3.0, 4.0, 5.0}; - Vector b(b_data); + Type b_data[4] = {2.0f, 3.0f, 4.0f, 5.0f}; + Vector b(b_data); - float x_check_data[4] = { 0.97893433f, - -2.80798701f, - -0.03175765f, - -2.19387649f - }; - Vector x_check(x_check_data); + Type x_check_data[4] = { 0.97893433f, + -2.80798701f, + -0.03175765f, + -2.19387649f + }; + Vector x_check(x_check_data); - LeastSquaresSolver qrd = LeastSquaresSolver(A); + LeastSquaresSolver qrd = LeastSquaresSolver(A); - Vector x = qrd.solve(b); - TEST(isEqual(x, x_check)); - return 0; -} - -int test_4x4_type_double() { - // Start with an (m x n) A matrix - double data[16] = { 20., -10., -13., 21., - 17., 16., -18., -14., - 0.7, -0.8, 0.9, -0.5, - -1., -1.1, -1.2, -1.3 - }; - Matrix A(data); - - double b_data[4] = {2.0, 3.0, 4.0, 5.0}; - Vector b(b_data); - - double x_check_data[4] = { 0.97893433, - -2.80798701, - -0.03175765, - -2.19387649 - }; - Vector x_check(x_check_data); - - LeastSquaresSolver qrd = LeastSquaresSolver(A); - - Vector x = qrd.solve(b); + Vector x = qrd.solve(b); TEST(isEqual(x, x_check)); return 0; } @@ -107,7 +83,7 @@ int test_div_zero() { float data[4] = {0.0f, 0.0f, 0.0f, 0.0f}; Matrix A(data); - float b_data[2] = {1.0, 1.0}; + float b_data[2] = {1.0f, 1.0f}; Vector b(b_data); // Implement such that x returns zeros if it reaches div by zero