matrix.hpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include "expect/expect.hpp"
  9. #include "math/vector/vector.hpp"
  10. namespace math { namespace matrix {
  11. template <typename T> struct is_matrix {
  12. static const constexpr bool value = false;
  13. };
  14. template <typename T, std::size_t R, std::size_t C> class matrix;
  15. template <typename T, size_t R, size_t C> struct is_matrix<matrix<T, R, C>> {
  16. static const constexpr bool value = true;
  17. };
  18. template <typename T, size_t N> struct is_matrix<vector::vector<T, N>> {
  19. static const constexpr bool value = true;
  20. };
  21. namespace concat_strategy {
  22. struct {
  23. } horizonal;
  24. using horizontal_concat_t = decltype(horizonal);
  25. struct {
  26. } vertical;
  27. using vertical_concat_t = decltype(vertical);
  28. struct {
  29. } diagonal;
  30. using diagonal_concat_t = decltype(diagonal);
  31. };
  32. #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \
  33. typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c>>::type
  34. #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) \
  35. for (size_t i = 0; i < i_max; ++i) \
  36. for (size_t j = 0; j < j_max; ++j)
  37. #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
  38. template <typename T, std::size_t C> class row_reference {
  39. private:
  40. row_reference(T (&h)[C]) : _handle(h) {}
  41. public:
  42. row_reference(row_reference const &) = delete;
  43. row_reference(row_reference &&) = default;
  44. template <typename S>
  45. row_reference & operator=(row_reference<S, C> const & other) {
  46. VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
  47. return *this;
  48. }
  49. T const & operator[](std::size_t col) const { return _handle[col]; }
  50. T & operator[](std::size_t col) { return _handle[col]; }
  51. T const & at(std::size_t col) const {
  52. expects(col < C, std::out_of_range, "column index out of range");
  53. return operator[](col);
  54. }
  55. T & at(std::size_t col) {
  56. expects(col < C, std::out_of_range, "column index out of range");
  57. return operator[](col);
  58. }
  59. private:
  60. template <typename _T, std::size_t R, std::size_t _C> friend class matrix;
  61. T (&_handle)[C];
  62. };
  63. template <typename T, std::size_t R, std::size_t C> class matrix {
  64. public:
  65. using value_type = T;
  66. template <typename M>
  67. using mul_t = decltype(std::declval<T>() * std::declval<M>());
  68. template <typename M>
  69. using div_t = decltype(std::declval<T>() / std::declval<M>());
  70. public:
  71. matrix() = default;
  72. matrix(std::array<std::array<T, C>, R> const & init) {
  73. MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; }
  74. }
  75. template <size_t N>
  76. matrix(vector::vector<typename std::enable_if<C == 1 && N == R, T>::type,
  77. N> const & other) {
  78. VECTOR_FOR_EACH(i) { _data[i][0] = other[i]; }
  79. }
  80. matrix(matrix const & other) { *this = other; }
  81. matrix(matrix && other) { *this = std::move(other); }
  82. matrix & operator=(matrix const & other) {
  83. MATRIX_FOR_EACH(i, j) { _data[i][j] = other(i, j); }
  84. return *this;
  85. }
  86. matrix & operator=(matrix && other) {
  87. MATRIX_FOR_EACH(i, j) { _data[i][j] = std::move(other(i, j)); }
  88. return *this;
  89. }
  90. template <size_t R2, size_t C2> matrix(matrix<T, R2, C2> const & other) {
  91. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  92. _data[i][j] = other(i, j);
  93. }
  94. }
  95. matrix<T, C, R> transpose() const {
  96. matrix<T, C, R> out;
  97. MATRIX_FOR_EACH(i, j) { out(j, i) = _data[i][j]; }
  98. return out;
  99. }
  100. template <size_t C2>
  101. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other,
  102. concat_strategy::horizontal_concat_t) const {
  103. matrix<T, R, C + C2> accum{*this};
  104. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  105. return accum;
  106. }
  107. template <size_t R2>
  108. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other,
  109. concat_strategy::vertical_concat_t) const {
  110. matrix<T, R + R2, C> accum{*this};
  111. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  112. return accum;
  113. }
  114. template <size_t R2, size_t C2>
  115. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other,
  116. concat_strategy::diagonal_concat_t) const {
  117. matrix<T, R + R2, C + C2> accum{*this};
  118. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  119. return accum;
  120. }
  121. T const & operator()(std::size_t row, std::size_t col) const {
  122. return _data[row][col];
  123. }
  124. T & operator()(std::size_t row, std::size_t col) { return _data[row][col]; }
  125. row_reference<const T, C> operator[](std::size_t row) const {
  126. return {_data[row]};
  127. }
  128. row_reference<T, C> operator[](std::size_t row) { return {_data[row]}; }
  129. row_reference<const T, C> at(std::size_t row) const {
  130. expects(row < R, std::out_of_range, "row index out of range");
  131. return operator[](row);
  132. }
  133. row_reference<T, C> at(std::size_t row) {
  134. expects(row < R, std::out_of_range, "row index out of range");
  135. return operator[](row);
  136. }
  137. value_type const & at(std::size_t row, std::size_t col) const {
  138. expects(row < R && col < C, std::out_of_range,
  139. "coordinates out of range");
  140. return _data[row][col];
  141. }
  142. value_type & at(std::size_t row, std::size_t col) {
  143. expects(row < R && col < C, std::out_of_range,
  144. "coordinates out of range");
  145. return _data[row][col];
  146. }
  147. matrix & operator+=(matrix const & other) {
  148. MATRIX_FOR_EACH(i, j) { _data[i][j] += other(i, j); }
  149. return *this;
  150. }
  151. matrix operator+(matrix const & other) const {
  152. return matrix{*this} += other;
  153. }
  154. matrix & operator-=(matrix const & other) {
  155. MATRIX_FOR_EACH(i, j) { _data[i][j] -= other(i, j); }
  156. return *this;
  157. }
  158. matrix operator-(matrix const & other) const {
  159. return matrix{*this} -= other;
  160. }
  161. matrix operator-() const {
  162. matrix tmp;
  163. MATRIX_FOR_EACH(i, j) { tmp(i, j) = -_data[i][j]; }
  164. return tmp;
  165. }
  166. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  167. vector::vector<T, C> rval;
  168. MATRIX_FOR_EACH(i, j) { rval[i] += _data[i][j] * vec[j]; }
  169. return rval;
  170. }
  171. template <std::size_t C2>
  172. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  173. matrix<T, R, C2> rval;
  174. MATRIX_FOR_EACH(i, j) {
  175. for (size_t k = 0; k < C2; ++k) {
  176. rval(i, k) += _data[i][j] * other(j, k);
  177. }
  178. }
  179. return rval;
  180. }
  181. matrix<T, R, C> & operator*=(T c) {
  182. MATRIX_FOR_EACH(i, j) { _data[i][j] *= c; }
  183. return *this;
  184. }
  185. template <typename M>
  186. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C)
  187. operator*(M c) const {
  188. return matrix<mul_t<M>, R, C>{*this} *= c;
  189. }
  190. template <typename M>
  191. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C)
  192. operator*(M c, matrix const & matr) {
  193. return matrix<mul_t<M>, R, C>{matr} *= c;
  194. }
  195. template <typename M> matrix<div_t<M>, R, C> & operator/=(M c) {
  196. MATRIX_FOR_EACH(i, j) { _data[i][j] /= c; }
  197. return *this;
  198. }
  199. template <typename M> matrix<div_t<M>, R, C> operator/(M c) const {
  200. return matrix<mul_t<M>, R, C>{*this} /= c;
  201. }
  202. bool operator==(matrix const & other) const {
  203. MATRIX_FOR_EACH(i, j) {
  204. if (_data[i][j] != other(i, j)) { return false; }
  205. }
  206. return true;
  207. }
  208. bool operator!=(matrix const & other) const { return !operator==(other); }
  209. private:
  210. value_type _data[R][C] = {value_type()};
  211. };
  212. }}