瀏覽代碼

feat: add Boolean distribution

Sam Jaffe 2 年之前
父節點
當前提交
7e1868948a
共有 2 個文件被更改,包括 8 次插入0 次删除
  1. 6 0
      include/random/distribution.h
  2. 2 0
      src/distribution.cxx

+ 6 - 0
include/random/distribution.h

@@ -54,4 +54,10 @@ class Index : public Uniform<size_t> {
 public:
   explicit Index(size_t size) : Uniform(0, size - 1) {}
 };
+
+class Boolean : public Uniform<int> {
+public:
+  Boolean() : Uniform(0, 1) {}
+  void describe(std::ostream &os) const override;
+};
 }

+ 2 - 0
src/distribution.cxx

@@ -46,5 +46,7 @@ template <typename Rand> void Uniform<Rand>::describe(std::ostream & os) const {
   os << typeid(Rand).name() << "Uniform[" << min_ << ',' << max_ << close;
 }
 
+void Boolean::describe(std::ostream &os) const { os << "Boolean"; }
+
 IMPLEMENT_DEVICE(INSTANTIATE_DISTRIBUTION_TEMPLATES)
 }