Explorar o código

refactor/fix: logging for Matrix, fix adhoc Outputs bug in fabricate (Composite/Workflow)

Sam Jaffe hai 1 mes
pai
achega
e1f5876bc6
Modificáronse 4 ficheiros con 26 adicións e 6 borrados
  1. 15 2
      src/cipy/__init__.py
  2. 4 3
      src/cipy/common.py
  3. 3 1
      src/cipy/runner.py
  4. 4 0
      src/cipy/workflow.py

+ 15 - 2
src/cipy/__init__.py

@@ -73,11 +73,24 @@ def compute(arg: typing.Callable[[Context], typing.Any], /) -> typing.Any:
     return pydantic.Field(default=Factory(arg))
 
 
-def outputs(**fields: types.UnionType | type[typing.Any]) -> Outputs:
+T = typing.TypeVar("T")
+type TypeInfo[T] = types.UnionType | type[T]
+
+
+def outputs(
+    **fields: TypeInfo[typing.Any] | tuple[TypeInfo[typing.Any], Ref],
+) -> Outputs:
+
+    def to_annotation(field):
+        if isinstance(field, tuple):
+            return typing.Annotated[field[0], pydantic.Field(default=field[1])]
+        else:
+            return typing.Annotated[field, pydantic.Field(default=None)]
+
     frame = sys._getframe(1)
     return pydantic.create_model(  # type: ignore[call-overload]
         f"__{frame.f_lineno}_AnonymousOutputs",
         __base__=Outputs,
         __module__=frame.f_globals["__name__"],
-        **{k: typing.Annotated[t, pydantic.Field(default=None)] for k, t in fields.items()}
+        **{k: to_annotation(t) for k, t in fields.items()},
     )()

+ 4 - 3
src/cipy/common.py

@@ -135,13 +135,14 @@ class Context(SimpleNamespace):
         if extra is None:
             extra = {}
 
+        annotation = state.__pydantic_fields__[attr].annotation
+        assert annotation is not None
+
         fields: dict[str, Any] = {}
         if (model := getattr(state, attr)) is not None:
+            annotation = model.__class__
             fields = vars(model)
 
-        annotation = state.__pydantic_fields__[attr].annotation
-        assert annotation is not None
-
         for name, fld in annotation.__pydantic_fields__.items():
             if name in extra:
                 fields[name] = self(extra[name])

+ 3 - 1
src/cipy/runner.py

@@ -175,7 +175,9 @@ def _log_inputs(self: Action) -> None:
 
 
 def _log_outputs(self: Action) -> None:
-    outputs = [(k, v) for k, v in vars(self.outputs).items() if v is not None and v != ""]
+    outputs = [
+        (k, v) for k, v in vars(self.outputs).items() if v is not None and v != ""
+    ]
     if not outputs:
         return
 

+ 4 - 0
src/cipy/workflow.py

@@ -7,6 +7,8 @@ from typing import Any, Iterable, final, override
 
 from pydantic import BaseModel, PrivateAttr
 
+import cipy.runner
+
 from cipy.common import Action, Context, Results, Scalar, Status, Value
 
 
@@ -108,4 +110,6 @@ class Matrix(Action):
             if self.fail_fast and status is Status.FAILURE:
                 break
 
+        cipy.runner._log_outputs(self)
+
         return status