Explorar o código

breaking/feat: add Distribution class to allow non-uniform distributions

Sam Jaffe %!s(int64=2) %!d(string=hai) anos
pai
achega
d528ca2813

+ 20 - 4
include/random/device.h

@@ -7,11 +7,27 @@
 
 #pragma once
 
+#include <random/distribution.h>
+#include <random/forwards.h>
+
+#define APPLY_DEVICE(T)                                                        \
+  virtual T apply(Distribution<T> const & dist) { return dist(*this); }
+
 namespace engine::random {
-struct Device {
+class Device {
+public:
+  using result_type = uint32_t;
+  constexpr static result_type min() { return 0; }
+  constexpr static result_type max() { return ~0u; }
+
   virtual ~Device() = default;
-  virtual int32_t inclusive(int32_t min, int32_t max) = 0;
-  virtual uint32_t inclusive(uint32_t min, uint32_t max) = 0;
-  virtual double exclusive(double min, double max) = 0;
+  virtual result_type operator()() = 0;
+
+private:
+  friend class Random;
+  APPLY_DEVICE(size_t)
+  APPLY_DEVICE(double)
+  APPLY_DEVICE(int32_t)
+  APPLY_DEVICE(uint32_t)
 };
 }

+ 57 - 0
include/random/distribution.h

@@ -0,0 +1,57 @@
+//
+//  distribution.h
+//  shared_random_generator
+//
+//  Created by Sam Jaffe on 3/25/23.
+//  Copyright © 2023 Sam Jaffe. All rights reserved.
+//
+
+#pragma once
+
+#include <iosfwd>
+#include <type_traits>
+
+#include <random/forwards.h>
+
+namespace engine::random {
+template <typename Rand> class Distribution {
+public:
+  using result_type = Rand;
+
+  virtual ~Distribution() = default;
+  virtual result_type operator()(Device & device) const = 0;
+  // Allow us to operator on rvalue Devices as well
+  result_type operator()(Device && device) const { return (*this)(device); }
+
+private:
+  virtual void describe(std::ostream & os) const {}
+  friend std::ostream & operator<<(std::ostream & os,
+                                   Distribution const & dist) {
+    dist.describe(os);
+    return os;
+  }
+};
+
+template <typename Rand> class Uniform : public Distribution<Rand> {
+public:
+  using result_type = typename Distribution<Rand>::result_type;
+
+  Uniform() : Uniform(0) {}
+  Uniform(Rand min);
+  Uniform(Rand min, Rand max);
+
+  result_type operator()(Device & device) const override;
+
+private:
+  void describe(std::ostream &) const override;
+
+private:
+  Rand min_;
+  Rand max_;
+};
+
+class Index : public Uniform<size_t> {
+public:
+  explicit Index(size_t size) : Uniform(0, size - 1) {}
+};
+}

+ 7 - 0
include/random/forwards.h

@@ -10,8 +10,15 @@
 
 class scope_exit;
 
+/**
+ * Random(Device):
+ *   operator()(Distribution<R>) => R where R is double, uint32_t, or int32_t
+ */
+
 namespace engine::random {
 class Device;
+template <typename> class Distribution;
+class Mock;
 class Random;
 class Tape;
 }

+ 39 - 0
include/random/mock/mock_device.h

@@ -0,0 +1,39 @@
+//
+//  mock_device.h
+//  shared_random_generator
+//
+//  Created by Sam Jaffe on 3/25/23.
+//  Copyright © 2023 Sam Jaffe. All rights reserved.
+//
+
+#pragma once
+
+#include <sstream>
+
+#include <gmock/gmock.h>
+#include <random/device.h>
+#include <random/distribution.h>
+
+#ifndef CONCAT
+#define CONCAT2(A, B) A##B
+#define CONCAT(A, B) CONCAT2(A, B)
+#endif
+
+#define MOCK_APPLY_DEVICE(T)                                                   \
+  T apply(Distribution<T> const & dist) final {                                \
+    std::stringstream ss;                                                      \
+    ss << dist;                                                                \
+    return CONCAT(apply_, T)(ss.str());                                        \
+  }                                                                            \
+  MOCK_METHOD1(CONCAT(apply_, T), T(std::string const &))
+
+namespace engine::random {
+class MockDevice : public Device {
+public:
+  result_type operator()() final { return apply_uint32_t("?"); }
+  MOCK_APPLY_DEVICE(size_t);
+  MOCK_APPLY_DEVICE(double);
+  MOCK_APPLY_DEVICE(int32_t);
+  MOCK_APPLY_DEVICE(uint32_t);
+};
+}

+ 2 - 10
include/random/random.h

@@ -18,16 +18,8 @@ public:
   template <typename P> Random(P const & p) : impl_(std::make_shared<P>(p)) {}
   template <typename P> Random(std::shared_ptr<P> const & p) : impl_(p) {}
 
-  int32_t exclusive(int32_t min, int32_t max) const;
-  int32_t inclusive(int32_t min, int32_t max) const;
-
-  uint32_t exclusive(uint32_t min, uint32_t max) const;
-  uint32_t inclusive(uint32_t min, uint32_t max) const;
-
-  double exclusive(double min, double max) const;
-  double inclusive(double min, double max) const;
-
-  std::shared_ptr<Device> device() const { return impl_; }
+  template <typename Rand>
+  Rand operator()(Distribution<Rand> const & dist) const;
 
   /**
    * Create a scope in which all calls to the generator functions will record

+ 7 - 35
include/random/tape.h

@@ -8,9 +8,7 @@
 
 #pragma once
 
-#include <string>
-#include <tuple>
-#include <variant>
+#include <utility>
 #include <vector>
 
 #include <random/device.h>
@@ -19,44 +17,18 @@ namespace engine::random {
 class Tape : public Device {
 public:
   // Allow for the serialization and de-serialization of this Tape
-  using serial_type =
-      std::vector<std::tuple<std::string, uint64_t, uint64_t, uint64_t>>;
-
-private:
-  template <typename T> struct EntryT {
-    EntryT(T min, T max) : min(min), max(max), result() {}
-    EntryT(T min, T max, T result) : min(min), max(max), result(result) {}
-    EntryT(uint64_t min, uint64_t max, uint64_t result);
-
-    T min;
-    T max;
-    T result;
-  };
-
-  using Entry = std::variant<EntryT<double>, EntryT<uint32_t>, EntryT<int32_t>>;
+  using serial_type = std::vector<result_type>;
 
 private:
   size_t index_{0}; // transient
-  std::vector<Entry> entries_;
+  std::vector<result_type> entries_;
 
 public:
   Tape() = default;
-  Tape(serial_type serial);
-  explicit operator serial_type() const;
+  Tape(serial_type serial) : entries_(std::move(serial)) {}
+  explicit operator serial_type() const { return entries_; }
 
-  int32_t inclusive(int32_t min, int32_t max) override {
-    return fetch(min, max);
-  }
-  uint32_t inclusive(uint32_t min, uint32_t max) override {
-    return fetch(min, max);
-  }
-  double exclusive(double min, double max) override { return fetch(min, max); }
-
-  void inclusive(int32_t min, int32_t max, int32_t result);
-  void inclusive(uint32_t min, uint32_t max, uint32_t result);
-  void exclusive(double min, double max, double result);
-
-private:
-  template <typename T> T fetch(T min, T max);
+  result_type operator()() override;
+  result_type operator()(result_type result);
 };
 }

+ 1 - 4
include/random/thread_safe.h

@@ -21,10 +21,7 @@ private:
 
 public:
   ThreadSafeDevice(std::unique_ptr<Device> impl);
-
-  int32_t inclusive(int32_t min, int32_t max) final;
-  uint32_t inclusive(uint32_t min, uint32_t max) final;
-  double exclusive(double min, double max) final;
+  uint32_t operator()() final;
 };
 
 }

+ 20 - 0
shared_random_generator.xcodeproj/project.pbxproj

@@ -10,6 +10,8 @@
 		CD89E51824E6F3FD008167A8 /* libshared_random_generator.dylib in Frameworks */ = {isa = PBXBuildFile; fileRef = CDED6A4221B2F5A700AB91D0 /* libshared_random_generator.dylib */; };
 		CD89E51F24E6F40B008167A8 /* GoogleMock.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = CD89E50824E6F3EA008167A8 /* GoogleMock.framework */; };
 		CD89E52124E6F424008167A8 /* random_test.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CD89E52024E6F424008167A8 /* random_test.cxx */; };
