Browse Source

feat: use bindings to implement type coercion automatically

Sam Jaffe 3 years ago
parent
commit
8926adb5ad

+ 21 - 1
include/reflection/forward.h

@@ -12,14 +12,17 @@
 #include <map>
 #include <string>
 #include <type_traits>
+#include <typeindex>
+#include <unordered_map>
 
 namespace reflection {
 class Object;
 template <typename Obj, typename = void> class Reflection;
 template <typename T> class Proxy;
-template <typename O, typename I> struct TypeConversion;
+template <typename T> class TypeCast;
 
 template <typename Func> using Cache = std::map<std::string_view, Func>;
+template <typename T> using TypeMap = std::unordered_map<std::type_index, T>;
 template <typename Obj>
 using Accessor = std::function<Object(Obj &, std::string)>;
 template <typename Obj>
@@ -27,3 +30,20 @@ using Getter = std::function<Object(Obj const &, std::string)>;
 }
 
 #define reflect(object) reflection::Object(object, #object)
+
+#define INSTANTIATE_REFLECTION(T)                                              \
+  template Cache<Accessor<T>> Reflection<T>::members_;                         \
+  template Cache<Getter<T>> Reflection<T>::const_members_
+
+#define INSTANTIATE_REFLECTION_TYPECAST(T)                                     \
+  template TypeMap<std::function<T(Object const &)>> TypeCast<T>::get_;        \
+  template TypeMap<std::function<void(Object &, T const &)>> TypeCast<T>::set_
+
+#define CONCAT_IMPL(A, B) A##B
+#define CONCAT(A, B) CONCAT_IMPL(A, B)
+
+#define REFLECTION(T)                                                          \
+  bool const CONCAT(reflection_, __LINE__) = ::reflection::Reflection<T>()
+
+#define REFLECTION_TYPE_CAST(T)                                                \
+  bool const CONCAT(type_cast_, __LINE__) = ::reflection::TypeCast<T>()

+ 37 - 16
include/reflection/object.h

@@ -22,13 +22,15 @@ public:
   template <typename T> Object(T && data, std::string name = "this");
   template <typename T> Object(Proxy<T> data, std::string name = "this");
 
+  template <typename V> Object & operator=(V const & value);
+
   Object own() const { return clone_(*this); }
 
   std::string_view name() const {
     return std::string_view(name_).substr(name_.rfind('.') + 1);
   }
   std::string_view path() const { return name_; }
-  char const * type() const { return type_.name(); }
+  std::type_index type() const { return type_; }
 
   template <typename T> bool is_a() const { return type_ == typeid(T); }
 
@@ -36,8 +38,10 @@ public:
   template <typename T> operator T const &() const & {
     return cast<T const &>();
   }
+  template <typename T> operator T() const & { return cast<T>(); }
+  template <typename T> operator T() & { return cast<T>(); }
   template <typename T> operator T() && {
-    return is_a<T>() ? std::move(cast<T &>()) : cast<T const &>();
+    return const_ ? cast<T>() : std::move(cast<T &>());
   }
 
   Object get(std::string_view id) & { return (this->*get_)(id); }
@@ -59,18 +63,7 @@ private:
     return o;
   }
 
-  template <typename T> T cast() const {
-    // Why can we decay this away?
-    using V = std::remove_const_t<std::remove_reference_t<T>>;
-    // 1) typeid does not respect const
-    if (!is_a<V>()) { throw std::bad_cast(); }
-    // 2) We guard against using mutable-reference on immutable Object here
-    if (std::is_same_v<T, V &> && const_) { throw std::bad_cast(); }
-    // 3) Proxy always contains mutable data
-    if (proxy_) { return T(*std::static_pointer_cast<Proxy<V>>(data_)); }
-    // 4) The const will be re-added by type coercion
-    return *std::static_pointer_cast<V>(data_);
-  }
+  template <typename T> T cast() const;
 
   template <typename T> Object getter(std::string_view id) const;
   template <typename T> Object accessor(std::string_view id) const;
