Browse Source

Extract Row reference to new type, this allows up to have

matr3x3()[0] = matr2x3()[1]
Sam Jaffe 7 years ago
parent
commit
fcd6a426c3
1 changed files with 48 additions and 42 deletions
  1. 48 42
      matrix.hpp

+ 48 - 42
matrix.hpp

@@ -35,6 +35,36 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
 #define MATRIX_FOR_EACH_RANGE(i, i_max, j, j_max) for (size_t i = 0; i < i_max; ++i) for (size_t j = 0; j < j_max; ++j)
 #define MATRIX_FOR_EACH(i, j) MATRIX_FOR_EACH_RANGE(i, R, j, C)
   
+  template <typename T, std::size_t C>
+  class row_reference {
+  private:
+    row_reference(T ( & h )[C]) : _handle(h) {}
+  public:
+    row_reference(row_reference const &) = delete;
+    row_reference(row_reference &&) = default;
+    
+    template <typename S>
+    row_reference & operator=(row_reference<S, C> const & other) {
+      VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
+      return *this;
+    }
+    
+    T const & operator[](std::size_t col) const { return _handle[col]; }
+    T & operator[](std::size_t col) { return _handle[col]; }
+    T const & at(std::size_t col) const {
+      expects(col < C, std::out_of_range, "column index out of range");
+      return operator[](col);
+    }
+    T & at(std::size_t col) {
+      expects(col < C, std::out_of_range, "column index out of range");
+      return operator[](col);
+    }
+    
+  private:
+    template <typename _T, std::size_t R, std::size_t _C> friend class matrix;
+    T ( & _handle )[C];
+  };
+  
   template <typename T, std::size_t R, std::size_t C>
   class matrix {
   public:
@@ -45,38 +75,6 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
     template <typename M>
     using div_t = decltype(std::declval<T>()/std::declval<M>());
   public:
-    
-    template <typename S>
-    class row_reference {
-    public:
-      row_reference(S ( & h )[C]) : _handle(h) {}
-      row_reference(row_reference const &) = delete;
-      row_reference(row_reference &&) = default;
-
-      row_reference & operator=(row_reference const & other) {
-        return operator=<S>(other);
-      }
-      template <typename S2>
-      row_reference & operator=(row_reference<S2> const & other) {
-        VECTOR_FOR_EACH_RANGE(i, C) { _handle[i] = other[i]; }
-        return *this;
-      }
-
-      S const & operator[](std::size_t col) const { return _handle[col]; }
-      S & operator[](std::size_t col) { return _handle[col]; }
-      S const & at(std::size_t col) const {
-        expects(col < C, std::out_of_range, "column index out of range");
-        return operator[](col);
-      }
-      S & at(std::size_t col) {
-        expects(col < C, std::out_of_range, "column index out of range");
-        return operator[](col);
-      }
-      
-    private:
-      S ( & _handle )[C];
-    };
-    
     matrix() = default;
     matrix(std::array<std::array<T, C>, R> const & init) {
       MATRIX_FOR_EACH(i, j) {
@@ -98,13 +96,13 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
     }
     matrix & operator=(matrix const& other) {
       MATRIX_FOR_EACH(i, j) {
-        _data[i][j] = other._data[i][j];
+        _data[i][j] = other(i, j);
       }
       return *this;
     }
     matrix & operator=(matrix && other) {
       MATRIX_FOR_EACH(i, j) {
-        _data[i][j] = std::move(other._data[i][j]);
+        _data[i][j] = std::move(other(i, j));
       }
       return *this;
     }
@@ -149,17 +147,17 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
     T & operator()(std::size_t row, std::size_t col) {
       return _data[row][col];
     }
-    row_reference<const T> operator[](std::size_t row) const {
+    row_reference<const T, C> operator[](std::size_t row) const {
       return { _data[row] };
     }
-    row_reference<T> operator[](std::size_t row) {
+    row_reference<T, C> operator[](std::size_t row) {
       return { _data[row] };
     }
-    row_reference<const T> at(std::size_t row) const {
+    row_reference<const T, C> at(std::size_t row) const {
       expects(row >= R, std::out_of_range, "row index out of range");
       return operator[](row);
     }
-    row_reference<T> at(std::size_t row) {
+    row_reference<T, C> at(std::size_t row) {
       expects(row >= R, std::out_of_range, "row index out of range");
       return operator[](row);
     }
@@ -174,7 +172,7 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
     
     matrix& operator+=(matrix const & other) {
       MATRIX_FOR_EACH(i, j) {
-        _data[i][j] += other[i][j];
+        _data[i][j] += other(i, j);
       }
       return *this;
     }
@@ -183,7 +181,7 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
     }
     matrix& operator-=(matrix const & other) {
       MATRIX_FOR_EACH(i, j) {
-        _data[i][j] -= other[i][j];
+        _data[i][j] -= other(i, j);
       }
       return *this;
     }
@@ -191,6 +189,14 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
       return matrix{*this} -= other;
     }
     
+    matrix operator-() const {
+      matrix tmp;
+      MATRIX_FOR_EACH(i, j) {
+        tmp(i, j) = -_data[i][j];
+      }
+      return tmp;
+    }
+    
     vector::vector<T, C> operator*(vector::vector<T, C> const & vec) const {
       vector::vector<T, C> rval;
       MATRIX_FOR_EACH(i, j) {
@@ -204,7 +210,7 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
       matrix<T, R, C2> rval;
       MATRIX_FOR_EACH(i, j) {
         for (size_t k = 0; k < C2; ++k) {
-          rval[i][k] += _data[i][j] * other[j][k];
+          rval(i, k) += _data[i][j] * other(j, k);
         }
       }
       return rval;
@@ -241,7 +247,7 @@ typename std::enable_if<!is_matrix<_type>::value, matrix<t, r, c> >::type
     
     bool operator==(matrix const & other) const {
       MATRIX_FOR_EACH(i, j) {
-        if (_data[i][j] != other._data[i][j]) {
+        if (_data[i][j] != other(i, j)) {
           return false;
         }
       }