+		CD8EA29E29CF57C500D45186 /* distribution.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CD8EA29D29CF57C500D45186 /* distribution.cxx */; };
+		CD8EA2A229CF79D200D45186 /* distribution_test.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CD8EA2A129CF79D200D45186 /* distribution_test.cxx */; };
 		CDD2138229C76EF500A4582C /* tape_test.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CDD2138129C76EF500A4582C /* tape_test.cxx */; };
 		CDE943B829C75E170086A8CA /* tape.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CDE943B629C75E170086A8CA /* tape.cxx */; };
 		CDE943B929C75E170086A8CA /* tape.h in Headers */ = {isa = PBXBuildFile; fileRef = CDE943B729C75E170086A8CA /* tape.h */; };
@@ -103,6 +105,10 @@
 		CD89E51324E6F3FD008167A8 /* shared_random_generator-test.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "shared_random_generator-test.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
 		CD89E51724E6F3FD008167A8 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
 		CD89E52024E6F424008167A8 /* random_test.cxx */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = random_test.cxx; sourceTree = "<group>"; };
+		CD8EA27A29CF4C0500D45186 /* distribution.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = distribution.h; sourceTree = "<group>"; };
+		CD8EA29D29CF57C500D45186 /* distribution.cxx */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = distribution.cxx; sourceTree = "<group>"; };
+		CD8EA2A029CF767100D45186 /* mock_device.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = mock_device.h; sourceTree = "<group>"; };
+		CD8EA2A129CF79D200D45186 /* distribution_test.cxx */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = distribution_test.cxx; sourceTree = "<group>"; };
 		CDD2137C29C76E5600A4582C /* xcode_gtest_helper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = xcode_gtest_helper.h; sourceTree = "<group>"; };
 		CDD2138129C76EF500A4582C /* tape_test.cxx */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = tape_test.cxx; sourceTree = "<group>"; };
 		CDD2138329C7723700A4582C /* expect.xcodeproj */ = {isa = PBXFileReference; lastKnownFileType = "wrapper.pb-project"; name = expect.xcodeproj; path = external/expect/expect.xcodeproj; sourceTree = "<group>"; };
