diff --git a/matrix/SquareMatrix.hpp b/matrix/SquareMatrix.hpp index 06415f1abd..fcb225d324 100644 --- a/matrix/SquareMatrix.hpp +++ b/matrix/SquareMatrix.hpp @@ -61,6 +61,18 @@ public: return *this; } + template + const Slice slice(size_t x0, size_t y0) const + { + return Slice(x0, y0, this); + } + + template + Slice slice(size_t x0, size_t y0) + { + return Slice(x0, y0, this); + } + // inverse alias inline SquareMatrix I() const { @@ -118,6 +130,119 @@ public: } return res; } + + // zero all offdiagonal elements and keep corresponding diagonal elements + template + void uncorrelateCovariance(size_t first) + { + SquareMatrix &self = *this; + Vector diag_elements = self.slice(first, first).diag(); + self.uncorrelateCovarianceSetVariance(first, diag_elements); + } + + template + void uncorrelateCovarianceSetVariance(size_t first, const Vector &vec) + { + SquareMatrix &self = *this; + // zero rows and columns + self.slice(first, 0) = 0; + self.slice(0, first) = 0; + + // set diagonals + unsigned vec_idx = 0; + for (size_t idx = first; idx < first+Width; idx++) { + self(idx,idx) = vec(vec_idx); + vec_idx ++; + } + } + + template + void uncorrelateCovarianceSetVariance(size_t first, Type val) + { + SquareMatrix &self = *this; + // zero rows and columns + self.slice(first, 0) = 0; + self.slice(0, first) = 0; + + // set diagonals + for (size_t idx = first; idx < first+Width; idx++) { + self(idx,idx) = val; + } + } + + // make block diagonal symmetric by taking the average of the two corresponding off diagonal values + template + void makeBlockSymmetric(size_t first) + { + SquareMatrix &self = *this; + if(Width>1) { + for (size_t row_idx = first+1; row_idx < first+Width; row_idx++) { + for (size_t col_idx = first; col_idx < row_idx; col_idx++) { + Type tmp = self(row_idx,col_idx) + (self(col_idx,row_idx) - self(row_idx,col_idx)) / 2; + self(row_idx,col_idx) = tmp; + self(col_idx,row_idx) = tmp; + } + } + } + } + + // make rows and columns symmetric by taking the average of the two corresponding off diagonal values + template + void makeRowColSymmetric(size_t first) + { + SquareMatrix &self = *this; + self.makeBlockSymmetric(first); + for (size_t row_idx = first; row_idx < first+Width; row_idx++) { + for (size_t col_idx = 0; col_idx < first; col_idx++) { + Type tmp = self(row_idx,col_idx) + (self(col_idx,row_idx) - self(row_idx,col_idx)) / 2; + self(row_idx,col_idx) = tmp; + self(col_idx,row_idx) = tmp; + } + for (size_t col_idx = first+Width; col_idx < M; col_idx++) { + Type tmp = self(row_idx,col_idx) + (self(col_idx,row_idx) - self(row_idx,col_idx)) / 2; + self(row_idx,col_idx) = tmp; + self(col_idx,row_idx) = tmp; + } + } + } + + // checks if block diagonal is symmetric + template + bool isBlockSymmetric(size_t first, const Type eps = 1e-8f) + { + SquareMatrix &self = *this; + if(Width>1) { + for (size_t row_idx = first+1; row_idx < first+Width; row_idx++) { + for (size_t col_idx = first; col_idx < row_idx; col_idx++) { + if(!isEqualF(self(row_idx,col_idx), self(col_idx,row_idx), eps)) { + return false; + } + } + } + } + return true; + } + + // checks if rows and columns are symmetric + template + bool isRowColSymmetric(size_t first, const Type eps = 1e-8f) + { + SquareMatrix &self = *this; + for (size_t row_idx = first; row_idx < first+Width; row_idx++) { + for (size_t col_idx = 0; col_idx < first; col_idx++) { + if(!isEqualF(self(row_idx,col_idx), self(col_idx,row_idx), eps)) { + return false; + } + } + for (size_t col_idx = first+Width; col_idx < M; col_idx++) { + if(!isEqualF(self(row_idx,col_idx), self(col_idx,row_idx), eps)) { + return false; + } + } + } + return self.isBlockSymmetric(first, eps); + } + }; typedef SquareMatrix SquareMatrix3f; diff --git a/test/squareMatrix.cpp b/test/squareMatrix.cpp index 3ad6b1a832..0c87debc0b 100644 --- a/test/squareMatrix.cpp +++ b/test/squareMatrix.cpp @@ -38,7 +38,110 @@ int main() TEST(isEqual(A_bottomright, bottomright_check)); TEST(isEqual(A_bottomright2, bottomright_check)); + // test diagonal functions + float data_4x4[16] = {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14,15, 16 + }; + SquareMatrix B(data_4x4); + B.uncorrelateCovariance<1>(1); + float data_B_check[16] = {1, 0, 3, 4, + 0, 6, 0, 0, + 9, 0, 11, 12, + 13, 0,15, 16 + }; + SquareMatrix B_check(data_B_check); + TEST(isEqual(B, B_check)) + + SquareMatrix C(data_4x4); + C.uncorrelateCovariance<2>(1); + float data_C_check[16] = {1, 0, 0, 4, + 0, 6, 0, 0, + 0, 0, 11, 0, + 13, 0,0, 16 + }; + SquareMatrix C_check(data_C_check); + TEST(isEqual(C, C_check)) + + SquareMatrix D(data_4x4); + D.uncorrelateCovarianceSetVariance<2>(0, Vector2f{20,21}); + float data_D_check[16] = {20, 0, 0, 0, + 0, 21, 0, 0, + 0, 0, 11, 12, + 0, 0,15, 16 + }; + SquareMatrix D_check(data_D_check); + TEST(isEqual(D, D_check)) + + SquareMatrix E(data_4x4); + E.uncorrelateCovarianceSetVariance<3>(1, 33); + float data_E_check[16] = {1, 0, 0, 0, + 0, 33, 0, 0, + 0, 0, 33, 0, + 0, 0,0, 33 + }; + SquareMatrix E_check(data_E_check); + TEST(isEqual(E, E_check)) + + // test symmetric functions + SquareMatrix F(data_4x4); + F.makeBlockSymmetric<2>(1); + float data_F_check[16] = {1, 2, 3, 4, + 5, 6, 8.5, 8, + 9, 8.5, 11, 12, + 13, 14,15, 16 + }; + SquareMatrix F_check(data_F_check); + TEST(isEqual(F, F_check)) + TEST(F.isBlockSymmetric<2>(1)); + TEST(!F.isRowColSymmetric<2>(1)); + + SquareMatrix G(data_4x4); + G.makeRowColSymmetric<2>(1); + float data_G_check[16] = {1, 3.5, 6, 4, + 3.5, 6, 8.5, 11, + 6, 8.5, 11, 13.5, + 13, 11,13.5, 16 + }; + SquareMatrix G_check(data_G_check); + TEST(isEqual(G, G_check)); + TEST(G.isBlockSymmetric<2>(1)); + TEST(G.isRowColSymmetric<2>(1)); + + SquareMatrix H(data_4x4); + H.makeBlockSymmetric<1>(1); + float data_H_check[16] = {1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14,15, 16 + }; + SquareMatrix H_check(data_H_check); + TEST(isEqual(H, H_check)) + TEST(H.isBlockSymmetric<1>(1)); + TEST(!H.isRowColSymmetric<1>(1)); + + SquareMatrix J(data_4x4); + J.makeRowColSymmetric<1>(1); + float data_J_check[16] = {1, 3.5, 3, 4, + 3.5, 6, 8.5, 11, + 9, 8.5, 11, 12, + 13, 11,15, 16 + }; + SquareMatrix J_check(data_J_check); + TEST(isEqual(J, J_check)); + TEST(J.isBlockSymmetric<1>(1)); + TEST(J.isRowColSymmetric<1>(1)); + TEST(!J.isBlockSymmetric<3>(1)); + + float data_K[16] = {1, 2, 3, 4, + 2, 3, 4, 11, + 3, 4, 11, 12, + 4, 11,15, 16 + }; + SquareMatrix K(data_K); + TEST(!K.isRowColSymmetric<1>(2)); return 0; }