ソースを参照

feat: implement proper UTF8 matching

Sam Jaffe 3 ヶ月 前
コミット
504147d3ed
6 ファイル変更145 行追加8 行削除
  1. 27 0
      include/abnf/code_point.h
  2. 5 3
      include/abnf/grammar.h
  3. 103 0
      src/code_point.cxx
  4. 4 1
      src/grammar.cxx
  5. 4 2
      src/io.cxx
  6. 2 2
      src/parser.cxx

+ 27 - 0
include/abnf/code_point.h

@@ -0,0 +1,27 @@
+#pragma once
+
+#include <string>
+
+#include <abnf/forward.h>
+
+namespace abnf {
+class code_point {
+public:
+  code_point(int value = 0) : value_(value) {}
+  explicit code_point(std::string_view str);
+
+  size_t width() const;
+
+  explicit operator int() const { return value_; }
+
+  operator std::string() const;
+
+private:
+  friend char_range parse_char_range(std::string_view token);
+  friend bool operator==(code_point const &, code_point const &) = default;
+  friend auto operator<=>(code_point const &, code_point const &) = default;
+
+private:
+  int value_; // 0x00000000 - 0x0010FFFF
+};
+}

+ 5 - 3
include/abnf/grammar.h

@@ -1,11 +1,13 @@
 #pragma once
 
+#include "abnf/code_point.h"
 #include <initializer_list>
 #include <limits>
 #include <map>
 #include <string>
 #include <vector>
 
+#include <abnf/code_point.h>
 #include <abnf/detail/iless.h>
 #include <abnf/forward.h>
 
@@ -38,15 +40,15 @@ struct reference {
   std::string value;
 };
 
