// // matrix.hpp // math // // Created by Sam Jaffe on 5/28/16. // #pragma once #include "vector/vector.hpp" #include "expect/expect.hpp" namespace math { namespace matrix { template struct is_matrix { static const constexpr bool value = false; }; template class matrix; template struct is_matrix > { static const constexpr bool value = true; }; namespace concat_strategy { struct {} horizonal; using horizontal_concat_t = decltype(horizonal); struct {} vertical; using vertical_concat_t = decltype(vertical); struct {} diagonal; using diagonal_concat_t = decltype(diagonal); }; #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \ typename std::enable_if::value, matrix >::type #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) for (size_t i = 0; i < i_max; ++i) for (size_t j = 0; j < j_max; ++j) #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C) template class matrix { public: using value_type = T; template using mul_t = decltype(std::declval()*std::declval()); template using div_t = decltype(std::declval()/std::declval()); public: template class row_reference { public: row_reference(S ( & h )[C]) : _handle(h) {} row_reference(row_reference const &) = delete; row_reference(row_reference &&) = default; row_reference & operator=(row_reference const & other) { return operator=(other); } template row_reference & operator=(row_reference const & other) { VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; } return *this; } S const & operator[](std::size_t col) const { return _handle[col]; } S & operator[](std::size_t col) { return _handle[col]; } S const & at(std::size_t col) const { expects(col < C, std::out_of_range, "column index out of range"); return operator[](col); } S & at(std::size_t col) { expects(col < C, std::out_of_range, "column index out of range"); return operator[](col); } private: S ( & _handle )[C]; }; matrix() = default; matrix(std::array, R> const & init) { MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; } } template matrix(vector::vector::type, N> const & other) { VECTOR_FOR_EACH(i) { _data[i][0] = other[i]; } } matrix(matrix const& other) { *this = other; } matrix(matrix && other) { *this = std::move(other); } matrix & operator=(matrix const& other) { MATRIX_FOR_EACH(i, j) { _data[i][j] = other._data[i][j]; } return *this; } matrix & operator=(matrix && other) { MATRIX_FOR_EACH(i, j) { _data[i][j] = std::move(other._data[i][j]); } return *this; } template matrix(matrix const & other) { MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) { _data[i][j] = other(i, j); } } matrix transpose() const { matrix out; MATRIX_FOR_EACH(i, j) { out(j,i) = _data[i][j]; } return out; } template matrix concat(matrix const & other, concat_strategy::horizontal_concat_t) const { matrix accum{*this}; MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); } return accum; } template matrix concat(matrix const & other, concat_strategy::vertical_concat_t) const { matrix accum{*this}; MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); } return accum; } template matrix concat(matrix const & other, concat_strategy::diagonal_concat_t) const { matrix accum{*this}; MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); } return accum; } T const & operator()(std::size_t row, std::size_t col) const { return _data[row][col]; } T & operator()(std::size_t row, std::size_t col) { return _data[row][col]; } row_reference operator[](std::size_t row) const { return { _data[row] }; } row_reference operator[](std::size_t row) { return { _data[row] }; } row_reference at(std::size_t row) const { expects(row >= R, std::out_of_range, "row index out of range"); return operator[](row); } row_reference at(std::size_t row) { expects(row >= R, std::out_of_range, "row index out of range"); return operator[](row); } value_type const & at(std::size_t row, std::size_t col) const { expects(row < R && col < C, std::out_of_range, "coordinates out of range"); return _data[row][col]; } value_type & at(std::size_t row, std::size_t col) { expects(row < R && col < C, std::out_of_range, "coordinates out of range"); return _data[row][col]; } matrix& operator+=(matrix const & other) { MATRIX_FOR_EACH(i, j) { _data[i][j] += other[i][j]; } return *this; } matrix operator+(matrix const & other) const { return matrix{*this} += other; } matrix& operator-=(matrix const & other) { MATRIX_FOR_EACH(i, j) { _data[i][j] -= other[i][j]; } return *this; } matrix operator-(matrix const & other) const { return matrix{*this} -= other; } vector::vector operator*(vector::vector const & vec) const { vector::vector rval; MATRIX_FOR_EACH(i, j) { rval[i] += _data[i][j] * vec[j]; } return rval; } template matrix operator*(matrix const & other) const { matrix rval; MATRIX_FOR_EACH(i, j) { for (size_t k = 0; k < C2; ++k) { rval[i][k] += _data[i][j] * other[j][k]; } } return rval; } matrix& operator*=(T c) { MATRIX_FOR_EACH(i, j) { _data[i][j] *= c; } return *this; } template MATRIX_DISABLE_IF_MATRIX(M, mul_t, R, C) operator*(M c) const { return matrix, R, C>{*this} *= c; } template friend MATRIX_DISABLE_IF_MATRIX(M, mul_t, R, C) operator*(M c, matrix const& matr) { return matrix, R, C>{matr} *= c; } template matrix, R, C>& operator/=(M c) { MATRIX_FOR_EACH(i, j) { _data[i][j] /= c; } return *this; } template matrix, R, C> operator/(M c) const { return matrix, R, C>{*this} /= c; } bool operator==(matrix const & other) const { MATRIX_FOR_EACH(i, j) { if (_data[i][j] != other._data[i][j]) { return false; } } return true; } bool operator!=(matrix const & other) const { return !operator==(other); } private: value_type _data[R][C] = {value_type()}; }; } }