Analytic inverse implementation (#122)

* Add analytic 2x2 matrix inverse

* Add analytical 3x3 matrix inverse
This commit is contained in:
kritz
2020-03-04 09:14:04 +01:00
committed by GitHub
parent 649c837b6b
commit 4873dc1c1e
2 changed files with 79 additions and 1 deletions
+41
View File
@@ -435,6 +435,47 @@ bool inv(const SquareMatrix<Type, M> & A, SquareMatrix<Type, M> & inv)
return true;
}
template<typename Type>
bool inv(const SquareMatrix<Type, 2> & A, SquareMatrix<Type, 2> & inv)
{
Type det = A(0, 0) * A(1, 1) - A(1, 0) * A(0, 1);
if(fabs(static_cast<float>(det)) < FLT_EPSILON || !is_finite(det)) {
return false;
}
inv(0, 0) = A(1, 1);
inv(1, 0) = -A(1, 0);
inv(0, 1) = -A(0, 1);
inv(1, 1) = A(0, 0);
inv /= det;
return true;
}
template<typename Type>
bool inv(const SquareMatrix<Type, 3> & A, SquareMatrix<Type, 3> & inv)
{
Type det = A(0, 0) * (A(1, 1) * A(2, 2) - A(2, 1) * A(1, 2)) -
A(0, 1) * (A(1, 0) * A(2, 2) - A(1, 2) * A(2, 0)) +
A(0, 2) * (A(1, 0) * A(2, 1) - A(1, 1) * A(2, 0));
if(fabs(static_cast<float>(det)) < FLT_EPSILON || !is_finite(det)) {
return false;
}
inv(0, 0) = A(1, 1) * A(2, 2) - A(2, 1) * A(1, 2);
inv(0, 1) = A(0, 2) * A(2, 1) - A(0, 1) * A(2, 2);
inv(0, 2) = A(0, 1) * A(1, 2) - A(0, 2) * A(1, 1);
inv(1, 0) = A(1, 2) * A(2, 0) - A(1, 0) * A(2, 2);
inv(1, 1) = A(0, 0) * A(2, 2) - A(0, 2) * A(2, 0);
inv(1, 2) = A(1, 0) * A(0, 2) - A(0, 0) * A(1, 2);
inv(2, 0) = A(1, 0) * A(2, 1) - A(2, 0) * A(1, 1);
inv(2, 1) = A(2, 0) * A(0, 1) - A(0, 0) * A(2, 1);
inv(2, 2) = A(0, 0) * A(1, 1) - A(1, 0) * A(0, 1);
inv /= det;
return true;
}
/**
* inverse based on LU factorization with partial pivotting
*/
+38 -1
View File
@@ -22,6 +22,27 @@ int main()
SquareMatrix<float, 3> A_I_check(data_check);
TEST((A_I - A_I_check).abs().max() < 1e-6f);
float data_2x2[4] = {12, 2,
-7, 5
};
float data_2x2_check[4] = {
0.0675675675f, -0.02702702f,
0.0945945945f, 0.162162162f
};
SquareMatrix<float, 2> A2x2(data_2x2);
SquareMatrix<float, 2> A2x2_I = inv(A2x2);
SquareMatrix<float, 2> A2x2_I_check(data_2x2_check);
TEST(isEqual(A2x2_I, A2x2_I_check));
SquareMatrix<float, 2> A2x2_sing = ones<float, 2, 2>();
SquareMatrix<float, 2> A2x2_sing_I;
TEST(inv(A2x2_sing, A2x2_sing_I) == false);
SquareMatrix<float, 3> A3x3_sing = ones<float, 3, 3>();
SquareMatrix<float, 3> A3x3_sing_I;
TEST(inv(A3x3_sing, A3x3_sing_I) == false)
// stess test
SquareMatrix<float, n_large> A_large;
A_large.setIdentity();
@@ -34,7 +55,7 @@ int main()
}
SquareMatrix<float, 3> zero_test = zeros<float, 3, 3>();
inv(zero_test);
TEST(isEqual(inv(zero_test), zeros<float, 3, 3>()));
// test pivotting
float data2[81] = {
@@ -64,6 +85,7 @@ int main()
SquareMatrix<float, 9> A2_I = inv(A2);
SquareMatrix<float, 9> A2_I_check(data2_check);
TEST((A2_I - A2_I_check).abs().max() < 1e-3f);
float data3[9] = {
0, 1, 2,
3, 4, 5,
@@ -93,6 +115,16 @@ int main()
TEST(isEqual(A3_I, Z3));
TEST(isEqual(A3.I(), Z3));
for(size_t i = 0; i < 9; i++) {
A2(0, i) = 0;
}
A2_I = inv(A2);
SquareMatrix<float, 9> Z9 = zeros<float, 9, 9>();
TEST(!A2.I(A2_I));
TEST(!Z9.I(A2_I));
TEST(isEqual(A2_I, Z9));
TEST(isEqual(A2.I(), Z9));
// cover NaN
A3(0, 0) = NAN;
A3(0, 1) = 0;
@@ -101,6 +133,11 @@ int main()
TEST(isEqual(A3_I, Z3));
TEST(isEqual(A3.I(), Z3));
A2(0, 0) = NAN;
A2_I = inv(A2);
TEST(isEqual(A2_I, Z9));
TEST(isEqual(A2.I(), Z9));
float data4[9] = {
1.33471626f, 0.74946721f, -0.0531679f,
0.74946721f, 1.07519593f, 0.08036323f,