@@ -81,8 +74,8 @@ private:
 private:
   std::type_index type_; //
   std::string name_;     // The property name as a dot-separated path
-  bool proxy_{
-      false}; // Are we using a proxy class to store non-trivial setter behavior
+  // Are we using a proxy class to store non-trivial setter behavior
+  bool proxy_{false};
   bool const_; // Is the object immutable (const-ref or rvalue)
   std::shared_ptr<void> data_;
   Object (Object::*get_)(std::string_view) const;
@@ -91,6 +84,7 @@ private:
 }
 
 #include "reflection/reflection.h"
+#include "reflection/typecast.h"
 
 namespace reflection {
 template <typename T> Object Object::getter(std::string_view id) const {
@@ -124,4 +118,31 @@ Object::Object(Proxy<T> data, std::string name)
     : type_(typeid(T)), name_(std::move(name)), proxy_(true), const_(false),
       data_(std::make_shared<Proxy<T>>(std::move(data))),
       get_(&Object::getter<T>) {}
+
+template <typename V> Object & Object::operator=(V const & value) {
+  if (TypeCast<V>::has_cast(*this)) {
+    TypeCast<V>::set(*this, value);
+  } else {
+    cast<V &>() = value;
+  }
+  return *this;
+}
+
+template <typename T> T Object::cast() const {
+  // Why can we decay this away?
+  using V = std::remove_const_t<std::remove_reference_t<T>>;
+  // 1) Casting is only allowed on non-reference rvals
+  if constexpr (std::is_same_v<V, T>) {
+    if (TypeCast<V>::has_cast(*this)) { return TypeCast<V>::get(*this); }
+  }
+  // 2) typeid does not respect const
+  if (!is_a<V>()) { throw std::bad_cast(); }
+  // 3) We guard against using mutable-reference on immutable Object here
+  if (std::is_same_v<T, V &> && const_) { throw std::bad_cast(); }
+  // 4) Proxy always contains mutable data
+  if (proxy_) { return T(*std::static_pointer_cast<Proxy<V>>(data_)); }
+  // 5) The const will be re-added by type coercion
+  return *std::static_pointer_cast<V>(data_);
+}
+
 }

+ 17 - 48
include/reflection/reflection.h

@@ -10,9 +10,9 @@
 #include "reflection/object.h"
 #include "reflection/proxy.h"
 
-#define REFLECTION(expr)                                                       \
+#define REFLECTION_F(expr)                                                     \
   [=](Obj & obj, std::string name) { return Object(expr, std::move(name)); }
-#define CONST_REFLECTION(expr)                                                 \
+#define CONST_REFLECTION_F(expr)                                               \
   [=](Obj const & obj, std::string name) {                                     \
     return Object(expr, std::move(name));                                      \
   }
@@ -21,7 +21,8 @@ namespace reflection {
 
 template <typename T>
 constexpr auto is_final_reflection =
-    std::is_same_v<T, std::string> || std::is_fundamental_v<T> || std::is_enum_v<T>;
+    std::is_same_v<T, std::string> || std::is_fundamental_v<T> ||
+    std::is_enum_v<T>;
 
 template <typename Obj, typename> class Reflection {
 public:
@@ -45,70 +46,38 @@ public:
   Reflection & bind(std::string_view id, R (Obj::*get)() const,
                     void (Obj::*set)(I)) {
     using V = std::decay_t<R>;
-    members_.emplace(id, REFLECTION(Proxy<V>(obj, get, set)));
-    const_members_.emplace(id, CONST_REFLECTION((obj.*get)()));
+    members_.emplace(id, REFLECTION_F(Proxy<V>(obj, get, set)));
+    const_members_.emplace(id, CONST_REFLECTION_F((obj.*get)()));
     return *this;
   }
-  
+
   template <typename R>
   Reflection & bind(std::string_view id, R (Obj::*get)() const,
-                    std::decay_t<R> &(Obj::*acc)()) {
-    members_.emplace(id, REFLECTION((obj.*acc)()));
-    const_members_.emplace(id, CONST_REFLECTION((obj.*get)()));
+                    std::decay_t<R> & (Obj::*acc)()) {
+    members_.emplace(id, REFLECTION_F((obj.*acc)()));
+    const_members_.emplace(id, CONST_REFLECTION_F((obj.*get)()));
     return *this;
   }
-  
+
   template <typename T>
   Reflection & bind(std::string_view id, T (Obj::*get)() const) {
-    const_members_.emplace(id, CONST_REFLECTION((obj.*get)()));
+    const_members_.emplace(id, CONST_REFLECTION_F((obj.*get)()));
     return *this;
   }
 
   template <typename T> Reflection & bind(std::string_view id, T Obj::*member) {
-    members_.emplace(id, REFLECTION(obj.*member));
-    const_members_.emplace(id, CONST_REFLECTION(obj.*member));
+    members_.emplace(id, REFLECTION_F(obj.*member));
+    const_members_.emplace(id, CONST_REFLECTION_F(obj.*member));
     return *this;
   }
 
   template <typename F> Reflection & bind(std::string_view id, F func) {
     if constexpr (!std::is_const_v<decltype(func(std::declval<Obj &>()))>) {
-      members_.emplace(id, REFLECTION(func(obj)));
-    } else {
-      members_.emplace(id, CONST_REFLECTION(func(obj)));
-    }
-    const_members_.emplace(id, CONST_REFLECTION(func(obj)));
-    return *this;
-  }
-
-  template <typename R, typename T>
-  Reflection & bind(std::string_view id, T Obj::*member) {
-    if constexpr (std::is_convertible_v<T, R>) {
-      return bind(id, member, [](T const &v) { return static_cast<R>(v); });
+      members_.emplace(id, REFLECTION_F(func(obj)));
     } else {
-      return bind(id, member, TypeConversion<R, T>());
+      members_.emplace(id, CONST_REFLECTION_F(func(obj)));
     }
-  }
-  
-  template <typename R, typename T>
-  Reflection & bind(std::string_view id, T (Obj::*get)() const) {
-    if constexpr (std::is_convertible_v<T, R>) {
-      return bind(id, get, [](T const &v) { return static_cast<R>(v); });
-    } else {
-      return bind(id, get, TypeConversion<R, T>());
-    }
-  }
-  
-  template <typename T, typename F>
-  Reflection & bind(std::string_view id, T Obj::*member, F convert) {
-    members_.emplace(id, CONST_REFLECTION(convert(obj.*member)));
-    const_members_.emplace(id, CONST_REFLECTION(convert(obj.*member)));
-    return *this;
-  }
-  
-  template <typename T, typename F>
-  Reflection & bind(std::string_view id, T (Obj::*get)() const, F convert) {
-    members_.emplace(id, CONST_REFLECTION(convert((obj.*get)())));
-    const_members_.emplace(id, CONST_REFLECTION(convert((obj.*get)())));
+    const_members_.emplace(id, CONST_REFLECTION_F(func(obj)));
     return *this;
   }
 

+ 141 - 0
include/reflection/typecast.h

@@ -0,0 +1,141 @@
+//
+//  typecast.h
+//  reflection
+//
+//  Created by Sam Jaffe on 7/4/22.
+//  Copyright © 2022 Sam Jaffe. All rights reserved.
+//
+
+#pragma once
+
+#include <functional>
+#include <typeindex>
+#include <unordered_map>
+#include <utility>
+
+#include "reflection/forward.h"
+#include "reflection/object.h"
+
+namespace reflection {
+template <typename T> class TypeCast {
+public:
+  static bool has_cast(Object const & obj) { return get_.count(obj.type()); }
+
+  static T get(Object const & obj) { return get_.at(obj.type())(obj); }
+
+  static void set(Object & obj, T const & value) {
+    set_.at(obj.type())(obj, value);
+  }
+
+  operator bool() const { return true; }
+
+  template <typename I, typename O, typename S>
+  TypeCast & bind(T (*to)(I), O (*from)(S)) {
+    return bind(std::function(to), std::function(from));
+  }
+
+  template <typename I, typename O, typename S>
+  TypeCast & bind(std::function<T(I)> to, std::function<O(S)> from) {
+    bind_impl(to, from);
+    TypeCast<std::decay_t<I>>().template bind_impl(from, to);
+    return *this;
+  }
+
+  template <typename V> TypeCast & bind() {
+    if constexpr (!std::is_same_v<V, T>) {
+      bind_impl<V>();
+      TypeCast<V>().template bind_impl<T>();
+    }
+    return *this;
+  }
+
+  template <typename V1, typename V2, typename... Vs> TypeCast & bind() {
+    return bind<V1>().template bind<V2, Vs...>();
+  }
+
+private:
+  template <typename I, typename O, typename S>
+  void bind_impl(std::function<T(I)> to, std::function<O(S)> from) {
+    static_assert(std::is_same_v<std::decay_t<S>, T>,
+                  "this types must be compatible");
+    static_assert(std::is_same_v<std::decay_t<I>, std::decay_t<I>>,
+                  "cast types must be compatible");
+    using V = std::decay_t<I>;
+    get_.emplace(typeid(V),
+                 [to](auto & obj) { return to(static_cast<V const &>(obj)); });
+    set_.emplace(typeid(V), [from](auto & obj, T const & val) {
+      static_cast<V &>(obj) = from(val);
+    });
+  }
+
+  template <typename V> void bind_impl() {
+    get_.emplace(typeid(V), [](auto & obj) {
+      return static_cast<T>(static_cast<V const &>(obj));
+    });
+    set_.emplace(typeid(V), [](auto & obj, T val) {
+      static_cast<V &>(obj) = static_cast<V>(val);
+    });
+  }
+
+private:
+  template <typename S> friend class TypeCast;
+  static TypeMap<std::function<T(Object const &)>> get_;
+  static TypeMap<std::function<void(Object &, T const &)>> set_;
+};
+
+template <typename T>
+TypeMap<std::function<T(Object const &)>> TypeCast<T>::get_;
+template <typename T>
+TypeMap<std::function<void(Object &, T const &)>> TypeCast<T>::set_;
+}
+
+#if defined REFLECTION_TYPE_CAST_IMPLEMENTATION
+namespace reflection {
+INSTANTIATE_REFLECTION_TYPECAST(bool);
+INSTANTIATE_REFLECTION_TYPECAST(int16_t);
+INSTANTIATE_REFLECTION_TYPECAST(int32_t);
+INSTANTIATE_REFLECTION_TYPECAST(int64_t);
+INSTANTIATE_REFLECTION_TYPECAST(uint16_t);
+INSTANTIATE_REFLECTION_TYPECAST(uint32_t);
+INSTANTIATE_REFLECTION_TYPECAST(uint64_t);
+INSTANTIATE_REFLECTION_TYPECAST(float);
+INSTANTIATE_REFLECTION_TYPECAST(double);
+INSTANTIATE_REFLECTION_TYPECAST(long double);
+INSTANTIATE_REFLECTION_TYPECAST(std::string);
+
+// Bind boolean types to all integer types
+REFLECTION_TYPE_CAST(bool)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+// Bind all integer types (wider than 1byte) to each other.
+REFLECTION_TYPE_CAST(int32_t)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+REFLECTION_TYPE_CAST(uint32_t)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+REFLECTION_TYPE_CAST(int16_t)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+REFLECTION_TYPE_CAST(uint16_t)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+REFLECTION_TYPE_CAST(int64_t)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+REFLECTION_TYPE_CAST(uint64_t)
+    .bind<int16_t, int32_t, int64_t>()
+    .bind<uint16_t, uint32_t, uint64_t>();
+
+// Mututally bind all floating point numbers
+REFLECTION_TYPE_CAST(float).bind<float, double, long double>();
+REFLECTION_TYPE_CAST(double).bind<float, double, long double>();
+REFLECTION_TYPE_CAST(long double).bind<float, double, long double>();
+}
+#endif

+ 6 - 0
reflection.xcodeproj/project.pbxproj

@@ -11,6 +11,7 @@
 		CD2FF9092310BE9200ABA548 /* GoogleMock.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = CD2FF8F32310BE6D00ABA548 /* GoogleMock.framework */; };
 		CDA9561B28726949006ACEC2 /* reflection in Headers */ = {isa = PBXBuildFile; fileRef = CDA95611287266E5006ACEC2 /* reflection */; settings = {ATTRIBUTES = (Public, ); }; };
 		CDA9561E28731972006ACEC2 /* object_test.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CDA9561D28731972006ACEC2 /* object_test.cxx */; };
+		CDA956D9287493FE006ACEC2 /* typecast_test.cxx in Sources */ = {isa = PBXBuildFile; fileRef = CDA956D8287493FE006ACEC2 /* typecast_test.cxx */; };
 /* End PBXBuildFile section */
 
 /* Begin PBXContainerItemProxy section */
@@ -57,6 +58,8 @@
 		CDA9561728726941006ACEC2 /* libreflection.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libreflection.a; sourceTree = BUILT_PRODUCTS_DIR; };
 		CDA9561C28726A68006ACEC2 /* xcode_gtest_helper.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = xcode_gtest_helper.h; sourceTree = "<group>"; };
 		CDA9561D28731972006ACEC2 /* object_test.cxx */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = object_test.cxx; sourceTree = "<group>"; };
