浏览代码

Making operator* take rhs by value to optimize chaining.

Samuel Jaffe 8 年之前
父节点
当前提交
dd4812d507
共有 2 个文件被更改,包括 32 次插入26 次删除
  1. 1 1
      include/biginteger.h
  2. 31 25
      src/biginteger.cpp

+ 1 - 1
include/biginteger.h

@@ -31,7 +31,7 @@ namespace math {
     // Binary operators
     friend biginteger operator+(biginteger, biginteger const &);
     friend biginteger operator-(biginteger, biginteger const &);
-    friend biginteger operator*(biginteger const &, biginteger const &);
+    friend biginteger operator*(biginteger, biginteger const &);
     friend biginteger operator/(biginteger const &, biginteger const &);
     friend biginteger operator%(biginteger const &, biginteger const &);
     friend biginteger & operator+=(biginteger &, biginteger const &);

+ 31 - 25
src/biginteger.cpp

@@ -20,7 +20,7 @@ namespace detail {
   int compare(data_type const & rhs, data_type const & lhs);
   void add(data_type & rhs, data_type const & lhs);
   void subtract_nounderflow(data_type & rhs, data_type const & lhs);
-  data_type multiply(data_type const & rhs, data_type const & lhs);
+  void multiply(data_type & rhs, data_type const & lhs);
   std::pair<data_type, data_type> divide(data_type remainder, data_type const & divisor);
 }
 
@@ -119,7 +119,17 @@ biginteger & math::operator-=(biginteger & rhs, biginteger const & lhs) {
 }
 
 biginteger & math::operator*=(biginteger & rhs, biginteger const & lhs) {
-  swap(rhs, rhs * lhs);
+  bool is_neg = rhs.is_negative != lhs.is_negative;
+  if (rhs == biginteger::ZERO || lhs == biginteger::ZERO) {
+    rhs = biginteger::ZERO;
+  } else if (detail::compare(lhs.data, biginteger::ONE.data) == 0) {
+    rhs.is_negative = is_neg;
+  } else if (detail::compare(rhs.data, biginteger::ONE.data) == 0) {
+    rhs = lhs;
+    rhs.is_negative = is_neg;
+  } else {
+    detail::multiply(rhs.data, lhs.data);
+  }
   return rhs;
 }
 
@@ -140,11 +150,8 @@ biginteger math::operator-(biginteger rhs, biginteger const & lhs) {
   return rhs -= lhs;
 }
 
-biginteger math::operator*(biginteger const & rhs, biginteger const & lhs) {
-  if (rhs == biginteger::ZERO || lhs == biginteger::ZERO) {
-    return biginteger::ZERO;
-  }
-  return {rhs.is_negative != lhs.is_negative, detail::multiply(rhs.data, lhs.data)};
+biginteger math::operator*(biginteger rhs, biginteger const & lhs) {
+  return rhs *= lhs;
 }
 
 biginteger math::operator/(biginteger const & rhs, biginteger const & lhs) {
@@ -163,8 +170,8 @@ biginteger math::operator/(biginteger const & rhs, biginteger const & lhs) {
 
 biginteger math::operator%(biginteger const & rhs, biginteger const & lhs) {
   if (lhs == biginteger::ZERO) { throw std::domain_error("cannot divide by 0"); }
-  else if (rhs == biginteger::ZERO || lhs == biginteger::ONE ||
-           lhs == biginteger::NEGATIVE_ONE) { return biginteger::ZERO; }
+  else if (detail::compare(lhs.data, biginteger::ONE.data) == 0 ||
+           rhs == biginteger::ZERO) { return biginteger::ZERO; }
   else {
     auto cmp = detail::compare(rhs.data, lhs.data);
     if (cmp < 0) { return rhs; }
@@ -229,33 +236,32 @@ namespace detail {
     if (rhs[rbnd-1] == 0 && rbnd > 1) { rhs.pop_back(); }
   }
   
-  data_type multiply(data_type const & rhs, data_type const & lhs) {
-    if (compare(rhs, {1}) == 0) { return lhs; }
-    else if (compare(lhs, {1}) == 0) { return rhs; }
+  void multiply(data_type & rhs, data_type const & lhs) {
     size_t const rbnd = rhs.size(), lbnd = lhs.size();
     size_t const ubnd = rbnd + lbnd;
-    data_type rval(ubnd + 1);
+    rhs.resize(ubnd);
     // Multiply
-    for (size_t i = 0; i < rbnd; ++i) {
-      for (size_t j = 0; j < lbnd; ++j) {
+    for (size_t i = rbnd; i > 0; --i) {
+      int32_t const value = rhs[i-1];
+      for (size_t j = lbnd; j > 0; --j) {
         // Max input              999,999,999
         // Max output 999,999,998,000,000,001
-        int64_t product = static_cast<int64_t>(rhs[i]) * static_cast<int64_t>(lhs[j]);
+        int64_t product = static_cast<int64_t>(value) * static_cast<int64_t>(lhs[j-1]);
         int64_t overflow = product / biginteger::OVER_SEG;
-        rval[i+j] += static_cast<int32_t>(product - (overflow * biginteger::OVER_SEG));
-        rval[i+j+1] += static_cast<int32_t>(overflow);
+        rhs[i+j-2] += static_cast<int32_t>(product - (overflow * biginteger::OVER_SEG));
+        rhs[i+j-1] += static_cast<int32_t>(overflow);
       }
+      rhs[i-1] -= value;
     }
     // Carry
-    for (size_t i = 0; i < ubnd; ++i) {
-      if (rval[i] > biginteger::MAX_SEG) {
-        int32_t overflow = rval[i] / biginteger::OVER_SEG;
-        rval[i] -= (overflow * biginteger::OVER_SEG);
-        rval[i+1] += overflow;
+    for (size_t i = 0; i < ubnd-1; ++i) {
+      if (rhs[i] > biginteger::MAX_SEG) {
+        int32_t overflow = rhs[i] / biginteger::OVER_SEG;
+        rhs[i] -= (overflow * biginteger::OVER_SEG);
+        rhs[i+1] += overflow;
       }
     }
-    while (rval.back() == 0 && rval.size() > 1) { rval.pop_back(); }
-    return rval;
+    while (rhs.back() == 0) { rhs.pop_back(); }
   }
   
   data_type shift10(data_type const & data, int32_t pow, size_t shift) {