matrix.hpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include "vector/vector.hpp"
  9. #include "expect/expect.hpp"
  10. namespace math { namespace matrix {
  11. template <typename T>
  12. struct is_matrix { static const constexpr bool value = false; };
  13. template <typename T, std::size_t R, std::size_t C>
  14. class matrix;
  15. template <typename T, size_t R, size_t C>
  16. struct is_matrix<matrix<T, R, C> > { static const constexpr bool value = true; };
  17. namespace concat_strategy {
  18. struct {} horizonal;
  19. using horizontal_concat_t = decltype(horizonal);
  20. struct {} vertical;
  21. using vertical_concat_t = decltype(vertical);
  22. struct {} diagonal;
  23. using diagonal_concat_t = decltype(diagonal);
  24. };
  25. #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \
  26. typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
  27. #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) for (size_t i = 0; i < i_max; ++i) for (size_t j = 0; j < j_max; ++j)
  28. #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
  29. template <typename T, std::size_t C>
  30. class row_reference {
  31. private:
  32. row_reference(T ( & h )[C]) : _handle(h) {}
  33. public:
  34. row_reference(row_reference const &) = delete;
  35. row_reference(row_reference &&) = default;
  36. template <typename S>
  37. row_reference & operator=(row_reference<S, C> const & other) {
  38. VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
  39. return *this;
  40. }
  41. T const & operator[](std::size_t col) const { return _handle[col]; }
  42. T & operator[](std::size_t col) { return _handle[col]; }
  43. T const & at(std::size_t col) const {
  44. expects(col < C, std::out_of_range, "column index out of range");
  45. return operator[](col);
  46. }
  47. T & at(std::size_t col) {
  48. expects(col < C, std::out_of_range, "column index out of range");
  49. return operator[](col);
  50. }
  51. private:
  52. template <typename _T, std::size_t R, std::size_t _C> friend class matrix;
  53. T ( & _handle )[C];
  54. };
  55. template <typename T, std::size_t R, std::size_t C>
  56. class matrix {
  57. public:
  58. using value_type = T;
  59. template <typename M>
  60. using mul_t = decltype(std::declval<T>()*std::declval<M>());
  61. template <typename M>
  62. using div_t = decltype(std::declval<T>()/std::declval<M>());
  63. public:
  64. matrix() = default;
  65. matrix(std::array<std::array<T, C>, R> const & init) {
  66. MATRIX_FOR_EACH(i, j) {
  67. _data[i][j] = init[i][j];
  68. }
  69. }
  70. template <size_t N>
  71. matrix(vector::vector<typename std::enable_if<C == 1 && N == R, T>::type, N> const & other) {
  72. VECTOR_FOR_EACH(i) {
  73. _data[i][0] = other[i];
  74. }
  75. }
  76. matrix(matrix const& other) {
  77. *this = other;
  78. }
  79. matrix(matrix && other) {
  80. *this = std::move(other);
  81. }
  82. matrix & operator=(matrix const& other) {
  83. MATRIX_FOR_EACH(i, j) {
  84. _data[i][j] = other(i, j);
  85. }
  86. return *this;
  87. }
  88. matrix & operator=(matrix && other) {
  89. MATRIX_FOR_EACH(i, j) {
  90. _data[i][j] = std::move(other(i, j));
  91. }
  92. return *this;
  93. }
  94. template <size_t R2, size_t C2>
  95. matrix(matrix<T, R2, C2> const & other) {
  96. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  97. _data[i][j] = other(i, j);
  98. }
  99. }
  100. matrix<T, C, R> transpose() const {
  101. matrix<T, C, R> out;
  102. MATRIX_FOR_EACH(i, j) { out(j,i) = _data[i][j]; }
  103. return out;
  104. }
  105. template <size_t C2>
  106. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other, concat_strategy::horizontal_concat_t) const {
  107. matrix<T, R, C + C2> accum{*this};
  108. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  109. return accum;
  110. }
  111. template <size_t R2>
  112. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other, concat_strategy::vertical_concat_t) const {
  113. matrix<T, R + R2, C> accum{*this};
  114. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  115. return accum;
  116. }
  117. template <size_t R2, size_t C2>
  118. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other, concat_strategy::diagonal_concat_t) const {
  119. matrix<T, R + R2, C + C2> accum{*this};
  120. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  121. return accum;
  122. }
  123. T const & operator()(std::size_t row, std::size_t col) const {
  124. return _data[row][col];
  125. }
  126. T & operator()(std::size_t row, std::size_t col) {
  127. return _data[row][col];
  128. }
  129. row_reference<const T, C> operator[](std::size_t row) const {
  130. return { _data[row] };
  131. }
  132. row_reference<T, C> operator[](std::size_t row) {
  133. return { _data[row] };
  134. }
  135. row_reference<const T, C> at(std::size_t row) const {
  136. expects(row < R, std::out_of_range, "row index out of range");
  137. return operator[](row);
  138. }
  139. row_reference<T, C> at(std::size_t row) {
  140. expects(row < R, std::out_of_range, "row index out of range");
  141. return operator[](row);
  142. }
  143. value_type const & at(std::size_t row, std::size_t col) const {
  144. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  145. return _data[row][col];
  146. }
  147. value_type & at(std::size_t row, std::size_t col) {
  148. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  149. return _data[row][col];
  150. }
  151. matrix& operator+=(matrix const & other) {
  152. MATRIX_FOR_EACH(i, j) {
  153. _data[i][j] += other(i, j);
  154. }
  155. return *this;
  156. }
  157. matrix operator+(matrix const & other) const {
  158. return matrix{*this} += other;
  159. }
  160. matrix& operator-=(matrix const & other) {
  161. MATRIX_FOR_EACH(i, j) {
  162. _data[i][j] -= other(i, j);
  163. }
  164. return *this;
  165. }
  166. matrix operator-(matrix const & other) const {
  167. return matrix{*this} -= other;
  168. }
  169. matrix operator-() const {
  170. matrix tmp;
  171. MATRIX_FOR_EACH(i, j) {
  172. tmp(i, j) = -_data[i][j];
  173. }
  174. return tmp;
  175. }
  176. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  177. vector::vector<T, C> rval;
  178. MATRIX_FOR_EACH(i, j) {
  179. rval[i] += _data[i][j] * vec[j];
  180. }
  181. return rval;
  182. }
  183. template <std::size_t C2>
  184. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  185. matrix<T, R, C2> rval;
  186. MATRIX_FOR_EACH(i, j) {
  187. for (size_t k = 0; k < C2; ++k) {
  188. rval(i, k) += _data[i][j] * other(j, k);
  189. }
  190. }
  191. return rval;
  192. }
  193. matrix<T, R, C>& operator*=(T c) {
  194. MATRIX_FOR_EACH(i, j) {
  195. _data[i][j] *= c;
  196. }
  197. return *this;
  198. }
  199. template <typename M>
  200. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c) const {
  201. return matrix<mul_t<M>, R, C>{*this} *= c;
  202. }
  203. template <typename M>
  204. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C) operator*(M c, matrix const& matr) {
  205. return matrix<mul_t<M>, R, C>{matr} *= c;
  206. }
  207. template <typename M>
  208. matrix<div_t<M>, R, C>& operator/=(M c) {
  209. MATRIX_FOR_EACH(i, j) {
  210. _data[i][j] /= c;
  211. }
  212. return *this;
  213. }
  214. template <typename M>
  215. matrix<div_t<M>, R, C> operator/(M c) const {
  216. return matrix<mul_t<M>, R, C>{*this} /= c;
  217. }
  218. bool operator==(matrix const & other) const {
  219. MATRIX_FOR_EACH(i, j) {
  220. if (_data[i][j] != other(i, j)) {
  221. return false;
  222. }
  223. }
  224. return true;
  225. }
  226. bool operator!=(matrix const & other) const {
  227. return !operator==(other);
  228. }
  229. private:
  230. value_type _data[R][C] = {value_type()};
  231. };
  232. } }