+		CDA956D32873E1C7006ACEC2 /* typecast.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = typecast.h; sourceTree = "<group>"; };
+		CDA956D8287493FE006ACEC2 /* typecast_test.cxx */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = typecast_test.cxx; sourceTree = "<group>"; };
 /* End PBXFileReference section */
 
 /* Begin PBXFrameworksBuildPhase section */
@@ -131,6 +134,7 @@
 			children = (
 				CD5535251EEC689700108F81 /* reflection_test.cxx */,
 				CDA9561D28731972006ACEC2 /* object_test.cxx */,
+				CDA956D8287493FE006ACEC2 /* typecast_test.cxx */,
 				CDA9561C28726A68006ACEC2 /* xcode_gtest_helper.h */,
 			);
 			path = test;
@@ -151,6 +155,7 @@
 				CDA9560F287266C9006ACEC2 /* forward.h */,
 				CDA9561228726723006ACEC2 /* proxy.h */,
 				CDA95610287266CE006ACEC2 /* object.h */,
+				CDA956D32873E1C7006ACEC2 /* typecast.h */,
 			);
 			path = reflection;
 			sourceTree = "<group>";
@@ -293,6 +298,7 @@
 			isa = PBXSourcesBuildPhase;
 			buildActionMask = 2147483647;
 			files = (
+				CDA956D9287493FE006ACEC2 /* typecast_test.cxx in Sources */,
 				CDA9561E28731972006ACEC2 /* object_test.cxx in Sources */,
 				CD2FF9072310BE8500ABA548 /* reflection_test.cxx in Sources */,
 			);