@@ -165,6 +171,14 @@
 			name = Frameworks;
 			sourceTree = "<group>";
 		};
+		CD8EA29F29CF764400D45186 /* mock */ = {
+			isa = PBXGroup;
+			children = (
+				CD8EA2A029CF767100D45186 /* mock_device.h */,
+			);
+			path = mock;
+			sourceTree = "<group>";
+		};
 		CDD2138429C7723700A4582C /* Products */ = {
 			isa = PBXGroup;
 			children = (
@@ -194,8 +208,10 @@
 		CDE943B229C75C610086A8CA /* random */ = {
 			isa = PBXGroup;
 			children = (
+				CD8EA29F29CF764400D45186 /* mock */,
 				CDE943B529C75CD70086A8CA /* forwards.h */,
 				CDE943B329C75C610086A8CA /* device.h */,
+				CD8EA27A29CF4C0500D45186 /* distribution.h */,
 				CDE943B429C75C610086A8CA /* random.h */,
 				CDE943B729C75E170086A8CA /* tape.h */,
 				CDE943BA29C762C00086A8CA /* thread_safe.h */,
@@ -231,6 +247,7 @@
 		CDED6A4921B2F5DF00AB91D0 /* src */ = {
 			isa = PBXGroup;
 			children = (
+				CD8EA29D29CF57C500D45186 /* distribution.cxx */,
 				CDED6A4E21B2F62C00AB91D0 /* random.cxx */,
 				CDE943B629C75E170086A8CA /* tape.cxx */,
 				CDE943BB29C7643E0086A8CA /* thread_safe.cxx */,
@@ -242,6 +259,7 @@
 			isa = PBXGroup;
 			children = (
 				CDD2137C29C76E5600A4582C /* xcode_gtest_helper.h */,
+				CD8EA2A129CF79D200D45186 /* distribution_test.cxx */,
 				CD89E52024E6F424008167A8 /* random_test.cxx */,
 				CDD2138129C76EF500A4582C /* tape_test.cxx */,
 			);
@@ -424,6 +442,7 @@
 			isa = PBXSourcesBuildPhase;
 			buildActionMask = 2147483647;
 			files = (
+				CD8EA2A229CF79D200D45186 /* distribution_test.cxx in Sources */,
 				CD89E52124E6F424008167A8 /* random_test.cxx in Sources */,
 				CDD2138229C76EF500A4582C /* tape_test.cxx in Sources */,
 			);
@@ -434,6 +453,7 @@
 			buildActionMask = 2147483647;
 			files = (
 				CDE943BC29C7643E0086A8CA /* thread_safe.cxx in Sources */,
+				CD8EA29E29CF57C500D45186 /* distribution.cxx in Sources */,
 				CDED6A4F21B2F62C00AB91D0 /* random.cxx in Sources */,
 				CDE943B829C75E170086A8CA /* tape.cxx in Sources */,
 			);

+ 42 - 0
src/distribution.cxx

@@ -0,0 +1,42 @@
+//
+//  distribution.cpp
+//  shared_random_generator
+//
+//  Created by Sam Jaffe on 3/25/23.
+//  Copyright © 2023 Sam Jaffe. All rights reserved.
+//
+
+#include "random/distribution.h"
+
+#include <iostream>
+#include <random>
+
+#include "random/device.h"
+
+namespace engine::random {
+template <typename Rand>
+Uniform<Rand>::Uniform(Rand min)
+    : Uniform(min, std::numeric_limits<Rand>::max()) {}
+
+template <typename Rand>
+Uniform<Rand>::Uniform(Rand min, Rand max) : min_(min), max_(max) {}
+
+template <typename Rand>
+auto Uniform<Rand>::operator()(Device & device) const -> result_type {
+  if constexpr (std::is_floating_point_v<Rand>) {
+    return std::uniform_real_distribution<Rand>(min_, max_)(device);
+  } else {
+    return std::uniform_int_distribution<Rand>(min_, max_)(device);
+  }
+}
+
+template <typename Rand> void Uniform<Rand>::describe(std::ostream & os) const {
+  constexpr char close = std::is_floating_point_v<Rand> ? ')' : ']';
+  os << "Uniform[" << min_ << ',' << max_ << close;
+}
+
+template class Uniform<size_t>;
+template class Uniform<double>;
+template class Uniform<int32_t>;
+template class Uniform<uint32_t>;
+}

+ 19 - 48
src/random.cxx

@@ -7,36 +7,26 @@
 #include <scope_guard/scope_guard.hpp>
 
 #include "random/device.h"
+#include "random/distribution.h"
 #include "random/tape.h"
 
-#define EVALUATE(func, ...)                                                    \
-  [&, this]() {                                                                \
-    auto result = impl_->func(__VA_ARGS__);                                    \
-    if (tape_) { tape_->func(__VA_ARGS__, result); }                           \
-    return result;                                                             \
-  }()
-
 namespace engine::random {
 class DefaultDevice : public Device {
 public:
-  int32_t inclusive(int32_t min, int32_t max) override {
-    return std::uniform_int_distribution<>(min, max)(rng);
-  }
+  result_type operator()() final { return rng(); }
 
-  uint32_t inclusive(uint32_t min, uint32_t max) override {
-    return std::uniform_int_distribution<uint32_t>(min, max)(rng);
-  }
+private:
+  std::mt19937 rng{std::random_device{}()};
+};
 
-  double exclusive(double min, double max) override {
-    return uniform_real(min, max)(rng);
-  }
+class Link : public Device {
+public:
+  Link(Tape & tape, Device & to) : tape_(tape), device_(to) {}
+  result_type operator()() { return tape_(device_()); }
 
 private:
-  using engine = std::mt19937;
-  using uniform = std::uniform_int_distribution<typename engine::result_type>;
-  using uniform_real = std::uniform_real_distribution<>;
-
-  engine rng{std::random_device{}()};
+  Tape & tape_;
+  Device & device_;
 };
 
 Random::Random() : impl_(std::make_shared<DefaultDevice>()) {}
@@ -47,33 +37,9 @@ Random::Random(Random const & other, std::shared_ptr<Tape> with_tape)
 Random::Random(Random const & other, Tape & with_tape)
     : impl_(other.impl_), tape_(&with_tape, [](void *) {}) {}
 
-int32_t Random::exclusive(int32_t min, int32_t max) const {
-  return inclusive(min, max - 1);
-}
-
-int32_t Random::inclusive(int32_t min, int32_t max) const {
-  assert(min < max);
-  return EVALUATE(inclusive, min, max);
-}
-
-uint32_t Random::exclusive(uint32_t min, uint32_t max) const {
-  assert(max != 0);
-  return inclusive(min, max - 1);
-}
-
-uint32_t Random::inclusive(uint32_t min, uint32_t max) const {
-  assert(min < max);
-  return EVALUATE(inclusive, min, max);
-}
-
-double Random::exclusive(double min, double max) const {
-  assert(min < max);
-  return EVALUATE(exclusive, min, max);
-}
-
-double Random::inclusive(double min, double max) const {
-  double real_max = std::nextafter(max, std::numeric_limits<double>::max());
-  return std::min(exclusive(min, real_max), max);
+template <typename Rand>
+Rand Random::operator()(Distribution<Rand> const & dist) const {
+  return tape_ ? dist(Link(*tape_, *impl_)) : impl_->apply(dist);
 }
 
 scope_exit Random::record(std::shared_ptr<Tape> tape) {
@@ -81,4 +47,9 @@ scope_exit Random::record(std::shared_ptr<Tape> tape) {
   tape_ = tape;
   return [this]() { tape_.reset(); };
 }
+
+template size_t Random::operator()(Distribution<size_t> const &) const;
+template double Random::operator()(Distribution<double> const &) const;
+template int32_t Random::operator()(Distribution<int32_t> const &) const;
+template uint32_t Random::operator()(Distribution<uint32_t> const &) const;
 }

+ 7 - 54
src/tape.cxx

@@ -9,64 +9,17 @@
 #include "random/tape.h"
 
 #include <expect/expect.hpp>
-
-#define IS_A(T) tname == typeid(T).name()
-
-namespace {
-uint64_t r(uint32_t j) { return j; }
-uint64_t r(int32_t i) { return static_cast<uint64_t>(i); }
-uint64_t r(double d) { return reinterpret_cast<uint64_t &>(d); }
-double r(uint64_t ul) { return reinterpret_cast<double &>(ul); }
-}
+#include <scope_guard/scope_guard.hpp>
 
 namespace engine::random {
-template <typename T>
-Tape::EntryT<T>::EntryT(uint64_t min, uint64_t max, uint64_t result)
-    : min(static_cast<T>(min)), max(static_cast<T>(max)),
-      result(static_cast<T>(result)) {}
-
-Tape::Tape(serial_type serial) {
-  for (auto const & [tname, min, max, result] : serial) {
-    if (IS_A(uint32_t)) {
-      entries_.emplace_back(EntryT<uint32_t>(min, max, result));
-    } else if (IS_A(int32_t)) {
-      entries_.emplace_back(EntryT<int32_t>(min, max, result));
-    } else {
-      entries_.emplace_back(EntryT<double>(r(min), r(max), r(result)));
-    }
-  }
-}
-
-Tape::operator serial_type() const {
-  serial_type rval;
-  auto insert = [&rval](auto const & e) {
-    rval.emplace_back(typeid(e.min).name(), r(e.min), r(e.max), r(e.result));
-  };
-  for (auto const & entry : entries_) {
-    std::visit(insert, entry);
-  }
-  return rval;
-}
-
-void Tape::inclusive(int32_t min, int32_t max, int32_t result) {
-  entries_.emplace_back(EntryT(min, max, result));
-}
-
-void Tape::inclusive(uint32_t min, uint32_t max, uint32_t result) {
-  entries_.emplace_back(EntryT(min, max, result));
-}
-
-void Tape::exclusive(double min, double max, double result) {
-  entries_.emplace_back(EntryT(min, max, result));
+auto Tape::operator()(result_type result) -> result_type {
+  entries_.push_back(result);
+  return result;
 }
 
-template <typename T> T Tape::fetch(T min, T max) {
+auto Tape::operator()() -> result_type {
   expects(index_ < entries_.size(), std::out_of_range, "End of Tape");
-  auto const * entry = std::get_if<EntryT<T>>(&entries_[index_]);
-  expects(entry, std::domain_error, "Mismatched signature to Device on Tape");
-  expects(entry->min == min && entry->max == max,
-          "Mismatched bounds to Device on Tape");
-  ++index_;
-  return entry->result;
+  scope(exit) { ++index_; };
+  return entries_[index_];
 }
 }

+ 2 - 12
src/thread_safe.cxx

@@ -12,18 +12,8 @@ namespace engine::random {
 ThreadSafeDevice::ThreadSafeDevice(std::unique_ptr<Device> impl)
     : impl_(std::move(impl)) {}
 
-int32_t ThreadSafeDevice::inclusive(int32_t min, int32_t max) {
+uint32_t ThreadSafeDevice::operator()() {
   std::lock_guard lock(mutex_);
-  return impl_->exclusive(min, max);
-}
-
-uint32_t ThreadSafeDevice::inclusive(uint32_t min, uint32_t max) {
-  std::lock_guard lock(mutex_);
-  return impl_->exclusive(min, max);
-}
-
-double ThreadSafeDevice::exclusive(double min, double max) {
-  std::lock_guard lock(mutex_);
-  return impl_->exclusive(min, max);
+  return (*impl_)();
 }
 }

+ 74 - 0
test/distribution_test.cxx

@@ -0,0 +1,74 @@
+//
+//  distribution_test.cxx
+//  shared_random_generator-test
+//
+//  Created by Sam Jaffe on 3/25/23.
+//  Copyright © 2023 Sam Jaffe. All rights reserved.
+//
+
+#include <random/distribution.h>
+
+#include <random/random.h>
+
+#include "xcode_gtest_helper.h"
+
+using testing::DoubleNear;
+
+using engine::random::Random;
+using engine::random::Uniform;
+
+class kahan_summation {
+private:
+  double sum_{0.0};
+  double carry_{0.0};
+
+public:
+  kahan_summation() = default;
+  kahan_summation & operator+=(double d);
+
+  explicit operator double() const { return sum_; }
+};
+
+kahan_summation & kahan_summation::operator+=(double num) {
+  double const yield = num - carry_;
+  double const total = sum_ + yield;
+  carry_ = (total - sum_) - yield;
+  sum_ = total;
+  return *this;
+}
+
+TEST(UniformIntTest, UniformOnDefaultRandom) {
+  Random generator{};
+  kahan_summation sum{};
+  size_t const iters = 2000000;
+  for (size_t i = 0; i < iters; ++i) {
+    // All values from 0 - 100 are allowable
+    sum += generator(Uniform(0, 100));
+  }
+  // Expected result is (0+99)/2 +/- 0.1%
+  EXPECT_THAT(double(sum) / iters, DoubleNear(50, 0.05));
+}
+
+TEST(UniformIntTest, IsClosed) {
+  std::stringstream ss;
+  ss << Uniform(0, 10);
+  EXPECT_THAT(ss.str(), "Uniform[0,10]");
+}
+
+TEST(UniformRealTest, UniformOnDefaultRandom) {
+  Random generator{};
+  kahan_summation sum{};
+  size_t const iters = 2000000;
+  for (size_t i = 0; i < iters; ++i) {
+    // All values from 0 - 99 are allowable
+    sum += generator(Uniform(0.0, 100.0));
+  }
+  // Expected result is [0.0, 100.0) +/- 0.1%
+  EXPECT_THAT(double(sum) / iters, DoubleNear(50, 0.05));
+}
+
+TEST(UniformRealTest, IsOpen) {
+  std::stringstream ss;
+  ss << Uniform(0.0, 10.0);
+  EXPECT_THAT(ss.str(), "Uniform[0,10)");
+}

+ 18 - 85
test/random_test.cxx

@@ -8,112 +8,45 @@
 
 #include "random/random.h"
 
-#include "random/device.h"
+#include "random/distribution.h"
 
+#include "random/mock/mock_device.h"
 #include "xcode_gtest_helper.h"
 
-using testing::DoubleNear;
+using testing::_;
 
-struct MockDevice : engine::random::Device {
-  MOCK_METHOD2(s_inclusive, int32_t(int32_t, int32_t));
-  MOCK_METHOD2(u_inclusive, uint32_t(uint32_t, uint32_t));
-  MOCK_METHOD2(exclusive, double(double, double));
+using engine::random::Index;
+using engine::random::MockDevice;
+using engine::random::Uniform;
 
-  int32_t inclusive(int32_t min, int32_t max) { return s_inclusive(min, max); }
-  uint32_t inclusive(uint32_t min, uint32_t max) {
-    return u_inclusive(min, max);
-  }
-};
-
-TEST(RandomTest, ExclusiveIntegerPassedAsInclusive) {
+TEST(RandomTest, UsesInt32Type) {
   auto mock = std::make_shared<MockDevice>();
   engine::random::Random generator{mock};
 
-  EXPECT_CALL(*mock, s_inclusive(-10, 9)).Times(1);
-  generator.exclusive(-10, 10);
+  EXPECT_CALL(*mock, apply_int32_t(_)).Times(1);
+  generator(Uniform(-10, 10));
 }
 
-TEST(RandomTest, ExclusiveUIntegerPassedAsInclusive) {
+TEST(RandomTest, UsesUInt32Type) {
   auto mock = std::make_shared<MockDevice>();
   engine::random::Random generator{mock};
 
-  EXPECT_CALL(*mock, u_inclusive(0, 9)).Times(1);
-  generator.exclusive(0u, 10u);
+  EXPECT_CALL(*mock, apply_uint32_t(_)).Times(1);
+  generator(Uniform(0u, 10u));
 }
 
-TEST(RandomTest, PassesThroughExclusiveDoubleCall) {
+TEST(RandomTest, UsesSizeType) {
   auto mock = std::make_shared<MockDevice>();
   engine::random::Random generator{mock};
 
-  EXPECT_CALL(*mock, exclusive(1.0, 10.0)).Times(1);
-  generator.exclusive(1.0, 10.0);
+  EXPECT_CALL(*mock, apply_size_t(_)).Times(1);
+  generator(Index(10));
 }
 
-TEST(RandomTest, DoctorsInclusiveDoubleCall) {
+TEST(RandomTest, UsesDoubleType) {
   auto mock = std::make_shared<MockDevice>();
   engine::random::Random generator{mock};
 
-  EXPECT_CALL(*mock, exclusive(1.0, DoubleNear(10.0, 1e-7))).Times(1);
-  generator.inclusive(1.0, 10.0);
-}
-
-class kahan_summation {
-private:
-  double sum_{0.0};
-  double carry_{0.0};
-
-public:
-  kahan_summation() = default;
-  kahan_summation & operator+=(double d);
-
-  explicit operator double() const { return sum_; }
-};
-
-kahan_summation & kahan_summation::operator+=(double num) {
-  double const yield = num - carry_;
-  double const total = sum_ + yield;
-  carry_ = (total - sum_) - yield;
-  sum_ = total;
-  return *this;
-}
-
-TEST(DefaultRandomTest, RandomDistributionIntIsUniform) {
-  engine::random::Random generator{};
-  kahan_summation sum{};
-  size_t const iters = 2000000;
-  for (size_t i = 0; i < iters; ++i) {
-    // All values from 0 - 100 are allowable
-    sum += generator.inclusive(0, 100);
-  }
-  // Expected result is (0+99)/2 +/- 0.1%
-  EXPECT_THAT(double(sum) / iters, DoubleNear(50, 0.05));
-}
-
-TEST(DefaultRandomTest, RandomDistributionDblIsUniform) {
-  engine::random::Random generator{};
-  kahan_summation sum{};
-  size_t const iters = 2000000;
-  for (size_t i = 0; i < iters; ++i) {
-    // All values from 0 - 99 are allowable
-    sum += generator.exclusive(0.0, 100.0);
-  }
-  // Expected result is [0.0, 100.0) +/- 0.1%
-  EXPECT_THAT(double(sum) / iters, DoubleNear(50, 0.05));
-}
-
-TEST(DefaultRandomTest, InclusiveRangeMayIncludeValue) {
-  engine::random::Random generator{};
-  EXPECT_THAT(generator.inclusive(1.0, 1.0), DoubleNear(1.0, 1E-6));
-}
-
-TEST(DefaultRandomTest, RandomDistributionDblInclIsUniform) {
-  engine::random::Random generator{};
-  kahan_summation sum{};
-  size_t const iters = 2000000;
-  for (size_t i = 0; i < iters; ++i) {
-    // All values in [0.0, 100.0] are allowable
-    sum += generator.inclusive(0.0, 100.0);
-  }
-  // Expected result is (0+100)/2 +/- 0.1%
-  EXPECT_THAT(double(sum) / iters, DoubleNear(50, 0.05));
+  EXPECT_CALL(*mock, apply_double(_)).Times(1);
+  generator(Uniform(0.0, 10.0));
 }

+ 40 - 74
test/tape_test.cxx

@@ -19,104 +19,69 @@
 using testing::ElementsAre;
 using testing::FieldsAre;
 
-MATCHER_P(BytesAre, value, "") {
-  *result_listener << "whose bytes are " << std::hex << arg;
-  return reinterpret_cast<double const &>(arg) == value;
-}
-
-TEST(TapeTest, CanRecordIntData) {
-  engine::random::Tape tape;
+class StubDistribution : public engine::random::Distribution<int> {
+public:
+  StubDistribution(int value) : value_(value) {}
+  result_type operator()(engine::random::Device &) const { return value_; }
 
-  tape.inclusive(0u, 100u, 50u);
-  EXPECT_EQ(tape.inclusive(0u, 100u), 50);
-}
+private:
+  int value_;
+};
 
-TEST(TapeTest, CanRecordDoubleData) {
+TEST(TapeTest, CanRecord) {
   engine::random::Tape tape;
 
-  tape.exclusive(0.0, 100.0, 50.0);
-  EXPECT_EQ(tape.exclusive(0.0, 100.0), 50.0);
-
-  tape.exclusive(0.0, 100.0, 50.0);
-  EXPECT_EQ(tape.exclusive(0.0, 100.0), 50.0);
+  tape(50u);
+  EXPECT_EQ(tape(), 50);
 }
 
 TEST(TapeTest, ThrowsOnOverFetch) {
   engine::random::Tape tape;
-  tape.inclusive(0, 100, 50);
-  EXPECT_EQ(tape.inclusive(0, 100), 50);
-  EXPECT_THROW(tape.inclusive(0, 100), std::out_of_range);
-}
-
-TEST(TapeTest, BadRequestIsNoOp) {
-  engine::random::Tape tape;
-  tape.inclusive(0, 100, 50);
-  EXPECT_ANY_THROW(tape.inclusive(0, 1));
-  EXPECT_EQ(tape.inclusive(0, 100), 50);
-}
-
-TEST(TapeTest, ThrowsOnTypeMismatch) {
-  engine::random::Tape tape;
-  tape.inclusive(0u, 100u, 50u);
-  EXPECT_THROW(tape.exclusive(0.0, 100.0), std::domain_error);
-  EXPECT_THROW(tape.inclusive(0, 100), std::domain_error);
-  EXPECT_NO_THROW(tape.inclusive(0u, 100u));
-}
-
-TEST(TapeTest, ThrowsOnBoundsMismatch) {
-  engine::random::Tape tape;
-  tape.exclusive(0.0, 100.0, 50.0);
-  EXPECT_THROW(tape.exclusive(0.0, 1.0), std::logic_error);
+  tape(50);
+  EXPECT_EQ(tape(), 50);
+  EXPECT_THROW(tape(), std::out_of_range);
 }
 
 TEST(TapeTest, IsOrdered) {
   engine::random::Tape tape;
-  tape.exclusive(0.0, 100.0, 50.0);
-  tape.exclusive(10.0, 90.0, 50.0);
-
-  EXPECT_THROW(tape.exclusive(10.0, 90.0), std::logic_error);
+  tape(1);
+  tape(2);
 
-  EXPECT_EQ(tape.exclusive(0.0, 100.0), 50.0);
-  EXPECT_EQ(tape.exclusive(10.0, 90.0), 50.0);
+  EXPECT_EQ(tape(), 1);
+  EXPECT_EQ(tape(), 2);
 }
 
 TEST(TapeTest, IsSerializable) {
   engine::random::Tape tape;
-  tape.inclusive(0, 100, 50);
-  tape.inclusive(0u, 100u, 50u);
-  tape.exclusive(0.0, 100.0, 50.0);
+  tape(1);
+  tape(2);
+  tape(3);
 
   auto serial = engine::random::Tape::serial_type(tape);
-  EXPECT_THAT(
-      serial,
-      ElementsAre(FieldsAre("i", 0, 100, 50), FieldsAre("j", 0, 100, 50),
-                  FieldsAre("d", BytesAre(0), BytesAre(100), BytesAre(50))));
+  EXPECT_THAT(serial, ElementsAre(1, 2, 3));
 
   tape = engine::random::Tape(serial);
-  EXPECT_EQ(tape.inclusive(0, 100), 50);
-  EXPECT_EQ(tape.inclusive(0u, 100u), 50u);
-  EXPECT_EQ(tape.exclusive(0.0, 100.0), 50.0);
+  EXPECT_EQ(tape(), 1);
+  EXPECT_EQ(tape(), 2);
+  EXPECT_EQ(tape(), 3);
 }
 
 TEST(TapeTest, IndexNotIncludedInSerialization) {
   engine::random::Tape tape;
-  tape.inclusive(0, 100, 50);
-  tape.inclusive(0u, 100u, 50u);
-  tape.exclusive(0.0, 100.0, 50.0);
+  tape(1);
+  tape(2);
+  tape(3);
 
-  EXPECT_EQ(tape.inclusive(0, 100), 50);
-  EXPECT_EQ(tape.inclusive(0u, 100u), 50u);
+  EXPECT_EQ(tape(), 1);
+  EXPECT_EQ(tape(), 2);
 
   auto serial = engine::random::Tape::serial_type(tape);
-  EXPECT_THAT(
-      serial,
-      ElementsAre(FieldsAre("i", 0, 100, 50), FieldsAre("j", 0, 100, 50),
-                  FieldsAre("d", BytesAre(0), BytesAre(100), BytesAre(50))));
+  EXPECT_THAT(serial, ElementsAre(1, 2, 3));
 
   tape = engine::random::Tape(serial);
-  EXPECT_EQ(tape.inclusive(0, 100), 50);
-  EXPECT_EQ(tape.inclusive(0u, 100u), 50u);
-  EXPECT_EQ(tape.exclusive(0.0, 100.0), 50.0);
+  EXPECT_EQ(tape(), 1);
+  EXPECT_EQ(tape(), 2);
+  EXPECT_EQ(tape(), 3);
 }
 
 TEST(TapeTest, CanAttachToRandom) {
@@ -125,8 +90,9 @@ TEST(TapeTest, CanAttachToRandom) {
 
   auto scope = random.record(tape);
 
-  auto result = random.inclusive(0, 100);
-  EXPECT_EQ(tape->inclusive(0, 100), result);
+  auto result = random(StubDistribution(5));
+  EXPECT_EQ(result, 5);
+  EXPECT_EQ((*tape)(), 5);
 }
 
 TEST(TapeTest, StopsRecordingOnScopeExit) {
@@ -135,10 +101,10 @@ TEST(TapeTest, StopsRecordingOnScopeExit) {
 
   {
     auto scope = random.record(tape);
-    random.inclusive(0, 100);
+    random(StubDistribution(2));
   }
-  random.inclusive(0, 100);
+  random(StubDistribution(1));
 
-  EXPECT_NO_THROW(tape->inclusive(0, 100));
-  EXPECT_THROW(tape->inclusive(0, 100), std::out_of_range);
+  EXPECT_NO_THROW((*tape)());
+  EXPECT_THROW((*tape)(), std::out_of_range);
 }