matrix.hpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include <expect/expect.hpp>
  9. #include <math/vector/vector.hpp>
  10. #include "math/matrix/forward.h"
  11. #include "math/matrix/row_reference.hpp"
  12. #include "math/matrix/traits.hpp"
  13. #define MATRIX_DISABLE_IF_MATRIX(_type, t, r, c) \
  14. typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c>>::type
  15. #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) \
  16. for (size_t i = 0; i < i_max; ++i) \
  17. for (size_t j = 0; j < j_max; ++j)
  18. #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
  19. namespace math { namespace matrix {
  20. template <typename T, std::size_t R, std::size_t C> class matrix {
  21. public:
  22. using value_type = T;
  23. template <typename M>
  24. using mul_t = decltype(std::declval<T>() * std::declval<M>());
  25. template <typename M>
  26. using div_t = decltype(std::declval<T>() / std::declval<M>());
  27. public:
  28. matrix() = default;
  29. matrix(std::array<std::array<T, C>, R> const & init) {
  30. MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; }
  31. }
  32. template <size_t N>
  33. matrix(vector::vector<typename std::enable_if<C == 1 && N == R, T>::type,
  34. N> const & other) {
  35. VECTOR_FOR_EACH(i) { _data[i][0] = other[i]; }
  36. }
  37. matrix(matrix const & other) { *this = other; }
  38. matrix(matrix && other) { *this = std::move(other); }
  39. matrix & operator=(matrix const & other) {
  40. MATRIX_FOR_EACH(i, j) { _data[i][j] = other(i, j); }
  41. return *this;
  42. }
  43. matrix & operator=(matrix && other) {
  44. MATRIX_FOR_EACH(i, j) { _data[i][j] = std::move(other(i, j)); }
  45. return *this;
  46. }
  47. template <size_t R2, size_t C2> matrix(matrix<T, R2, C2> const & other) {
  48. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  49. _data[i][j] = other(i, j);
  50. }
  51. }
  52. matrix<T, C, R> transpose() const {
  53. matrix<T, C, R> out;
  54. MATRIX_FOR_EACH(i, j) { out(j, i) = _data[i][j]; }
  55. return out;
  56. }
  57. template <size_t C2>
  58. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other,
  59. concat_strategy::horizontal_concat_t) const {
  60. matrix<T, R, C + C2> accum{*this};
  61. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  62. return accum;
  63. }
  64. template <size_t R2>
  65. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other,
  66. concat_strategy::vertical_concat_t) const {
  67. matrix<T, R + R2, C> accum{*this};
  68. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  69. return accum;
  70. }
  71. template <size_t R2, size_t C2>
  72. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other,
  73. concat_strategy::diagonal_concat_t) const {
  74. matrix<T, R + R2, C + C2> accum{*this};
  75. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  76. return accum;
  77. }
  78. T const & operator()(std::size_t row, std::size_t col) const {
  79. return _data[row][col];
  80. }
  81. T & operator()(std::size_t row, std::size_t col) { return _data[row][col]; }
  82. row_reference<const T, C> operator[](std::size_t row) const {
  83. return {_data[row]};
  84. }
  85. row_reference<T, C> operator[](std::size_t row) { return {_data[row]}; }
  86. row_reference<const T, C> at(std::size_t row) const {
  87. expects(row < R, std::out_of_range, "row index out of range");
  88. return operator[](row);
  89. }
  90. row_reference<T, C> at(std::size_t row) {
  91. expects(row < R, std::out_of_range, "row index out of range");
  92. return operator[](row);
  93. }
  94. value_type const & at(std::size_t row, std::size_t col) const {
  95. expects(row < R && col < C, std::out_of_range,
  96. "coordinates out of range");
  97. return _data[row][col];
  98. }
  99. value_type & at(std::size_t row, std::size_t col) {
  100. expects(row < R && col < C, std::out_of_range,
  101. "coordinates out of range");
  102. return _data[row][col];
  103. }
  104. matrix & operator+=(matrix const & other) {
  105. MATRIX_FOR_EACH(i, j) { _data[i][j] += other(i, j); }
  106. return *this;
  107. }
  108. matrix operator+(matrix const & other) const {
  109. return matrix{*this} += other;
  110. }
  111. matrix & operator-=(matrix const & other) {
  112. MATRIX_FOR_EACH(i, j) { _data[i][j] -= other(i, j); }
  113. return *this;
  114. }
  115. matrix operator-(matrix const & other) const {
  116. return matrix{*this} -= other;
  117. }
  118. matrix operator-() const {
  119. matrix tmp;
  120. MATRIX_FOR_EACH(i, j) { tmp(i, j) = -_data[i][j]; }
  121. return tmp;
  122. }
  123. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  124. vector::vector<T, C> rval;
  125. MATRIX_FOR_EACH(i, j) { rval[i] += _data[i][j] * vec[j]; }
  126. return rval;
  127. }
  128. template <std::size_t C2>
  129. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  130. matrix<T, R, C2> rval;
  131. MATRIX_FOR_EACH(i, j) {
  132. for (size_t k = 0; k < C2; ++k) {
  133. rval(i, k) += _data[i][j] * other(j, k);
  134. }
  135. }
  136. return rval;
  137. }
  138. matrix<T, R, C> & operator*=(T c) {
  139. MATRIX_FOR_EACH(i, j) { _data[i][j] *= c; }
  140. return *this;
  141. }
  142. template <typename M>
  143. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C)
  144. operator*(M c) const {
  145. return matrix<mul_t<M>, R, C>{*this} *= c;
  146. }
  147. template <typename M>
  148. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C)
  149. operator*(M c, matrix const & matr) {
  150. return matrix<mul_t<M>, R, C>{matr} *= c;
  151. }
  152. template <typename M> matrix<div_t<M>, R, C> & operator/=(M c) {
  153. MATRIX_FOR_EACH(i, j) { _data[i][j] /= c; }
  154. return *this;
  155. }
  156. template <typename M> matrix<div_t<M>, R, C> operator/(M c) const {
  157. return matrix<mul_t<M>, R, C>{*this} /= c;
  158. }
  159. bool operator==(matrix const & other) const {
  160. MATRIX_FOR_EACH(i, j) {
  161. if (_data[i][j] != other(i, j)) { return false; }
  162. }
  163. return true;
  164. }
  165. bool operator!=(matrix const & other) const { return !operator==(other); }
  166. private:
  167. value_type _data[R][C] = {value_type()};
  168. };
  169. }}