matrix.hpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include "vector/vector.hpp"
  9. #include "expect/expect.hpp"
  10. namespace math { namespace matrix {
  11. template <typename T>
  12. struct is_matrix { static const constexpr bool value = false; };
  13. template <typename T, std::size_t R, std::size_t C>
  14. class matrix;
  15. template <typename T, size_t R, size_t C>
  16. struct is_matrix<matrix<T, R, C> > { static const constexpr bool value = true; };
  17. template <typename T, size_t N>
  18. struct is_matrix<vector::vector<T, N>> { static const constexpr bool value = true; };
  19. namespace concat_strategy {
  20. struct {} horizonal;
  21. using horizontal_concat_t = decltype(horizonal);
  22. struct {} vertical;
  23. using vertical_concat_t = decltype(vertical);
  24. struct {} diagonal;
  25. using diagonal_concat_t = decltype(diagonal);
  26. };
  27. #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \
  28. typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
  29. #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) for (size_t i = 0; i < i_max; ++i) for (size_t j = 0; j < j_max; ++j)
  30. #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
  31. template <typename T, std::size_t C>
  32. class row_reference {
  33. private:
  34. row_reference(T ( & h )[C]) : _handle(h) {}
  35. public:
  36. row_reference(row_reference const &) = delete;
  37. row_reference(row_reference &&) = default;
  38. template <typename S>
  39. row_reference & operator=(row_reference<S, C> const & other) {
  40. VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
  41. return *this;
  42. }
  43. T const & operator[](std::size_t col) const { return _handle[col]; }
  44. T & operator[](std::size_t col) { return _handle[col]; }
  45. T const & at(std::size_t col) const {
  46. expects(col < C, std::out_of_range, "column index out of range");
  47. return operator[](col);
  48. }
  49. T & at(std::size_t col) {
  50. expects(col < C, std::out_of_range, "column index out of range");
  51. return operator[](col);
  52. }
  53. private:
  54. template <typename _T, std::size_t R, std::size_t _C> friend class matrix;
  55. T ( & _handle )[C];
  56. };
  57. template <typename T, std::size_t R, std::size_t C>
  58. class matrix {
  59. public:
  60. using value_type = T;
  61. template <typename M>
  62. using mul_t = decltype(std::declval<T>()*std::declval<M>());
  63. template <typename M>
  64. using div_t = decltype(std::declval<T>()/std::declval<M>());
  65. public:
  66. matrix() = default;
  67. matrix(std::array<std::array<T, C>, R> const & init) {
  68. MATRIX_FOR_EACH(i, j) {
  69. _data[i][j] = init[i][j];
  70. }
  71. }
  72. template <size_t N>
  73. matrix(vector::vector<typename std::enable_if<C == 1 && N == R, T>::type, N> const & other) {
  74. VECTOR_FOR_EACH(i) {
  75. _data[i][0] = other[i];
  76. }
  77. }
  78. matrix(matrix const& other) {
  79. *this = other;
  80. }
  81. matrix(matrix && other) {
  82. *this = std::move(other);
  83. }
  84. matrix & operator=(matrix const& other) {
  85. MATRIX_FOR_EACH(i, j) {
  86. _data[i][j] = other(i, j);
  87. }
  88. return *this;
  89. }
  90. matrix & operator=(matrix && other) {
  91. MATRIX_FOR_EACH(i, j) {
  92. _data[i][j] = std::move(other(i, j));
  93. }
  94. return *this;
  95. }
  96. template <size_t R2, size_t C2>
  97. matrix(matrix<T, R2, C2> const & other) {
  98. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  99. _data[i][j] = other(i, j);
  100. }
  101. }
  102. matrix<T, C, R> transpose() const {
  103. matrix<T, C, R> out;
  104. MATRIX_FOR_EACH(i, j) { out(j,i) = _data[i][j]; }
  105. return out;
  106. }
  107. template <size_t C2>
  108. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other, concat_strategy::horizontal_concat_t) const {
  109. matrix<T, R, C + C2> accum{*this};
  110. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  111. return accum;
  112. }
  113. template <size_t R2>
  114. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other, concat_strategy::vertical_concat_t) const {
  115. matrix<T, R + R2, C> accum{*this};
  116. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  117. return accum;
  118. }
  119. template <size_t R2, size_t C2>
  120. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other, concat_strategy::diagonal_concat_t) const {
  121. matrix<T, R + R2, C + C2> accum{*this};
  122. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  123. return accum;
  124. }
  125. T const & operator()(std::size_t row, std::size_t col) const {
  126. return _data[row][col];
  127. }
  128. T & operator()(std::size_t row, std::size_t col) {
  129. return _data[row][col];
  130. }
  131. row_reference<const T, C> operator[](std::size_t row) const {
  132. return { _data[row] };
  133. }
  134. row_reference<T, C> operator[](std::size_t row) {
  135. return { _data[row] };
  136. }
  137. row_reference<const T, C> at(std::size_t row) const {
  138. expects(row < R, std::out_of_range, "row index out of range");
  139. return operator[](row);
  140. }
  141. row_reference<T, C> at(std::size_t row) {
  142. expects(row < R, std::out_of_range, "row index out of range");
  143. return operator[](row);
  144. }
  145. value_type const & at(std::size_t row, std::size_t col) const {
  146. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  147. return _data[row][col];
  148. }
  149. value_type & at(std::size_t row, std::size_t col) {
  150. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  151. return _data[row][col];
  152. }
  153. matrix& operator+=(matrix const & other) {
  154. MATRIX_FOR_EACH(i, j) {
  155. _data[i][j] += other(i, j);
  156. }
  157. return *this;
  158. }
  159. matrix operator+(matrix const & other) const {
  160. return matrix{*this} += other;
  161. }
  162. matrix& operator-=(matrix const & other) {
  163. MATRIX_FOR_EACH(i, j) {
  164. _data[i][j] -= other(i, j);
  165. }
  166. return *this;
  167. }
  168. matrix operator-(matrix const & other) const {
  169. return matrix{*this} -= other;
  170. }
  171. matrix operator-() const {
  172. matrix tmp;
  173. MATRIX_FOR_EACH(i, j) {
  174. tmp(i, j) = -_data[i][j];
  175. }
  176. return tmp;
  177. }
  178. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  179. vector::vector<T, C> rval;
  180. MATRIX_FOR_EACH(i, j) {
  181. rval[i] += _data[i][j] * vec[j];
  182. }
  183. return rval;
  184. }
  185. template <std::size_t C2>
  186. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  187. matrix<T, R, C2> rval;
  188. MATRIX_FOR_EACH(i, j) {
  189. for (size_t k = 0; k < C2; ++k) {
  190. rval(i, k) += _data[i][j] * other(j, k);
  191. }
  192. }
  193. return rval;
  194. }
  195. matrix<T, R, C>& operator*=(T c) {
  196. MATRIX_FOR_EACH(i, j) {
  197. _data[i][j] *= c;
  198. }
  199. return *this;
  200. }
  201. template <typename M>
  202. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c) const {
  203. return matrix<mul_t<M>, R, C>{*this} *= c;
  204. }
  205. template <typename M>
  206. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c, matrix const& matr) {
  207. return matrix<mul_t<M>, R, C>{matr} *= c;
  208. }
  209. template <typename M>
  210. matrix<div_t<M>, R, C>& operator/=(M c) {
  211. MATRIX_FOR_EACH(i, j) {
  212. _data[i][j] /= c;
  213. }
  214. return *this;
  215. }
  216. template <typename M>
  217. matrix<div_t<M>, R, C> operator/(M c) const {
  218. return matrix<mul_t<M>, R, C>{*this} /= c;
  219. }
  220. bool operator==(matrix const & other) const {
  221. MATRIX_FOR_EACH(i, j) {
  222. if (_data[i][j] != other(i, j)) {
  223. return false;
  224. }
  225. }
  226. return true;
  227. }
  228. bool operator!=(matrix const & other) const {
  229. return !operator==(other);
  230. }
  231. private:
  232. value_type _data[R][C] = {value_type()};
  233. };
  234. } }