فهرست منبع

refactor: better Computed resolution in Context

Sam Jaffe 1 ماه پیش
والد
کامیت
d537be993c
2فایلهای تغییر یافته به همراه30 افزوده شده و 27 حذف شده
  1. 24 18
      src/cipy/common.py
  2. 6 9
      src/cipy/workflow.py

+ 24 - 18
src/cipy/common.py

@@ -13,6 +13,10 @@ from typing import Annotated, Any, Callable, Iterator, Literal, overload
 
 from pydantic import BaseModel, Field
 
+type Scalar = bool | int | float | str
+type Computed = Ref | Factory
+type Value = Scalar | Computed
+
 
 class Status(Enum):
     """Result status of a runner, higher numbers take priority"""
@@ -51,7 +55,7 @@ class Ref:
 class Factory:
     """Annotation class describing a non-trivial synthesized property"""
 
-    __call__: Callable[[Context], Any]
+    __call__: Callable[[Context], Scalar | None]
 
 
 class Results(SimpleNamespace):
@@ -80,15 +84,23 @@ class Results(SimpleNamespace):
 class Context(SimpleNamespace):
     """Wrapper class for the context of the CI runtime"""
 
-    def access(self, ref: Ref) -> Any:
+    def __call__(self, arg: Value | None) -> Scalar | None:
+        if arg is None:
+            return None
+
+        if isinstance(arg, Factory):
+            return arg(self)
+
+        if not isinstance(arg, Ref):
+            return arg
+
         """Accessor for context state with a dot-separated path"""
-        if ref.path[0] == "env":
-            assert len(ref.path) == 2
-            return os.environ.get(ref.path[1])
+        if arg.path[0] == "env":
+            assert len(arg.path) == 2
+            return os.environ.get(arg.path[1])
 
-        return reduce(
-            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), ref.path, self
-        )
+        attr = lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a)
+        return reduce(attr, arg.path, self)  # type: ignore[return-value]
 
     @overload
     def fabricate(
@@ -122,13 +134,9 @@ class Context(SimpleNamespace):
             assert annotation is not None
             model = annotation()
 
+        value: Value | None = None
         for k, value in extra.items():
-            if isinstance(value, Ref):
-                setattr(model, k, self.access(value))
-            elif isinstance(value, Factory):
-                setattr(model, k, value(self))
-            else:
-                setattr(model, k, value)
+            setattr(model, k, self(value))
 
         for k, fld in model.__pydantic_fields__.items():
             value = getattr(model, k)
@@ -137,10 +145,8 @@ class Context(SimpleNamespace):
 
             if k in extra:
                 pass
-            elif isinstance(value, Ref):
-                value = self.access(value)
-            elif isinstance(value, Factory):
-                value = value(self)
+
+            value = self(value)
             if value is not None:
                 setattr(model, k, fld.annotation(value))
 

+ 6 - 9
src/cipy/workflow.py

@@ -7,10 +7,7 @@ from typing import Any, Iterable, final, override
 
 from pydantic import BaseModel, PrivateAttr
 
-from cipy.common import Action, Context, Ref, Results, Status, _validate
-
-type Scalar = bool | int | float | str
-
+from cipy.common import Action, Context, Results, Scalar, Status, Value, _validate
 
 class Job(BaseModel):
     """A wrapper for a graph node with edges"""
@@ -71,7 +68,7 @@ class Workflow(Action):
         return status
 
 
-type MatrixParams = dict[str, list[Scalar | Ref]] | list[dict[str, Scalar | Ref]]
+type MatrixParams = dict[str, list[Value]] | list[dict[str, Value]]
 
 
 class Matrix(Action):
@@ -84,11 +81,11 @@ class Matrix(Action):
     fail_fast: bool = True
 
     def _resolve(
-        self, d: dict[str, Scalar | Ref], context: Context
-    ) -> dict[str, Scalar]:
-        return {k: context.access(v) if isinstance(v, Ref) else v for k, v in d.items()}
+        self, d: dict[str, Value], context: Context
+    ) -> dict[str, Scalar | None]:
+        return {k: context(v) for k, v in d.items()}
 
-    def _expand(self, context: Context) -> Iterable[dict[str, Scalar]]:
+    def _expand(self, context: Context) -> Iterable[dict[str, Scalar | None]]:
         if isinstance(self.on, list):
             return (self._resolve(d, context) for d in self.on)