浏览代码

Fixing division. Cleaning up multiplication code. Using getter functions where appropriate.

Samuel Jaffe 8 年之前
父节点
当前提交
c122c8ff1e
共有 3 个文件被更改,包括 56 次插入86 次删除
  1. 4 2
      include/bigdecimal.h
  2. 41 35
      src/bigdecimal.cpp
  3. 11 49
      test/bigdecimal.t.h

+ 4 - 2
include/bigdecimal.h

@@ -61,17 +61,19 @@ namespace math {
     std::string to_string() const;
   private:
     bigdecimal(bool, uint64_t);
-    int32_t get_impl_scale() const;
+    void set_scale(int32_t);
     void subtract_impl(bigdecimal const &, bool);
+    
     friend void swap(bigdecimal & rhs, bigdecimal & lhs) {
       using std::swap;
       swap(rhs.is_negative, lhs.is_negative);
       swap(rhs.scale_, lhs.scale_);
+      swap(rhs.steps_, lhs.steps_);
       swap(rhs.data, lhs.data);
     }
     
     bool is_negative;
-    int32_t scale_{0};
+    int32_t scale_{0}, steps_{0};
     data_type data{};
   };
 }

+ 41 - 35
src/bigdecimal.cpp

@@ -53,8 +53,8 @@ is_negative(number[0] == '-') {
   if (auto p = strchr(number, '.')) {
     read_segment(data, number, size_t(p - number));
     number = p + 1;
-    scale_ = int32_t(strlen(number));
-    size_t elems = size_t(scale_/SEG_DIGITS);
+    set_scale(int32_t(strlen(number)));
+    size_t elems = size_t(scale()/SEG_DIGITS);
     data.resize(elems+data.size());
     for (size_t idx = elems; idx > 0; --idx) {
       read(data[idx-1], number, SEG_DIGITS);
@@ -68,33 +68,34 @@ is_negative(number[0] == '-') {
 }
 
 void bigdecimal::set_value(bigdecimal const & other) {
-  int32_t scale = scale_;
+  int32_t oscale = scale();
   operator=(other);
-  rescale(scale);
+  rescale(oscale);
 }
 
-int32_t bigdecimal::get_impl_scale() const {
-  return add_scale(scale_);
+void bigdecimal::set_scale(int32_t nscale) {
+  scale_ = nscale;
+  steps_ = add_scale(nscale);
 }
 
 void bigdecimal::rescale(int32_t nscale) {
-  int32_t steps = add_scale(scale_);
   int32_t const nsteps = add_scale(nscale);
-  if (steps > nsteps) {
-    data.erase(data.begin(), data.begin() + steps - nsteps);
-  } else if (steps < nsteps) {
-    data.insert(data.begin(), size_t(nsteps-steps), 0);
+  if (steps_ > nsteps) {
+    data.erase(data.begin(), data.begin() + steps_ - nsteps);
+    if (data.empty()) { data.push_back(0); }
+  } else if (steps_ < nsteps) {
+    data.insert(data.begin(), size_t(nsteps-steps_), 0);
   }
-  auto const idx = (scale_ - nscale) % SEG_DIGITS;
-  if (scale_ > nscale && nscale < 0 && idx) {
+  auto const idx = -nscale % SEG_DIGITS;
+  if (scale() > nscale && nscale < 0 && idx) {
     data.back() /= detail::powers[idx];
     data.back() *= detail::powers[idx];
   }
-  scale_ = nscale;
+  set_scale(nscale);
 }
 
 void bigdecimal::subtract_impl(bigdecimal const & lhs, bool is_sub) {
-  size_t const offset = size_t(get_impl_scale() - lhs.get_impl_scale());
+  size_t const offset = size_t(steps_ - lhs.steps_);
   auto cmp = detail::compare(data, lhs.data, offset);
   if (cmp == 0) {
     set_value(bigdecimal::ZERO);
@@ -114,9 +115,9 @@ bigdecimal bigdecimal::operator-() const {
 }
 
 bigdecimal & math::operator+=(bigdecimal & rhs, bigdecimal const & lhs) {
-  int32_t const new_scale = std::max(rhs.scale_, lhs.scale_);
+  int32_t const new_scale = std::max(rhs.scale(), lhs.scale());
   rhs.rescale(new_scale);
-  size_t const offset = size_t(rhs.get_impl_scale() - lhs.get_impl_scale());
+  size_t const offset = size_t(rhs.steps_ - lhs.steps_);
   if (lhs == bigdecimal::ZERO) { return rhs; }
   else if (rhs == bigdecimal::ZERO) { rhs.set_value(lhs); }
   else if (rhs.is_negative == lhs.is_negative) {
@@ -128,9 +129,9 @@ bigdecimal & math::operator+=(bigdecimal & rhs, bigdecimal const & lhs) {
 }
 
 bigdecimal & math::operator-=(bigdecimal & rhs, bigdecimal const & lhs) {
-  int32_t const new_scale = std::max(rhs.scale_, lhs.scale_);
+  int32_t const new_scale = std::max(rhs.scale(), lhs.scale());
   rhs.rescale(new_scale);
-  size_t const offset = size_t(rhs.get_impl_scale() - lhs.get_impl_scale());
+  size_t const offset = size_t(rhs.steps_ - lhs.steps_);
   if (lhs == bigdecimal::ZERO) { return rhs; }
   else if (rhs == bigdecimal::ZERO) { rhs.set_value(-lhs); }
   else if (rhs.is_negative != lhs.is_negative) {
@@ -155,7 +156,7 @@ static bool is_one(bigdecimal::data_type const & data, int32_t scale) {
 }
 
 bigdecimal & math::operator*=(bigdecimal & rhs, bigdecimal const & lhs) {
-  int32_t const new_scale = rhs.scale_ + lhs.scale_;
+  int32_t const new_scale = rhs.scale() + lhs.scale();
   bool is_neg = rhs.is_negative != lhs.is_negative;
   if (rhs == bigdecimal::ZERO || lhs == bigdecimal::ZERO) {
     rhs = bigdecimal::ZERO;
@@ -166,14 +167,11 @@ bigdecimal & math::operator*=(bigdecimal & rhs, bigdecimal const & lhs) {
     rhs.is_negative = is_neg;
   } else {
     detail::multiply(rhs.data, lhs.data);
-    int32_t const steps = add_scale(rhs.scale_) + add_scale(lhs.scale_), nsteps = add_scale(new_scale);
-    int32_t const off2 = (steps < 0 ? 0 : 1);
-    if (steps > nsteps) {
-      rhs.data.erase(rhs.data.begin(), rhs.data.begin() + steps - nsteps);
-    } else if (steps + off2 < nsteps) {
-      rhs.data.insert(rhs.data.begin(), size_t(nsteps-steps-off2), 0);
+    auto diff = add_scale(new_scale) - rhs.steps_ - lhs.steps_;
+    if (diff < 0) {
+      rhs.data.erase(rhs.data.begin(), rhs.data.begin() - diff);
     }
-    rhs.scale_ = new_scale; // TODO: more steps in certain cases
+    rhs.set_scale(new_scale);
     return rhs;
   }
   rhs.rescale(new_scale);
@@ -181,18 +179,26 @@ bigdecimal & math::operator*=(bigdecimal & rhs, bigdecimal const & lhs) {
 }
 
 bigdecimal & math::operator/=(bigdecimal & rhs, bigdecimal const & lhs) {
-  int32_t const new_scale = rhs.scale_ - lhs.scale_;
+  int32_t const new_scale = rhs.scale() - lhs.scale();
   rhs.is_negative ^= lhs.is_negative;
   if (lhs == bigdecimal::ZERO) { throw std::domain_error("cannot divide by 0"); }
   else if (rhs == bigdecimal::ZERO) {
     rhs = bigdecimal::ZERO;
-  } else if (detail::compare(lhs.data, bigdecimal::ONE.data) != 0) {
+  } else if (!is_one(lhs.data, lhs.scale())) {
+    if (rhs.scale() < new_scale) rhs.rescale(new_scale);
     auto cmp = detail::compare(rhs.data, lhs.data);
     if (cmp < 0) { rhs = bigdecimal::ZERO; }
     else if (cmp == 0) { rhs.data = {1}; }
     else {
       rhs.data = detail::divide(rhs.data, lhs.data);
-      rhs.scale_ = new_scale;
+      auto diff = add_scale(new_scale) - rhs.steps_ - lhs.steps_;
+      if (diff > 0) {
+        rhs.data.insert(rhs.data.begin(), size_t(diff), 0);
+      }
+      if (rhs.data.size() <= mul_scale(new_scale)) {
+        rhs.data.push_back(0);
+      }
+      rhs.set_scale(new_scale);
       return rhs;
     }
   }
@@ -223,8 +229,8 @@ bool math::operator==(bigdecimal const & lhs, bigdecimal const & rhs) {
 bigdecimal const bigdecimal::ZERO{0}, bigdecimal::ONE{1}, bigdecimal::NEGATIVE_ONE{-1};
 
 std::string bigdecimal::to_string() const {
-  size_t const decimal_split = size_t(std::max(0, add_scale(scale_)));
-  int32_t const hidden = std::max(0, SEG_DIGITS * (-scale_/SEG_DIGITS));
+  size_t const decimal_split = size_t(std::max(0, steps_));
+  int32_t const hidden = std::max(0, SEG_DIGITS * (-scale()/SEG_DIGITS));
   size_t const chars = SEG_DIGITS * data.size() + size_t(hidden);
   std::vector<char> output(chars + 3, '\0');
   char * ptr = output.data();
@@ -233,14 +239,14 @@ std::string bigdecimal::to_string() const {
   for (size_t i = data.size()-1; i > decimal_split; --i) {
     ptr += sprintf(ptr, "%0.*d", SEG_DIGITS, data[i-1]);
   }
-  if (scale_ > 0) {
+  if (scale() > 0) {
     *ptr++ = '.';
     for (size_t i = decimal_split; i > 1; --i) {
       ptr += sprintf(ptr, "%0.*d", SEG_DIGITS, data[i-1]);
     }
-    int32_t const val = scale_ % SEG_DIGITS;
+    int32_t const val = scale() % SEG_DIGITS;
     sprintf(ptr, "%0.*d", val, data[0]/detail::powers[SEG_DIGITS-val]);
-  } else {
+  } else if (data.back()) {
     sprintf(ptr, "%0.*d", hidden, 0);
   }
   return output.data();

+ 11 - 49
test/bigdecimal.t.h

@@ -132,55 +132,6 @@ public:
     TS_ASSERT_EQUALS((G*F).to_string(), "1000000000");
   }
   
-  math::bigdecimal __create(int32_t scale) {
-    if (scale >= 0) return { "1", scale };
-    std::vector<char> data(size_t(std::abs(scale))+3, 0);
-    data[0] = '1';
-    sprintf(data.data()+1, "%0.*d", -scale, 0);
-    return { data.data(), scale };
-  }
-  
-  std::string __createExpect(int32_t scale1, int32_t scale2) {
-    int32_t decimals{0}, magnitude{std::min(scale1, scale2)};
-    if (scale1 < 0 == scale2 < 0) {
-      decimals = scale1 + scale2;
-    } else if (scale1 > 0 && scale1 > -scale2) {
-      decimals = scale1 + scale2;
-    } else if (scale2 > 0 && scale2 > -scale1) {
-      decimals = scale1 + scale2;
-    } else {
-      decimals = std::min(scale1, scale2);
-    }
-    if (decimals < 0) { magnitude = decimals; }
-    std::vector<char> data(size_t(std::abs(decimals) + std::abs(magnitude))+4, 0);
-    char * ptr = data.data();
-    *ptr++ = '1';
-    if (magnitude < 0) {
-      ptr += sprintf(ptr, "%0.*d", -magnitude, 0);
-    }
-    if (decimals > 0) {
-      sprintf(ptr, ".%0.*d", decimals, 0);
-    }
-    return data.data();
-  }
-  
-  void __testMultiplyScales(int32_t scale1, int32_t scale2) {
-    math::bigdecimal A(__create(scale1));
-    math::bigdecimal B(__create(scale2));
-    // This is wrong
-    std::string expected(__createExpect(scale1, scale2));
-    TS_ASSERT_EQUALS((A*B).to_string(), expected);
-    TS_ASSERT_EQUALS((B*A).to_string(), expected);
-  }
-  
-  void testMultiplicationBrutePermutations() {
-    for (int32_t i = -20; i <= +20; ++i) {
-      for (int32_t j = i; j <= +20; ++j) {
-        __testMultiplyScales(i, j);
-      }
-    }
-  }
-
   void testDivideHigherScaleByLowerScale() {
     math::bigdecimal a("1", 2);
     math::bigdecimal b("1", 1);
@@ -192,4 +143,15 @@ public:
     math::bigdecimal b("1", 2);
     TS_ASSERT_EQUALS((a/b).to_string(), "10");
   }
+  
+  void testDivideByLargerNumberGivesDecimalIfScaleAllows() {
+    math::bigdecimal a("1" ,  0);
+    math::bigdecimal b("1" ,  1);
+    math::bigdecimal c("10",  0);
+    math::bigdecimal d("10", -1);
+    TS_ASSERT_EQUALS((a/c).to_string(), "0");
+    TS_ASSERT_EQUALS((a/d).to_string(), "0.1");
+    TS_ASSERT_EQUALS((b/c).to_string(), "0.1");
+    TS_ASSERT_EQUALS((b/d).to_string(), "0.10");
+  }
 };