Ver código fonte

feat: add the ability to stack Tapes on top of one another
refactor: partially seal Tape

Sam Jaffe 2 anos atrás
pai
commit
dd6a208b1d
3 arquivos alterados com 6 adições e 1 exclusões
  1. 1 0
      include/random/random.h
  2. 4 1
      include/random/tape.h
  3. 1 0
      src/tape.cxx

+ 1 - 0
include/random/random.h

@@ -28,5 +28,6 @@ public:
    * @return A scope object that stops recording once it is descructed
    */
   scope_exit record(std::shared_ptr<Tape> tape);
+  std::shared_ptr<Tape> tape() const { return tape_; }
 };
 }

+ 4 - 1
include/random/tape.h

@@ -22,19 +22,22 @@ class Tape : public Device {
 public:
   // Allow for the serialization and de-serialization of this Tape
   using serial_type = std::vector<std::pair<std::string, uint64_t>>;
+  friend class Random;
 
 private:
   size_t index_{0}; // transient
   serial_type entries_;
+  std::shared_ptr<Tape> child_{nullptr};
 
 public:
   Tape() = default;
+  Tape(std::shared_ptr<Tape> child) : child_(child) {}
   Tape(serial_type serial) : entries_(std::move(serial)) {}
   explicit operator serial_type() const { return entries_; }
 
+private:
   IMPLEMENT_DEVICE(DECLARE_TAPE)
 
-private:
   result_type operator()() final { throw; }
   template <typename T> T poll(std::string const & dist);
   template <typename T> T operator()(Distribution<T> const & dist);

+ 1 - 0
src/tape.cxx

@@ -15,6 +15,7 @@
   T Tape::apply(Distribution<T> const & dist) { return (*this)(dist); }        \
   T Tape::apply(Distribution<T> const & dist, T value) {                       \
     entries_.emplace_back(to_string(dist), ByteSerial<T>()(value));            \
+    if (child_) child_->apply(dist, value);                                    \
     return value;                                                              \
   }