trie_iterator.hpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. //
  2. // trie_iterator.hpp
  3. // trie
  4. //
  5. // Created by Sam Jaffe on 6/16/17.
  6. //
  7. #pragma once
  8. #include "trie.hpp"
  9. #include <stack>
  10. #include <vector>
  11. #include "iterator/end_aware_iterator.hpp"
  12. namespace detail {
  13. template <typename Iter> struct trie_iterator_next {
  14. using iter_t = ::iterator::end_aware_iterator<Iter>;
  15. template <typename Trie> iter_t operator()(Trie & tr) {
  16. return {tr.local_begin(), tr.local_end()};
  17. }
  18. };
  19. template <typename Iter>
  20. struct trie_iterator_next<std::reverse_iterator<Iter>> {
  21. using iter_t = ::iterator::end_aware_iterator<std::reverse_iterator<Iter>>;
  22. template <typename Trie> iter_t operator()(Trie & tr) {
  23. return {tr.local_rbegin(), tr.local_rend()};
  24. }
  25. };
  26. template <typename Trie, typename Iter> class trie_iterator_base {
  27. public:
  28. using reference = decltype(std::declval<Trie>().value());
  29. using value_type = std::remove_reference_t<reference>;
  30. using pointer = value_type *;
  31. using difference_type = std::ptrdiff_t;
  32. using iterator_category = std::forward_iterator_tag;
  33. private:
  34. std::stack<Trie *> parent_trie_;
  35. std::vector<typename Trie::key_type> keys_;
  36. std::stack<iterator::end_aware_iterator<Iter>> iterators_;
  37. bool done_{false};
  38. public:
  39. trie_iterator_base() : done_{true} {}
  40. trie_iterator_base(Trie * tr) { parent_trie_.push(tr); }
  41. auto operator*() const -> reference { return root(); }
  42. auto operator-> () const -> pointer { return std::addressof(operator*()); }
  43. Trie & root() const {
  44. if (parent_trie_.empty()) {
  45. throw std::runtime_error("Dereferencing invalid iterator");
  46. }
  47. return *parent_trie_.top();
  48. }
  49. trie_iterator_base parent() const {
  50. trie_iterator_base tmp{*this};
  51. tmp.pop();
  52. return tmp;
  53. }
  54. protected:
  55. iterator::end_aware_iterator<Iter> & current() { return iterators_.top(); }
  56. bool empty() const { return iterators_.empty(); }
  57. void done(bool new_done) { done_ = new_done; }
  58. bool done() const { return done_; }
  59. void push(iterator::end_aware_iterator<Iter> it) {
  60. if (it.done()) {
  61. done_ = true;
  62. return;
  63. }
  64. iterators_.push(it);
  65. keys_.push_back(it->first);
  66. parent_trie_.push(it->second.get());
  67. }
  68. void pop() {
  69. parent_trie_.pop();
  70. iterators_.pop();
  71. keys_.pop_back();
  72. }
  73. bool poping_empty() {
  74. if (!current().done()) { return false; }
  75. pop();
  76. return !empty();
  77. }
  78. void assign() {
  79. if (empty()) { return; }
  80. parent_trie_.top() = current()->second.get();
  81. keys_.back() = current()->first;
  82. }
  83. bool can_recurse() { return !next().done(); }
  84. void recurse() { push(next()); }
  85. iterator::end_aware_iterator<Iter> next() {
  86. return trie_iterator_next<Iter>()(root());
  87. }
  88. friend bool operator!=(trie_iterator_base const & lhs,
  89. trie_iterator_base const & rhs) {
  90. return !(lhs == rhs);
  91. }
  92. friend bool operator==(trie_iterator_base const & lhs,
  93. trie_iterator_base const & rhs) {
  94. return (lhs.done_ && rhs.done_) ||
  95. (lhs.keys_ == rhs.keys_ && lhs.done_ == rhs.done_ && *lhs == *rhs);
  96. }
  97. friend Trie;
  98. template <typename Self, typename _Trie, typename KS>
  99. friend Self find_impl(_Trie * tr, KS const & keys);
  100. };
  101. }
  102. template <typename Trie, typename Iter>
  103. class trie_iterator : public detail::trie_iterator_base<Trie, Iter> {
  104. private:
  105. using super = detail::trie_iterator_base<Trie, Iter>;
  106. public:
  107. trie_iterator() : super() {}
  108. trie_iterator(Trie * tr) : super(tr) {}
  109. trie_iterator(super && take) : super(std::move(take)) {}
  110. trie_iterator & operator++() {
  111. if (super::done()) return *this;
  112. if (super::empty() || super::can_recurse()) {
  113. super::recurse();
  114. } else {
  115. advance();
  116. }
  117. return *this;
  118. }
  119. private:
  120. void advance() {
  121. ++super::current();
  122. while (super::poping_empty()) {
  123. ++super::current();
  124. }
  125. super::assign();
  126. super::done(super::empty());
  127. }
  128. };
  129. template <typename Trie, typename Iter>
  130. class trie_reverse_iterator : public detail::trie_iterator_base<Trie, Iter> {
  131. private:
  132. using super = detail::trie_iterator_base<Trie, Iter>;
  133. public:
  134. trie_reverse_iterator() : super() {}
  135. trie_reverse_iterator(Trie * tr) : super(tr) {
  136. if (super::can_recurse()) recurse();
  137. }
  138. trie_reverse_iterator & operator++() {
  139. if (super::done() || super::empty()) {
  140. super::done(true);
  141. } else {
  142. advance();
  143. }
  144. return *this;
  145. }
  146. private:
  147. void recurse() {
  148. do {
  149. super::recurse();
  150. } while (super::can_recurse());
  151. }
  152. void advance() {
  153. ++super::current();
  154. if (super::current().done()) {
  155. while (super::poping_empty())
  156. ;
  157. super::assign();
  158. } else {
  159. super::assign();
  160. if (super::can_recurse()) recurse();
  161. }
  162. }
  163. };