Sfoglia il codice sorgente

feat: add dispatch

Sam Jaffe 1 mese fa
parent
commit
943aefb1f5
3 ha cambiato i file con 74 aggiunte e 19 eliminazioni
  1. 3 1
      src/cipy/__init__.py
  2. 27 2
      src/cipy/action.py
  3. 44 16
      src/cipy/common.py

+ 3 - 1
src/cipy/__init__.py

@@ -2,16 +2,18 @@
 Entry point for cipy library, re-exporting all of the default items
 """
 
+import types
 import typing
 
 import pydantic
 
-from cipy.action import Composite, NodeScript, Script
+from cipy.action import Call, Composite, NodeScript, Script
 from cipy.common import Context, Factory, Inputs, Outputs, Ref, Status
 from cipy.shell import Shell
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 __all__ = [
+    "Call",
     "Composite",
     "Context",
     "Inputs",

+ 27 - 2
src/cipy/action.py

@@ -4,15 +4,40 @@ import pathlib
 import subprocess
 import tempfile
 
-from typing import final
+from typing import Any, final
 
 from pydantic import Field, PrivateAttr
 
 import cipy.runner
-from cipy.common import Action, Context, Results, Status, _validate
+from cipy.common import Action, Context, Factory, Ref, Results, Status, _validate
 from cipy.shell import Shell
 
 
+class Call(Action, extra='allow'):
+    name: str = ""
+    using: Action
+    __pydantic_extra__: dict[str, bool | int | float | str | Ref | Factory]
+
+    def __init__(self, using: Action, /, **kwargs: Any) -> None:
+        super().__init__(using=using, **kwargs)  # type: ignore
+
+    @final
+    def enabled(self, status: Status, context: Context) -> bool:
+        return self.using.enabled(status, context)
+
+    @final
+    def run(self, context: Context) -> Status:
+        context.fabricate(self.using, "inputs", self.__pydantic_extra__)
+        try:
+            return self.using.run(context)
+        finally:
+            self.outputs = self.using.outputs
+
+    @final
+    def cleanup(self, context: Context) -> None:
+        self.using.cleanup(context)
+
+
 class NodeScript(Action):
     """
     A special script that is run as a node.js file, with optional post-script

+ 44 - 16
src/cipy/common.py

@@ -8,7 +8,7 @@ import os
 from contextlib import contextmanager
 from enum import Enum, auto
 from functools import reduce
-from types import SimpleNamespace
+from types import SimpleNamespace, NoneType
 from typing import Any, Callable, Iterator, Literal, overload
 
 from pydantic import BaseModel, Field
@@ -47,6 +47,7 @@ class Ref(str):
             if s and all(t for t in s.split(".")):
                 return s
             raise ValueError("References must be of the form A.B.C etc.")
+
         return core_schema.general_plain_validator_function(validate)
 
 
@@ -91,16 +92,31 @@ class Context(SimpleNamespace):
             assert len(path) == 2
             return os.environ.get(path[1])
 
-        return reduce(getattr, path, self)
+        return reduce(
+            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), path, self
+        )
 
     @overload
-    def fabricate(self, state: BaseModel, attr: Literal["inputs"]) -> Inputs: ...
+    def fabricate(
+        self,
+        state: BaseModel,
+        attr: Literal["inputs"],
+        extra: dict[str, Any] = {},
+    ) -> Inputs: ...
 
     @overload
-    def fabricate(self, state: BaseModel, attr: Literal["outputs"]) -> Outputs: ...
+    def fabricate(
+        self,
+        state: BaseModel,
+        attr: Literal["outputs"],
+        extra: dict[str, Any] = {},
+    ) -> Outputs: ...
 
     def fabricate(
-        self, state: BaseModel, attr: Literal["inputs"] | Literal["outputs"]
+        self,
+        state: BaseModel,
+        attr: Literal["inputs"] | Literal["outputs"],
+        extra: dict[str, Ref | Factory] = {},
     ) -> Inputs | Outputs:
         """Fabricate and validate an Inputs or Outputs object"""
         model = getattr(state, attr)
@@ -109,15 +125,27 @@ class Context(SimpleNamespace):
             assert annotation is not None
             model = annotation()
 
-        for k, field in model.__pydantic_fields__.items():
-            value = getattr(model, k)
-            if field.annotation is None:
-                continue
-
+        for k, value in extra.items():
             if isinstance(value, Ref):
                 setattr(model, k, self.access(value))
             elif isinstance(value, Factory):
                 setattr(model, k, value(self))
+            else:
+                setattr(model, k, value)
+
+        for k, fld in model.__pydantic_fields__.items():
+            value = getattr(model, k)
+            if fld.annotation is None:
+                continue
+
+            if k in extra:
+                pass
+            elif isinstance(value, Ref):
+                value = self.access(value)
+            elif isinstance(value, Factory):
+                value = value(self)
+            if value is not None:
+                setattr(model, k, fld.annotation(value))
 
         _validate(model)
         setattr(state, attr, model)
@@ -167,14 +195,14 @@ class Action(BaseModel, abc.ABC):
 
 def _validate(model: BaseModel):
     """Perform the actual model validation that we sabotaged w/ required() and similar functions"""
-    for k, field in model.__pydantic_fields__.items():
+    for k, fld in model.__pydantic_fields__.items():
         attr = getattr(model, k)
-        if field.annotation is None:
+        if fld.annotation is None:
             continue
 
-        if isinstance(attr, (Ref, Factory)):
-            raise TypeError(f"field '{k}' in {type(model)} is unset")
-        if not isinstance(attr, field.annotation):
+        if isinstance(attr, (Ref, Factory, NoneType)):
+            raise TypeError(f"fld '{k}' in {type(model).__qualname__} is unset")
+        if not isinstance(attr, fld.annotation):
             raise TypeError(
-                f"field '{k}' in {type(model)} is of the wrong type (should be {field.annotation})"
+                f"field '{k}' in {type(model).__qualname__} is of the wrong type (should be {fld.annotation})"
             )