From 03ff1f639a7070cc37e9da3e2420fe23d196204a Mon Sep 17 00:00:00 2001 From: Samer Afach Date: Mon, 17 Oct 2016 23:24:04 +0200 Subject: [PATCH] Cleaned size checks. Added functions: zeros, ones, trace, elementWiseProduct and elementWiseProduct_inplace renamed functions to be with _inplace when they replace "this". --- include/internal/Matrix.h | 216 ++++++++++++++++++++------------------ src/Matrix.cpp | 1 + tests/tests.cpp | 3 + tests/tests.h | 2 +- 4 files changed, 120 insertions(+), 102 deletions(-) diff --git a/include/internal/Matrix.h b/include/internal/Matrix.h index 97f7d56..563165c 100644 --- a/include/internal/Matrix.h +++ b/include/internal/Matrix.h @@ -116,7 +116,9 @@ void MultiplyMatrices(const Matrix &matrixLHS, const Matrix &matrixRHS template void MultiplyMatrices(const Matrix& lhs, const Matrix& rhs, Matrix &to_add_then_result, T alpha, T beta); template -Matrix IdentityMatrix(const typename Poly::Matrix::size_type& size); +Matrix IdentityMatrix(const typename Matrix::size_type& size); +template +void _check_equal_matrices_sizes(const Matrix&, const Matrix&); template class Matrix @@ -138,6 +140,8 @@ private: inline size_type SizeToReserve(const size_type &rows, const size_type &columns) const; void invert_manual(); + void _check_square_matrix(); + public: Matrix(); @@ -203,16 +207,20 @@ public: inline typename std::vector::iterator end() noexcept; inline typename std::vector::const_iterator end() const noexcept; - Matrix& operator+=(const Matrix& rhs); Matrix& operator-=(const Matrix& rhs); Matrix& operator*=(const Matrix& rhs); Matrix& operator/=(const Matrix& rhs); - void applyInverse(); - Matrix getInverse(); + void elementWiseProduct_inplace(const Matrix& rhs); + Matrix elementWiseProduct(const Matrix& rhs); + T trace(); + void zeros(); + void ones(); + void inverse_inplace(); + Matrix inverse(); Matrix getExp(); - void applyConjugate(); - void applyConjugateTranspose(); + void conjugate_inplace(); + void conjugateTranspose_inplace(); std::vector vectorize(const VECTORIZATION_MODE& mode = VECMODE_FULL) const; const T& at(size_type row, size_type column) const; T& at(size_type row, size_type column); @@ -224,7 +232,7 @@ public: const size_type& rows() const; const size_type& columns() const; void clear(); - void applyTranspose(); + void transpose_inplace(); T getDeterminant(); Matrix getSVD(); std::string asString(int precision = 7, char open = '[', char close = ']', char sep = ','); @@ -420,13 +428,13 @@ bool Matrix::operator!=(const Matrix& rhs) } template -void Matrix::applyTranspose() +void Matrix::transpose_inplace() { *this = Transpose(*this); } template -void Matrix::applyConjugateTranspose() +void Matrix::conjugateTranspose_inplace() { *this = ConjugateTranspose(*this); } @@ -509,12 +517,7 @@ Matrix,M> operator*(const Matrix& lhs, const std::complex Matrix operator+(const Matrix& lhs, const Matrix& rhs) { -#ifdef POLYMATH_DEBUG - if((lhs.columns() != rhs.columns()) || (lhs.rows() != rhs.rows())) - { - throw std::length_error("Invalid matrix sizes for addition"); - } -#endif + _check_equal_matrices_sizes(lhs,rhs); Matrix result(rhs.rows(),rhs.columns()); std::transform( lhs.begin(), lhs.end(), rhs.begin(), result.begin(), @@ -526,12 +529,7 @@ template Matrix operator-(const Matrix& lhs, const Matrix& rhs) { Matrix result; -#ifdef POLYMATH_DEBUG - if((lhs.columns() != rhs.columns()) || (lhs.rows() != rhs.rows())) - { - throw std::length_error("Invalid matrix sizes for subtraction"); - } -#endif + _check_equal_matrices_sizes(lhs,rhs); result.resize(rhs.rows(),rhs.columns()); std::transform( lhs.begin(), lhs.end(), rhs.begin(), result.begin(), @@ -542,12 +540,7 @@ Matrix operator-(const Matrix& lhs, const Matrix& rhs) template Matrix,M> operator+(const Matrix,M>& lhs, const Matrix& rhs) { -#ifdef POLYMATH_DEBUG - if((lhs.columns() != rhs.columns()) || (lhs.rows() != rhs.rows())) - { - throw std::length_error("Invalid matrix sizes for addition"); - } -#endif + _check_equal_matrices_sizes(lhs,rhs); Matrix,M> result(rhs.rows(),rhs.columns()); std::transform( lhs.begin(), lhs.end(), rhs.begin(), result.begin(), @@ -558,12 +551,7 @@ Matrix,M> operator+(const Matrix,M>& lhs, const template Matrix,M> operator+(const Matrix& lhs, const Matrix,M>& rhs) { -#ifdef POLYMATH_DEBUG - if((lhs.columns() != rhs.columns()) || (lhs.rows() != rhs.rows())) - { - throw std::length_error("Invalid matrix sizes for addition"); - } -#endif + _check_equal_matrices_sizes(lhs,rhs); Matrix,M> result(rhs.rows(),rhs.columns()); std::transform( lhs.begin(), lhs.end(), rhs.begin(), result.begin(), @@ -573,12 +561,7 @@ Matrix,M> operator+(const Matrix& lhs, const Matrix Matrix,M> operator-(const Matrix,M>& lhs, const Matrix& rhs) { -#ifdef POLYMATH_DEBUG - if((lhs.columns() != rhs.columns()) || (lhs.rows() != rhs.rows())) - { - throw std::length_error("Invalid matrix sizes for addition"); - } -#endif + _check_equal_matrices_sizes(lhs,rhs); Matrix,M> result(rhs.rows(),rhs.columns()); std::transform( lhs.begin(), lhs.end(), rhs.begin(), result.begin(), @@ -589,12 +572,7 @@ Matrix,M> operator-(const Matrix,M>& lhs, const template Matrix,M> operator-(const Matrix& lhs, const Matrix,M>& rhs) { -#ifdef POLYMATH_DEBUG - if((lhs.columns() != rhs.columns()) || (lhs.rows() != rhs.rows())) - { - throw std::length_error("Invalid matrix sizes for addition"); - } -#endif + _check_equal_matrices_sizes(lhs,rhs); Matrix,M> result(rhs.rows(),rhs.columns()); std::transform( lhs.begin(), lhs.end(), rhs.begin(), result.begin(), @@ -624,8 +602,52 @@ Matrix& Matrix::operator*=(const Matrix& rhs) } template -void Matrix::applyInverse() +void Matrix::elementWiseProduct_inplace(const Matrix &rhs) { + _check_equal_matrices_sizes(*this,rhs); + std::transform( this->begin(), this->end(), + rhs.begin(), this->begin(), + std::multiplies()); +} + +template +Matrix Matrix::elementWiseProduct(const Matrix &rhs) +{ + Matrix result(this->rows(),this->columns()); + std::transform( this->begin(), this->end(), + rhs.begin(), result.begin(), + std::multiplies()); + return result; +} + +template +T Matrix::trace() +{ + _check_square_matrix(); + T result = T(0); + for(size_type i = 0; i < this->rows(); i++) + { + result += this->at(i,i); + } + return result; +} + +template +void Matrix::zeros() +{ + std::fill(this->begin(), this->end(), T(0)); +} + +template +void Matrix::ones() +{ + std::fill(this->begin(), this->end(), T(1)); +} + +template +void Matrix::inverse_inplace() +{ + //FIXME: Make size error functions all in one place if(rows() != columns()) { throw std::length_error("Matrix inverse is only for square matrices."); @@ -714,6 +736,38 @@ void Matrix::invert_manual() (*this) = sideMat; } +template +void Matrix::_check_square_matrix() +{ +#ifdef POLYMATH_DEBUG + if(this->columns() != this->rows()) + { + throw std::length_error("This operation requires square matrices."); + } +#endif +} + +template +void _check_equal_matrices_sizes(const Matrix& +#ifdef POLYMATH_DEBUG + mat1 +#endif + , + const Matrix& +#ifdef POLYMATH_DEBUG + mat2 +#endif + ) +{ +#ifdef POLYMATH_DEBUG + if((mat1.columns() != mat2.columns()) || (mat1.rows() != mat2.rows())) + { + throw std::length_error("Matrices sizes should be equal"); + } +#endif +} + + template void Matrix::copyFrom(const Matrix &rhs) { @@ -755,10 +809,10 @@ void Matrix::initialize(const std::initializer_list &list, Matrix:: } template -Matrix Matrix::getInverse() +Matrix Matrix::inverse() { Matrix tmpMat = *this; - tmpMat.applyInverse(); + tmpMat.inverse_inplace(); return tmpMat; } @@ -768,7 +822,7 @@ Matrix Matrix::getExp() } template -void Matrix::applyConjugate() +void Matrix::conjugate_inplace() { transform(_matEls.begin(),_matEls.end(), _matEls.begin(), [](std::complex& c) -> std::complex { return std::conj(c); }); } @@ -786,12 +840,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_UPPER_TRIANGULAR_NO_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()-1))/2); typename std::vector::const_iterator it = _matEls.begin() + this->columns(); //iterator at rows to skip in every step, starts at second column long size_to_copy = 0; @@ -804,12 +853,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_UPPER_TRIANGULAR_WITH_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()+1))/2); typename std::vector::const_iterator it = _matEls.begin(); long size_to_copy = 0; @@ -822,12 +866,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_LOWER_TRIANGULAR_NO_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()-1))/2); typename std::vector::const_iterator it = _matEls.begin(); //iterator at rows to skip in every step, starts at second column long size_to_copy = 0; @@ -841,12 +880,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_LOWER_TRIANGULAR_WITH_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()+1))/2); typename std::vector::const_iterator it = _matEls.begin(); //iterator at rows to skip in every step, starts at second column long size_to_copy = 0; @@ -870,12 +904,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_LOWER_TRIANGULAR_NO_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()-1))/2); typename std::vector::const_iterator it = _matEls.begin() + this->columns(); //iterator at rows to skip in every step, starts at second column long size_to_copy = 0; @@ -888,12 +917,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_LOWER_TRIANGULAR_WITH_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()+1))/2); typename std::vector::const_iterator it = _matEls.begin(); long size_to_copy = 0; @@ -906,12 +930,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_UPPER_TRIANGULAR_NO_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()-1))/2); typename std::vector::const_iterator it = _matEls.begin(); //iterator at rows to skip in every step, starts at second column long size_to_copy = 0; @@ -925,12 +944,7 @@ std::vector Matrix::vectorize(const VECTORIZATION_MODE &mode) const } else if(mode == Poly::VECMODE_UPPER_TRIANGULAR_WITH_DIAGONAL) { -#ifdef POLYMATH_DEBUG - if(this->columns() != this->rows()) - { - throw std::length_error("You cannot vectorize a non-square matrix to upper/lower triangulars."); - } -#endif + _check_square_matrix(); output.resize((this->columns()*(this->columns()+1))/2); typename std::vector::const_iterator it = _matEls.begin(); //iterator at rows to skip in every step, starts at second column long size_to_copy = 0; @@ -1556,7 +1570,7 @@ int _Lapack_EigenVals_hermitian_divconq_double(Poly::Matrix& eigenValues, Pol iworkspace.get(), &liwork, &info); if(M != ColMaj) { - input.applyConjugateTranspose(); + input.conjugateTranspose_inplace(); } if(info > 0) { @@ -1609,7 +1623,7 @@ int _Lapack_EigenVals_hermitian_divconq_float(Poly::Matrix& eigenValues, Poly iworkspace.get(), &liwork, &info); if(M != ColMaj) { - input.applyConjugateTranspose(); + input.conjugateTranspose_inplace(); } if(info > 0) { @@ -1646,7 +1660,7 @@ int _Lapack_EigenVals_hermitian_double(Poly::Matrix& eigenValues, Poly::Matri zheev_(&jobz, &uplo, &n, (CX_TP*)&input.front(), &lda, (TP*)&eigenValues.front(), (CX_TP*)workspace.get(), &lwork, rwork.get(), &info); if(M != ColMaj) { - input.applyConjugateTranspose(); + input.conjugateTranspose_inplace(); } if(info > 0) { @@ -1683,7 +1697,7 @@ int _Lapack_EigenVals_hermitian_float(Poly::Matrix& eigenValues, Poly::Matrix cheev_(&jobz, &uplo, &n, (CX_TP*)&input.front(), &lda, (TP*)&eigenValues.front(), (CX_TP*)workspace.get(), &lwork, rwork.get(), &info); if(M != ColMaj) { - input.applyConjugateTranspose(); + input.conjugateTranspose_inplace(); } if(info > 0) { diff --git a/src/Matrix.cpp b/src/Matrix.cpp index 025e440..e3871bd 100644 --- a/src/Matrix.cpp +++ b/src/Matrix.cpp @@ -26,4 +26,5 @@ std::string _unsupported_type(const std::string& function_name) { return std::string("The type you is not supported for the function " + function_name + "."); } + } diff --git a/tests/tests.cpp b/tests/tests.cpp index 4284ab9..3827a25 100644 --- a/tests/tests.cpp +++ b/tests/tests.cpp @@ -44,6 +44,9 @@ int RunTests() int main() { + int len = 3; + Poly::Matrix mat_d = Poly::RandomMatrix(len,len,0,10,std::random_device{}()); + Poly::Matrix mat_e = Poly::RandomMatrix(len,len,0,10,std::random_device{}()); RunTests(); std::cout<<"Tests program exited with no errors."<(); - auto mat_d_i = mat_d.getInverse(); + auto mat_d_i = mat_d.inverse(); PyObject *main = PyImport_AddModule("__main__"); PyRun_SimpleString(std::string("data={}").c_str()); PyRun_SimpleString(std::string("data['a']=np.matrix(" + mat_d.asString(32,'[',']',',') + ",dtype="+python_type_per_type()+")").c_str());