Sam Jaffe 1 месяц назад
Родитель
Сommit
943aefb1f5
3 измененных файлов с 74 добавлено и 19 удалено
  1. 3 1
      src/cipy/__init__.py
  2. 27 2
      src/cipy/action.py
  3. 44 16
      src/cipy/common.py

+ 3 - 1
src/cipy/__init__.py

@@ -2,16 +2,18 @@
 Entry point for cipy library, re-exporting all of the default items
 Entry point for cipy library, re-exporting all of the default items
 """
 """
 
 
+import types
 import typing
 import typing
 
 
 import pydantic
 import pydantic
 
 
-from cipy.action import Composite, NodeScript, Script
+from cipy.action import Call, Composite, NodeScript, Script
 from cipy.common import Context, Factory, Inputs, Outputs, Ref, Status
 from cipy.common import Context, Factory, Inputs, Outputs, Ref, Status
 from cipy.shell import Shell
 from cipy.shell import Shell
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 
 __all__ = [
 __all__ = [
+    "Call",
     "Composite",
     "Composite",
     "Context",
     "Context",
     "Inputs",
     "Inputs",

+ 27 - 2
src/cipy/action.py

@@ -4,15 +4,40 @@ import pathlib
 import subprocess
 import subprocess
 import tempfile
 import tempfile
 
 
-from typing import final
+from typing import Any, final
 
 
 from pydantic import Field, PrivateAttr
 from pydantic import Field, PrivateAttr
 
 
 import cipy.runner
 import cipy.runner
-from cipy.common import Action, Context, Results, Status, _validate
+from cipy.common import Action, Context, Factory, Ref, Results, Status, _validate
 from cipy.shell import Shell
 from cipy.shell import Shell
 
 
 
 
+class Call(Action, extra='allow'):
+    name: str = ""
+    using: Action
+    __pydantic_extra__: dict[str, bool | int | float | str | Ref | Factory]
+
+    def __init__(self, using: Action, /, **kwargs: Any) -> None:
+        super().__init__(using=using, **kwargs)  # type: ignore
+
+    @final
+    def enabled(self, status: Status, context: Context) -> bool:
+        return self.using.enabled(status, context)
+
+    @final
+    def run(self, context: Context) -> Status:
+        context.fabricate(self.using, "inputs", self.__pydantic_extra__)
+        try:
+            return self.using.run(context)
+        finally:
+            self.outputs = self.using.outputs
+
+    @final
+    def cleanup(self, context: Context) -> None:
+        self.using.cleanup(context)
+
+
 class NodeScript(Action):
 class NodeScript(Action):
     """
     """
     A special script that is run as a node.js file, with optional post-script
     A special script that is run as a node.js file, with optional post-script

+ 44 - 16
src/cipy/common.py

@@ -8,7 +8,7 @@ import os
 from contextlib import contextmanager
 from contextlib import contextmanager
 from enum import Enum, auto
 from enum import Enum, auto
 from functools import reduce
 from functools import reduce
-from types import SimpleNamespace
+from types import SimpleNamespace, NoneType
 from typing import Any, Callable, Iterator, Literal, overload
 from typing import Any, Callable, Iterator, Literal, overload
 
 
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
@@ -47,6 +47,7 @@ class Ref(str):
             if s and all(t for t in s.split(".")):
             if s and all(t for t in s.split(".")):
                 return s
                 return s
             raise ValueError("References must be of the form A.B.C etc.")
             raise ValueError("References must be of the form A.B.C etc.")
+
         return core_schema.general_plain_validator_function(validate)
         return core_schema.general_plain_validator_function(validate)
 
 
 
 
@@ -91,16 +92,31 @@ class Context(SimpleNamespace):
             assert len(path) == 2
             assert len(path) == 2
             return os.environ.get(path[1])
             return os.environ.get(path[1])
 
 
-        return reduce(getattr, path, self)
+        return reduce(
+            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), path, self
+        )
 
 
     @overload
     @overload
-    def fabricate(self, state: BaseModel, attr: Literal["inputs"]) -> Inputs: ...
+    def fabricate(
+        self,
+        state: BaseModel,
+        attr: Literal["inputs"],
+        extra: dict[str, Any] = {},
+    ) -> Inputs: ...
 
 
     @overload
     @overload
-    def fabricate(self, state: BaseModel, attr: Literal["outputs"]) -> Outputs: ...
+    def fabricate(
+        self,
+        state: BaseModel,
+        attr: Literal["outputs"],
+        extra: dict[str, Any] = {},
+    ) -> Outputs: ...
 
 
     def fabricate(
     def fabricate(
-        self, state: BaseModel, attr: Literal["inputs"] | Literal["outputs"]
+        self,
+        state: BaseModel,
+        attr: Literal["inputs"] | Literal["outputs"],
+        extra: dict[str, Ref | Factory] = {},
     ) -> Inputs | Outputs:
     ) -> Inputs | Outputs:
         """Fabricate and validate an Inputs or Outputs object"""
         """Fabricate and validate an Inputs or Outputs object"""
         model = getattr(state, attr)
         model = getattr(state, attr)
@@ -109,15 +125,27 @@ class Context(SimpleNamespace):
             assert annotation is not None
             assert annotation is not None
             model = annotation()
             model = annotation()
 
 
-        for k, field in model.__pydantic_fields__.items():
-            value = getattr(model, k)
-            if field.annotation is None:
-                continue
-
+        for k, value in extra.items():
             if isinstance(value, Ref):
             if isinstance(value, Ref):
                 setattr(model, k, self.access(value))
                 setattr(model, k, self.access(value))
             elif isinstance(value, Factory):
             elif isinstance(value, Factory):
                 setattr(model, k, value(self))
                 setattr(model, k, value(self))
+            else:
+                setattr(model, k, value)
+
+        for k, fld in model.__pydantic_fields__.items():
+            value = getattr(model, k)
+            if fld.annotation is None:
+                continue
+
+            if k in extra:
+                pass
+            elif isinstance(value, Ref):
+                value = self.access(value)
+            elif isinstance(value, Factory):
+                value = value(self)
+            if value is not None:
+                setattr(model, k, fld.annotation(value))
 
 
         _validate(model)
         _validate(model)
         setattr(state, attr, model)
         setattr(state, attr, model)
@@ -167,14 +195,14 @@ class Action(BaseModel, abc.ABC):
 
 
 def _validate(model: BaseModel):
 def _validate(model: BaseModel):
     """Perform the actual model validation that we sabotaged w/ required() and similar functions"""
     """Perform the actual model validation that we sabotaged w/ required() and similar functions"""
-    for k, field in model.__pydantic_fields__.items():
+    for k, fld in model.__pydantic_fields__.items():
         attr = getattr(model, k)
         attr = getattr(model, k)
-        if field.annotation is None:
+        if fld.annotation is None:
             continue
             continue
 
 
-        if isinstance(attr, (Ref, Factory)):
-            raise TypeError(f"field '{k}' in {type(model)} is unset")
-        if not isinstance(attr, field.annotation):
+        if isinstance(attr, (Ref, Factory, NoneType)):
+            raise TypeError(f"fld '{k}' in {type(model).__qualname__} is unset")
+        if not isinstance(attr, fld.annotation):
             raise TypeError(
             raise TypeError(
-                f"field '{k}' in {type(model)} is of the wrong type (should be {field.annotation})"
+                f"field '{k}' in {type(model).__qualname__} is of the wrong type (should be {fld.annotation})"
             )
             )