matrix.hpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. //
  2. // matrix.hpp
  3. // math
  4. //
  5. // Created by Sam Jaffe on 5/28/16.
  6. //
  7. #pragma once
  8. #include <expect/expect.hpp>
  9. #include <math/vector/vector.hpp>
  10. #include "math/matrix/forward.h"
  11. #include "math/matrix/row_reference.hpp"
  12. #include "math/matrix/traits.hpp"
  13. #include "math/matrix/macro.h"
  14. namespace math::matrix {
  15. template <typename T, std::size_t R, std::size_t C> class matrix {
  16. public:
  17. using value_type = T;
  18. template <typename M>
  19. using mul_t = decltype(std::declval<T>() * std::declval<M>());
  20. template <typename M>
  21. using div_t = decltype(std::declval<T>() / std::declval<M>());
  22. public:
  23. matrix() = default;
  24. MATRIX_CTOR_N_ARGS(1) {}
  25. MATRIX_CTOR_N_ARGS(2) {}
  26. MATRIX_CTOR_N_ARGS(3) {}
  27. MATRIX_CTOR_N_ARGS(4) {}
  28. DEFER_RESOLUTION matrix(vector::vector<T, R> const & other,
  29. DEFERRED_ENABLE_IF_T(C == 1, bool) = true) {
  30. VECTOR_FOR_EACH_RANGE(i, R) { _data[i][0] = other[i]; }
  31. }
  32. matrix(std::array<std::array<T, C>, R> const & init) {
  33. MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; }
  34. }
  35. matrix(std::array<vector::vector<T, C>, R> const & init) {
  36. MATRIX_FOR_EACH(i, j) { _data[i][j] = init[i][j]; }
  37. }
  38. template <size_t R2, size_t C2> matrix(matrix<T, R2, C2> const & other) {
  39. MATRIX_FOR_EACH_RANGE(i, std::min(R, R2), j, std::min(C, C2)) {
  40. _data[i][j] = other(i, j);
  41. }
  42. }
  43. matrix<T, C, R> transpose() const {
  44. matrix<T, C, R> out;
  45. MATRIX_FOR_EACH(i, j) { out(j, i) = _data[i][j]; }
  46. return out;
  47. }
  48. template <size_t C2>
  49. matrix<T, R, C + C2> concat(matrix<T, R, C2> const & other,
  50. concat_strategy::horizontal_concat_t) const {
  51. matrix<T, R, C + C2> accum{*this};
  52. MATRIX_FOR_EACH_RANGE(i, R, j, C2) { accum(i, j + C) = other(i, j); }
  53. return accum;
  54. }
  55. template <size_t R2>
  56. matrix<T, R + R2, C> concat(matrix<T, R2, C> const & other,
  57. concat_strategy::vertical_concat_t) const {
  58. matrix<T, R + R2, C> accum{*this};
  59. MATRIX_FOR_EACH_RANGE(i, R2, j, C) { accum(i + R, j) = other(i, j); }
  60. return accum;
  61. }
  62. template <size_t R2, size_t C2>
  63. matrix<T, R + R2, C + C2> concat(matrix<T, R2, C2> const & other,
  64. concat_strategy::diagonal_concat_t) const {
  65. matrix<T, R + R2, C + C2> accum{*this};
  66. MATRIX_FOR_EACH_RANGE(i, R2, j, C2) { accum(i + R, j + C) = other(i, j); }
  67. return accum;
  68. }
  69. // In C++23 - we can use operator[](row, col)
  70. T const & operator()(std::size_t row, std::size_t col) const {
  71. return _data[row][col];
  72. }
  73. T & operator()(std::size_t row, std::size_t col) { return _data[row][col]; }
  74. row_reference<const T, C> operator[](std::size_t row) const {
  75. return {_data[row]};
  76. }
  77. row_reference<T, C> operator[](std::size_t row) { return {_data[row]}; }
  78. row_reference<const T, C> at(std::size_t row) const {
  79. expects(row < R, std::out_of_range, "row index out of range");
  80. return operator[](row);
  81. }
  82. row_reference<T, C> at(std::size_t row) {
  83. expects(row < R, std::out_of_range, "row index out of range");
  84. return operator[](row);
  85. }
  86. value_type const & at(std::size_t row, std::size_t col) const {
  87. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  88. return _data[row][col];
  89. }
  90. value_type & at(std::size_t row, std::size_t col) {
  91. expects(row < R && col < C, std::out_of_range, "coordinates out of range");
  92. return _data[row][col];
  93. }
  94. matrix & operator+=(matrix const & other) {
  95. MATRIX_FOR_EACH(i, j) { _data[i][j] += other(i, j); }
  96. return *this;
  97. }
  98. matrix operator+(matrix const & other) const {
  99. return matrix{*this} += other;
  100. }
  101. matrix & operator-=(matrix const & other) {
  102. MATRIX_FOR_EACH(i, j) { _data[i][j] -= other(i, j); }
  103. return *this;
  104. }
  105. matrix operator-(matrix const & other) const {
  106. return matrix{*this} -= other;
  107. }
  108. matrix operator-() const {
  109. matrix tmp;
  110. MATRIX_FOR_EACH(i, j) { tmp(i, j) = -_data[i][j]; }
  111. return tmp;
  112. }
  113. vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
  114. vector::vector<T, C> rval;
  115. MATRIX_FOR_EACH(i, j) { rval[i] += _data[i][j] * vec[j]; }
  116. return rval;
  117. }
  118. template <std::size_t C2>
  119. matrix<T, R, C2> operator*(matrix<T, C, C2> const & other) const {
  120. matrix<T, R, C2> rval;
  121. MATRIX_FOR_EACH(i, j) {
  122. for (size_t k = 0; k < C2; ++k) {
  123. rval(i, k) += _data[i][j] * other(j, k);
  124. }
  125. }
  126. return rval;
  127. }
  128. matrix<T, R, C> & operator*=(T c) {
  129. MATRIX_FOR_EACH(i, j) { _data[i][j] *= c; }
  130. return *this;
  131. }
  132. template <typename M>
  133. MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C)
  134. operator*(M c) const {
  135. return matrix<mul_t<M>, R, C>{*this} *= c;
  136. }
  137. template <typename M>
  138. friend MATRIX_DISABLE_IF_MATRIX(M, mul_t<M>, R, C)
  139. operator*(M c, matrix const & matr) {
  140. return matrix<mul_t<M>, R, C>{matr} *= c;
  141. }
  142. template <typename M> matrix<div_t<M>, R, C> & operator/=(M c) {
  143. MATRIX_FOR_EACH(i, j) { _data[i][j] /= c; }
  144. return *this;
  145. }
  146. template <typename M> matrix<div_t<M>, R, C> operator/(M c) const {
  147. return matrix<mul_t<M>, R, C>{*this} /= c;
  148. }
  149. bool operator==(matrix const & other) const {
  150. MATRIX_FOR_EACH(i, j) {
  151. if (_data[i][j] != other(i, j)) { return false; }
  152. }
  153. return true;
  154. }
  155. bool operator!=(matrix const & other) const { return !operator==(other); }
  156. private:
  157. value_type _data[R][C] = {value_type()};
  158. };
  159. MATRIX_CTOR_DEDUCTION(1);
  160. MATRIX_CTOR_DEDUCTION(2);
  161. MATRIX_CTOR_DEDUCTION(3);
  162. MATRIX_CTOR_DEDUCTION(4);
  163. }
  164. #include "math/matrix/undef.h"