+ 14 - 9
test/object_test.cxx

@@ -27,20 +27,25 @@ struct Example {
 struct Compound {
   Example ex;
 };
+}
 
-bool _1 = Reflection<Example>()
-              .bind("a", &Example::a)
-              .bind("c", &Example::get_c, &Example::set_c);
-
-bool _2 = Reflection<Compound>().bind("ex", &Compound::ex);
+namespace reflection {
+INSTANTIATE_REFLECTION(Example);
+INSTANTIATE_REFLECTION(Compound);
 }
 
+REFLECTION(Example)
+    .bind("a", &Example::a)
+    .bind("c", &Example::get_c, &Example::set_c);
+
+REFLECTION(Compound).bind("ex", &Compound::ex);
+
 TEST(ObjectTest, CanFetchData) {
   Example ex{.a = 1};
 
   Object a = Object(ex).get("a");
 
-  EXPECT_NO_THROW((void)int(a)) << a.type();
+  EXPECT_NO_THROW((void)int(a)) << a.type().name();
   EXPECT_THAT(int(a), 1);
 }
 
@@ -49,7 +54,7 @@ TEST(ObjectTest, CanSetDataOnNonConst) {
 
   {
     Object a = Object(ex).get("a");
-    EXPECT_NO_THROW((void)static_cast<int &>(a)) << a.type();
+    EXPECT_NO_THROW((void)static_cast<int &>(a)) << a.type().name();
     static_cast<int &>(a) = 2;
   }
   EXPECT_THAT(ex.a, 2);
@@ -78,7 +83,7 @@ TEST(ObjectTest, CanFetchWithGetter) {
   Example ex{.a = 1};
 
   Object c = Object(ex).get("c");
-  EXPECT_NO_THROW((void)int(c)) << c.type();
+  EXPECT_NO_THROW((void)int(c)) << c.type().name();
   EXPECT_THAT(int(c), Eq(5));
   static_cast<int &>(c) = 4;
 }
@@ -88,7 +93,7 @@ TEST(ObjectTest, CanModifyWithSetter) {
 
   {
     Object c = Object(ex).get("c");
-    EXPECT_NO_THROW((void)static_cast<int &>(c)) << c.type();
+    EXPECT_NO_THROW((void)static_cast<int &>(c)) << c.type().name();
     static_cast<int &>(c) = 4;
     // Notice that the setter is scoped on the Object expiring
     EXPECT_THAT(ex.a, Eq(1));

+ 8 - 3
test/reflection_test.cxx

@@ -15,13 +15,18 @@ struct Example {
   int d;
 
   int get() const { return d; }
-  int &get() { return d; }
+  int & get() { return d; }
 };
+}
+
+namespace reflection {
+INSTANTIATE_REFLECTION(Example);
+}
 
-bool _ = Reflection<Example>().bind("a", &Example::a)
+REFLECTION(Example)
+    .bind("a", &Example::a)
     .bind("c", &Example::b)
     .bind("d", &Example::get, &Example::get);
-}
 
 TEST(ReflectionTest, AquiresGetter) {
   EXPECT_NO_THROW(Reflection<Example>::getter("a"));

+ 56 - 0
test/typecast_test.cxx

@@ -0,0 +1,56 @@
+//
+//  typecast_test.cxx
+//  reflection-test
+//
+//  Created by Sam Jaffe on 7/5/22.
+//  Copyright © 2022 Sam Jaffe. All rights reserved.
+//
+
+#define REFLECTION_TYPE_CAST_IMPLEMENTATION
+#include "reflection/typecast.h"
+
+#include "xcode_gtest_helper.h"
+
+using reflection::Object;
+using reflection::TypeCast;
+
+using testing::Eq;
+using testing::NotNull;
+
+enum class Status { OK, BAD_INPUT, INTERNAL_ERROR };
+
+std::string status_to_string(Status status) {
+  switch (status) {
+  case Status::OK:
+    return "OK";
+  case Status::BAD_INPUT:
+    return "BAD";
+  case Status::INTERNAL_ERROR:
+    return "ERROR";
+  }
+}
+
+Status status_from_string(std::string const & str) {
+  if (str == "OK")
+    return Status::OK;
+  else if (str == "BAD")
+    return Status::BAD_INPUT;
+  else if (str == "ERROR")
+    return Status::INTERNAL_ERROR;
+  throw std::invalid_argument(str);
+}
+
+namespace reflection {
+INSTANTIATE_REFLECTION_TYPECAST(Status);
+REFLECTION_TYPE_CAST(Status).bind<int>().bind(&status_from_string,
+                                              &status_to_string);
+}
+
+TEST(TypeCastTest, CanPerformTypeCoercion) {
+  Object stat(Status::OK);
+
+  EXPECT_THAT(Status(stat), Eq(Status::OK));
+  EXPECT_THAT(int(stat), Eq(0));
+  std::string str(stat);
+  EXPECT_THAT(str, Eq("OK"));
+}