// // matrix.hpp // math // // Created by Sam Jaffe on 5/28/16. // #pragma once #include #include #include "math/matrix/forward.h" #include "math/matrix/row_reference.hpp" #include "math/matrix/traits.hpp" #include "math/matrix/macro.h" namespace math::matrix { template class matrix { public: using value_type = T; template using mul_t = decltype(std::declval() * std::declval()); template using div_t = decltype(std::declval() / std::declval()); public: matrix() = default; MATRIX_CTOR_N_ARGS(1) {} MATRIX_CTOR_N_ARGS(2) {} MATRIX_CTOR_N_ARGS(3) {} MATRIX_CTOR_N_ARGS(4) {} DEFER_RESOLUTION matrix(vector::vector const & other, DEFERRED_ENABLE_IF_T(C == 1, bool) = true) { VECTOR_FOR_EACH_RANGE(i, R) { _data[i][0] = other[i]; } } matrix(std::array, R> const & init) { MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; } } matrix(std::array, R> const & init) { MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; } } template matrix(matrix const & other) { MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) { _data[i][j] = other(i, j); } } matrix transpose() const { matrix out; MATRIX_FOR_EACH(i, j) { out(j, i) = _data[i][j]; } return out; } template matrix concat(matrix const & other, concat_strategy::horizontal_concat_t) const { matrix accum{*this}; MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); } return accum; } template matrix concat(matrix const & other, concat_strategy::vertical_concat_t) const { matrix accum{*this}; MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); } return accum; } template matrix concat(matrix const & other, concat_strategy::diagonal_concat_t) const { matrix accum{*this}; MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); } return accum; } // In C++23 - we can use operator[](row, col) T const & operator()(std::size_t row, std::size_t col) const { return _data[row][col]; } T & operator()(std::size_t row, std::size_t col) { return _data[row][col]; } row_reference operator[](std::size_t row) const { return {_data[row]}; } row_reference operator[](std::size_t row) { return {_data[row]}; } row_reference at(std::size_t row) const { expects(row < R, std::out_of_range, "row index out of range"); return operator[](row); } row_reference at(std::size_t row) { expects(row < R, std::out_of_range, "row index out of range"); return operator[](row); } value_type const & at(std::size_t row, std::size_t col) const { expects(row < R && col < C, std::out_of_range, "coordinates out of range"); return _data[row][col]; } value_type & at(std::size_t row, std::size_t col) { expects(row < R && col < C, std::out_of_range, "coordinates out of range"); return _data[row][col]; } matrix & operator+=(matrix const & other) { MATRIX_FOR_EACH(i, j) { _data[i][j] += other(i, j); } return *this; } matrix operator+(matrix const & other) const { return matrix{*this} += other; } matrix & operator-=(matrix const & other) { MATRIX_FOR_EACH(i, j) { _data[i][j] -= other(i, j); } return *this; } matrix operator-(matrix const & other) const { return matrix{*this} -= other; } matrix operator-() const { matrix tmp; MATRIX_FOR_EACH(i, j) { tmp(i, j) = -_data[i][j]; } return tmp; } vector::vector operator*(vector::vector const & vec) const { vector::vector rval; MATRIX_FOR_EACH(i, j) { rval[i] += _data[i][j] * vec[j]; } return rval; } template matrix operator*(matrix const & other) const { matrix rval; MATRIX_FOR_EACH(i, j) { for (size_t k = 0; k < C2; ++k) { rval(i, k) += _data[i][j] * other(j, k); } } return rval; } matrix & operator*=(T c) { MATRIX_FOR_EACH(i, j) { _data[i][j] *= c; } return *this; } template MATRIX_DISABLE_IF_MATRIX(M, mul_t, R, C) operator*(M c) const { return matrix, R, C>{*this} *= c; } template friend MATRIX_DISABLE_IF_MATRIX(M, mul_t, R, C) operator*(M c, matrix const & matr) { return matrix, R, C>{matr} *= c; } template matrix, R, C> & operator/=(M c) { MATRIX_FOR_EACH(i, j) { _data[i][j] /= c; } return *this; } template matrix, R, C> operator/(M c) const { return matrix, R, C>{*this} /= c; } bool operator==(matrix const & other) const { MATRIX_FOR_EACH(i, j) { if (_data[i][j] != other(i, j)) { return false; } } return true; } bool operator!=(matrix const & other) const { return !operator==(other); } private: value_type _data[R][C] = {value_type()}; }; MATRIX_CTOR_DEDUCTION(1); MATRIX_CTOR_DEDUCTION(2); MATRIX_CTOR_DEDUCTION(3); MATRIX_CTOR_DEDUCTION(4); } #include "math/matrix/undef.h"