matrix.hpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include "math/vector/vector.hpp"
  9. #include "expect/expect.hpp"
  10. namespace math { namespace matrix {
  11. template <typename T>
  12. struct is_matrix { static const constexpr bool value = false; };
  13. template <typename T, std::size_t R, std::size_t C>
  14. class matrix;
  15. template <typename T, size_t R, size_t C>
  16. struct is_matrix<matrix<T, R, C> > { static const constexpr bool value = true; };
  17. namespace concat_strategy {
  18. struct {} horizonal;
  19. using horizontal_concat_t = decltype(horizonal);
  20. struct {} vertical;
  21. using vertical_concat_t = decltype(vertical);
  22. struct {} diagonal;
  23. using diagonal_concat_t = decltype(diagonal);
  24. };
  25. #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \
  26. typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
  27. #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) for (size_t i = 0; i < i_max; ++i) for (size_t j = 0; j < j_max; ++j)
  28. #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
  29. template <typename T, std::size_t R, std::size_t C>
  30. class matrix {
  31. public:
  32. using value_type = T;
  33. template <typename M>
  34. using mul_t = decltype(std::declval<T>()*std::declval<M>());
  35. template <typename M>
  36. using div_t = decltype(std::declval<T>()/std::declval<M>());
  37. public:
  38. template <typename S>
  39. class row_reference {
  40. public:
  41. row_reference(S ( & h )[C]) : _handle(h) {}
  42. row_reference(row_reference const &) = delete;
  43. row_reference & operator=(row_reference const & other) {
  44. return operator=<S>(other);
  45. }
  46. template <typename S2>
  47. row_reference & operator=(row_reference<S2> const & other) {
  48. VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
  49. return *this;
  50. }
  51. S const & operator[](std::size_t col) const { return _handle[col]; }
  52. S & operator[](std::size_t col) { return _handle[col]; }
  53. S const & at(std::size_t col) const {
  54. expects(col < C, std::out_of_range, "column index out of range");
  55. return operator[](col);
  56. }
  57. S & at(std::size_t col) {
  58. expects(col < C, std::out_of_range, "column index out of range");
  59. return operator[](col);
  60. }
  61. private:
  62. S ( & _handle )[C];
  63. };
  64. matrix() = default;
  65. matrix(std::array<std::array<T, C>, R> const & init) {
  66. MATRIX_FOR_EACH(i, j) {
  67. _data[i][j] = init[i][j];
  68. }
  69. }
  70. matrix(std::initializer_list<std::array<T, C>> const & init) {
  71. expects(init.size() == R, "initializer size mismatch");
  72. size_t i = 0;
  73. for (auto it = init.begin(), end = init.end(); it != end && i < R; ++it, ++i) {
  74. for (size_t j = 0; j < C; ++j) {
  75. _data[i][j] = (*it)[j];
  76. }
  77. }
  78. }
  79. template <size_t N>
  80. matrix(vector::vector<typename std::enable_if<C == 1 && N == R, T>::type, N> const & other) {
  81. VECTOR_FOR_EACH(i) {
  82. _data[i][0] = other[i];
  83. }
  84. }
  85. matrix(matrix const& other) {
  86. *this = other;
  87. }
  88. matrix(matrix && other) {
  89. *this = std::move(other);
  90. }
  91. matrix & operator=(matrix const& other) {
  92. MATRIX_FOR_EACH(i, j) {
  93. _data[i][j] = other._data[i][j];
  94. }
  95. return *this;
  96. }
  97. matrix & operator=(matrix && other) {
  98. MATRIX_FOR_EACH(i, j) {
  99. _data[i][j] = std::move(other._data[i][j]);
  100. }
  101. return *this;
  102. }
  103. template <size_t R2, size_t C2>
  104. matrix(matrix<T, R2, C2> const & other) {
  105. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  106. _data[i][j] = other(i, j);
  107. }
  108. }
  109. matrix<T, C, R> transpose() const {
  110. matrix<T, C, R> out;
  111. MATRIX_FOR_EACH(i, j) { out(j,i) = _data[i][j]; }
  112. return out;
  113. }
  114. template <size_t C2>
  115. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other, concat_strategy::horizontal_concat_t) const {
  116. matrix<T, R, C + C2> accum{*this};
  117. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  118. return accum;
  119. }
  120. template <size_t R2>
  121. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other, concat_strategy::vertical_concat_t) const {
  122. matrix<T, R + R2, C> accum{*this};
  123. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  124. return accum;
  125. }
  126. template <size_t R2, size_t C2>
  127. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other, concat_strategy::diagonal_concat_t) const {
  128. matrix<T, R + R2, C + C2> accum{*this};
  129. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  130. return accum;
  131. }
  132. T const & operator()(std::size_t row, std::size_t col) const {
  133. return _data[row][col];
  134. }
  135. T & operator()(std::size_t row, std::size_t col) {
  136. return _data[row][col];
  137. }
  138. row_reference<const T> operator[](std::size_t row) const {
  139. return { _data[row] };
  140. }
  141. row_reference<T> operator[](std::size_t row) {
  142. return { _data[row] };
  143. }
  144. row_reference<const T> at(std::size_t row) const {
  145. expects(row >= R, std::out_of_range, "row index out of range");
  146. return operator[](row);
  147. }
  148. row_reference<T> at(std::size_t row) {
  149. expects(row >= R, std::out_of_range, "row index out of range");
  150. return operator[](row);
  151. }
  152. value_type const & at(std::size_t row, std::size_t col) const {
  153. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  154. return _data[row][col];
  155. }
  156. value_type & at(std::size_t row, std::size_t col) {
  157. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  158. return _data[row][col];
  159. }
  160. matrix& operator+=(matrix const & other) {
  161. MATRIX_FOR_EACH(i, j) {
  162. _data[i][j] += other[i][j];
  163. }
  164. return *this;
  165. }
  166. matrix operator+(matrix const & other) const {
  167. return matrix{*this} += other;
  168. }
  169. matrix& operator-=(matrix const & other) {
  170. MATRIX_FOR_EACH(i, j) {
  171. _data[i][j] -= other[i][j];
  172. }
  173. return *this;
  174. }
  175. matrix operator-(matrix const & other) const {
  176. return matrix{*this} -= other;
  177. }
  178. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  179. vector::vector<T, C> rval;
  180. MATRIX_FOR_EACH(i, j) {
  181. rval[i] += _data[i][j] * vec[j];
  182. }
  183. return rval;
  184. }
  185. template <std::size_t C2>
  186. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  187. matrix<T, R, C2> rval;
  188. MATRIX_FOR_EACH(i, j) {
  189. for (size_t k = 0; k < C2; ++k) {
  190. rval[i][k] += _data[i][j] * other[j][k];
  191. }
  192. }
  193. return rval;
  194. }
  195. matrix<T, R, C>& operator*=(T c) {
  196. MATRIX_FOR_EACH(i, j) {
  197. _data[i][j] *= c;
  198. }
  199. return *this;
  200. }
  201. template <typename M>
  202. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c) const {
  203. return matrix<mul_t<M>, R, C>{*this} *= c;
  204. }
  205. template <typename M>
  206. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c, matrix const& matr) {
  207. return matrix<mul_t<M>, R, C>{matr} *= c;
  208. }
  209. template <typename M>
  210. matrix<div_t<M>, R, C>& operator/=(M c) {
  211. MATRIX_FOR_EACH(i, j) {
  212. _data[i][j] /= c;
  213. }
  214. return *this;
  215. }
  216. template <typename M>
  217. matrix<div_t<M>, R, C> operator/(M c) const {
  218. return matrix<mul_t<M>, R, C>{*this} /= c;
  219. }
  220. bool operator==(matrix const & other) const {
  221. MATRIX_FOR_EACH(i, j) {
  222. if (_data[i][j] != other._data[i][j]) {
  223. return false;
  224. }
  225. }
  226. return true;
  227. }
  228. bool operator!=(matrix const & other) const {
  229. return !operator==(other);
  230. }
  231. private:
  232. value_type _data[R][C] = {value_type()};
  233. };
  234. } }