-struct char_range { // TODO: UTF8/Codepoint handling
+struct char_range {
   char_range() = default;
   char_range(int val) : first(val), last(val) {}
   char_range(int first, int last) : first(first), last(last) {}
 
   friend bool operator==(char_range const &, char_range const &) = default;
 
-  int first;
-  int last;
+  code_point first;
+  code_point last;
 };
 
 struct repeated {

+ 103 - 0
src/code_point.cxx

@@ -0,0 +1,103 @@
+#include <abnf/code_point.h>
+
+#include <algorithm>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <cstring>
+#include <stdexcept>
+#include <string>
+
+namespace {
+struct bytes {
+  constexpr size_t ordinal() const { return width() - 1; }
+  constexpr size_t width() const { return b4 ? 4 : (b3 ? 3 : (b2 ? 2 : 1)); }
+
+  union {
+    struct {
+      uint32_t reserved : 8;
+      uint32_t u : 4;
+      uint32_t v : 4;
+      uint32_t w : 4;
+      uint32_t x : 4;
+      uint32_t y : 4;
+      uint32_t z : 4;
+    };
+    struct {
+      uint32_t b2 : 25;
+      uint32_t _2 : 7;
+    };
+    struct {
+      uint32_t b3 : 21;
+      uint32_t _3 : 11;
+    };
+    struct {
+      uint32_t b4 : 16;
+      uint32_t _4 : 16;
+    };
+  };
+};
+
+struct utf8_bits {
+  uint32_t filter;
+  uint32_t shift;
+  uint8_t bits[4];
+};
+
+constexpr size_t MAX_WIDTH = 4UL;
+}
+
+#define SHIFT(X, BY) (data.X << BY)
+#define SPLIT_SHIFT(X, BY)                                                     \
+  ((data.X & 0xF0) << (BY + 2)) | ((data.X & 0x0F) << (BY))
+
+static constexpr std::array<utf8_bits, MAX_WIDTH> g_bits{{
+    {.filter = 0b00000000'000000000'000000000'00000000,
+     .shift = 24,
+     .bits = {7, 0, 0, 0}},
+    {.filter = 0b110'00000'10'000000'00000000'00000000,
+     .shift = 16,
+     .bits = {5, 6, 0, 0}},
+    {.filter = 0b1110'0000'10'000000'10'000000'00000000,
+     .shift = 8,
+     .bits = {4, 6, 6, 0}},
+    {.filter = 0b11110'000'10'000000'10'000000'10'000000,
+     .shift = 0,
+     .bits = {3, 6, 6, 6}},
+}};
+
+namespace abnf {
+code_point::code_point(std::string_view str) {
+  char data[MAX_WIDTH] = {'\0'};
+  std::memcpy(&data, str.data(), std::min(sizeof(data), str.size()));
+  for (auto const & [filter, shift, bits] : g_bits) {
+    if ((((*reinterpret_cast<uint32_t *>(data) & filter) >> shift) << shift) !=
+        filter) {
+      continue;
+    }
+    for (size_t i = MAX_WIDTH, counter = 0; i-- > 0; counter += bits[i]) {
+      char byte = ((0xFF << bits[i]) ^ 0xFF) & data[i];
+      value_ |= (byte << counter);
+    }
+  }
+}
+
+size_t code_point::width() const {
+  return reinterpret_cast<bytes const &>(value_).width();
+}
+
+code_point::operator std::string() const {
+  bytes const data = reinterpret_cast<bytes const &>(value_);
+  if (data.reserved) {
+    throw std::domain_error("Illegal Codepoint (>0x10FFFF)");
+  }
+
+  uint32_t bdata = (SHIFT(u, 26) | SPLIT_SHIFT(v, 20) | SHIFT(w, 16) |
+                    SHIFT(x, 10) | SHIFT(z, 0));
+  bdata |= data.b2 ? SPLIT_SHIFT(y, 4) : SHIFT(y, 4);
+
+  auto [filter, shift, _] = g_bits.at(data.ordinal());
+  bdata = (bdata | filter) << shift;
+  return {reinterpret_cast<char *>(&bdata), data.width()};
+}
+}

+ 4 - 1
src/grammar.cxx

@@ -1,3 +1,4 @@
+#include "abnf/code_point.h"
 #include <abnf/grammar.h>
 
 #include <initializer_list>
@@ -71,7 +72,9 @@ auto grammar::satisfies(std::string_view text, literal const & lit) const
 template <>
 auto grammar::satisfies(std::string_view text, char_range const & rng) const
     -> satisfies_result {
-  if (text[0] >= rng.first && text[0] <= rng.last) { return 1UL; }
+  if (code_point cp(text); cp >= rng.first && cp <= rng.last) {
+    return cp.width();
+  }
   return false;
 }
 

+ 4 - 2
src/io.cxx

@@ -11,8 +11,10 @@ std::ostream & operator<<(std::ostream & os, rule const & rule) {
       [&os](literal const & l) { os << '"' << l.value << '"'; },
       [&os](reference const & r) { os << '<' << r.value << '>'; },
       [&os](char_range const & r) {
-        os << "%x" << std::hex << r.first << std::dec;
-        if (r.first != r.last) { os << '-' << std::hex << r.last << std::dec; }
+        os << "%x" << std::hex << int(r.first) << std::dec;
+        if (r.first != r.last) {
+          os << '-' << std::hex << int(r.last) << std::dec;
+        }
       },
       [&os](repeated const & r) {
         if (r.min == 0 && r.max == 1) {

+ 2 - 2
src/parser.cxx

@@ -53,9 +53,9 @@ char_range parse_char_range(std::string_view token) {
   char_range rval;
   token.remove_prefix(2);
   char const * const last = token.end();
-  auto [end, ec] = std::from_chars(token.data(), last, rval.first, 16);
+  auto [end, ec] = std::from_chars(token.data(), last, rval.first.value_, 16);
   if (*end == '-') {
-    ec = std::from_chars(end + 1, last, rval.last, 16).ec;
+    ec = std::from_chars(end + 1, last, rval.last.value_, 16).ec;
   } else {
     rval.last = rval.first;
   }