matrix.hpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include "vector/vector.hpp"
  9. #include "expect/expect.hpp"
  10. namespace math { namespace matrix {
  11. template <typename T>
  12. struct is_matrix { static const constexpr bool value = false; };
  13. template <typename T, std::size_t R, std::size_t C>
  14. class matrix;
  15. template <typename T, size_t R, size_t C>
  16. struct is_matrix<matrix<T, R, C> > { static const constexpr bool value = true; };
  17. namespace concat_strategy {
  18. struct {} horizonal;
  19. using horizontal_concat_t = decltype(horizonal);
  20. struct {} vertical;
  21. using vertical_concat_t = decltype(vertical);
  22. struct {} diagonal;
  23. using diagonal_concat_t = decltype(diagonal);
  24. };
  25. #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \
  26. typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
  27. #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) for (size_t i = 0; i < i_max; ++i) for (size_t j = 0; j < j_max; ++j)
  28. #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
  29. template <typename T, std::size_t R, std::size_t C>
  30. class matrix {
  31. public:
  32. using value_type = T;
  33. template <typename M>
  34. using mul_t = decltype(std::declval<T>()*std::declval<M>());
  35. template <typename M>
  36. using div_t = decltype(std::declval<T>()/std::declval<M>());
  37. public:
  38. template <typename S>
  39. class row_reference {
  40. public:
  41. row_reference(S ( & h )[C]) : _handle(h) {}
  42. row_reference(row_reference const &) = delete;
  43. row_reference(row_reference &&) = default;
  44. row_reference & operator=(row_reference const & other) {
  45. return operator=<S>(other);
  46. }
  47. template <typename S2>
  48. row_reference & operator=(row_reference<S2> const & other) {
  49. VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
  50. return *this;
  51. }
  52. S const & operator[](std::size_t col) const { return _handle[col]; }
  53. S & operator[](std::size_t col) { return _handle[col]; }
  54. S const & at(std::size_t col) const {
  55. expects(col < C, std::out_of_range, "column index out of range");
  56. return operator[](col);
  57. }
  58. S & at(std::size_t col) {
  59. expects(col < C, std::out_of_range, "column index out of range");
  60. return operator[](col);
  61. }
  62. private:
  63. S ( & _handle )[C];
  64. };
  65. matrix() = default;
  66. matrix(std::array<std::array<T, C>, R> const & init) {
  67. MATRIX_FOR_EACH(i, j) {
  68. _data[i][j] = init[i][j];
  69. }
  70. }
  71. template <size_t N>
  72. matrix(vector::vector<typename std::enable_if<C == 1 && N == R, T>::type, N> const & other) {
  73. VECTOR_FOR_EACH(i) {
  74. _data[i][0] = other[i];
  75. }
  76. }
  77. matrix(matrix const& other) {
  78. *this = other;
  79. }
  80. matrix(matrix && other) {
  81. *this = std::move(other);
  82. }
  83. matrix & operator=(matrix const& other) {
  84. MATRIX_FOR_EACH(i, j) {
  85. _data[i][j] = other._data[i][j];
  86. }
  87. return *this;
  88. }
  89. matrix & operator=(matrix && other) {
  90. MATRIX_FOR_EACH(i, j) {
  91. _data[i][j] = std::move(other._data[i][j]);
  92. }
  93. return *this;
  94. }
  95. template <size_t R2, size_t C2>
  96. matrix(matrix<T, R2, C2> const & other) {
  97. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  98. _data[i][j] = other(i, j);
  99. }
  100. }
  101. matrix<T, C, R> transpose() const {
  102. matrix<T, C, R> out;
  103. MATRIX_FOR_EACH(i, j) { out(j,i) = _data[i][j]; }
  104. return out;
  105. }
  106. template <size_t C2>
  107. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other, concat_strategy::horizontal_concat_t) const {
  108. matrix<T, R, C + C2> accum{*this};
  109. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  110. return accum;
  111. }
  112. template <size_t R2>
  113. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other, concat_strategy::vertical_concat_t) const {
  114. matrix<T, R + R2, C> accum{*this};
  115. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  116. return accum;
  117. }
  118. template <size_t R2, size_t C2>
  119. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other, concat_strategy::diagonal_concat_t) const {
  120. matrix<T, R + R2, C + C2> accum{*this};
  121. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  122. return accum;
  123. }
  124. T const & operator()(std::size_t row, std::size_t col) const {
  125. return _data[row][col];
  126. }
  127. T & operator()(std::size_t row, std::size_t col) {
  128. return _data[row][col];
  129. }
  130. row_reference<const T> operator[](std::size_t row) const {
  131. return { _data[row] };
  132. }
  133. row_reference<T> operator[](std::size_t row) {
  134. return { _data[row] };
  135. }
  136. row_reference<const T> at(std::size_t row) const {
  137. expects(row >= R, std::out_of_range, "row index out of range");
  138. return operator[](row);
  139. }
  140. row_reference<T> at(std::size_t row) {
  141. expects(row >= R, std::out_of_range, "row index out of range");
  142. return operator[](row);
  143. }
  144. value_type const & at(std::size_t row, std::size_t col) const {
  145. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  146. return _data[row][col];
  147. }
  148. value_type & at(std::size_t row, std::size_t col) {
  149. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  150. return _data[row][col];
  151. }
  152. matrix& operator+=(matrix const & other) {
  153. MATRIX_FOR_EACH(i, j) {
  154. _data[i][j] += other[i][j];
  155. }
  156. return *this;
  157. }
  158. matrix operator+(matrix const & other) const {
  159. return matrix{*this} += other;
  160. }
  161. matrix& operator-=(matrix const & other) {
  162. MATRIX_FOR_EACH(i, j) {
  163. _data[i][j] -= other[i][j];
  164. }
  165. return *this;
  166. }
  167. matrix operator-(matrix const & other) const {
  168. return matrix{*this} -= other;
  169. }
  170. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  171. vector::vector<T, C> rval;
  172. MATRIX_FOR_EACH(i, j) {
  173. rval[i] += _data[i][j] * vec[j];
  174. }
  175. return rval;
  176. }
  177. template <std::size_t C2>
  178. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  179. matrix<T, R, C2> rval;
  180. MATRIX_FOR_EACH(i, j) {
  181. for (size_t k = 0; k < C2; ++k) {
  182. rval[i][k] += _data[i][j] * other[j][k];
  183. }
  184. }
  185. return rval;
  186. }
  187. matrix<T, R, C>& operator*=(T c) {
  188. MATRIX_FOR_EACH(i, j) {
  189. _data[i][j] *= c;
  190. }
  191. return *this;
  192. }
  193. template <typename M>
  194. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c) const {
  195. return matrix<mul_t<M>, R, C>{*this} *= c;
  196. }
  197. template <typename M>
  198. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c, matrix const& matr) {
  199. return matrix<mul_t<M>, R, C>{matr} *= c;
  200. }
  201. template <typename M>
  202. matrix<div_t<M>, R, C>& operator/=(M c) {
  203. MATRIX_FOR_EACH(i, j) {
  204. _data[i][j] /= c;
  205. }
  206. return *this;
  207. }
  208. template <typename M>
  209. matrix<div_t<M>, R, C> operator/(M c) const {
  210. return matrix<mul_t<M>, R, C>{*this} /= c;
  211. }
  212. bool operator==(matrix const & other) const {
  213. MATRIX_FOR_EACH(i, j) {
  214. if (_data[i][j] != other._data[i][j]) {
  215. return false;
  216. }
  217. }
  218. return true;
  219. }
  220. bool operator!=(matrix const & other) const {
  221. return !operator==(other);
  222. }
  223. private:
  224. value_type _data[R][C] = {value_type()};
  225. };
  226. } }