|
|
@@ -1,10 +1,15 @@
|
|
|
"""Module containing basic Workflow definitions, which perform non-linear operations"""
|
|
|
|
|
|
-from typing import Any, final, override
|
|
|
+import copy
|
|
|
+import itertools
|
|
|
+
|
|
|
+from typing import Any, Iterable, final, override
|
|
|
|
|
|
from pydantic import BaseModel, PrivateAttr
|
|
|
|
|
|
-from cipy.common import Action, Context, Results, Status, _validate
|
|
|
+from cipy.common import Action, Context, Ref, Results, Status, _validate
|
|
|
+
|
|
|
+type Scalar = bool | int | float | str
|
|
|
|
|
|
|
|
|
class Job(BaseModel):
|
|
|
@@ -63,3 +68,39 @@ class Workflow(Action):
|
|
|
outctx.fabricate(self, "outputs")
|
|
|
|
|
|
return status
|
|
|
+
|
|
|
+
|
|
|
+class Matrix(Action):
|
|
|
+ """
|
|
|
+ Actions that represent running a single Workflow/Action across multiple configurations
|
|
|
+ """
|
|
|
+ matrix: dict[str, list[Scalar | Ref]] | list[dict[str, Scalar | Ref]]
|
|
|
+ uses: Action
|
|
|
+ fail_fast: bool = True
|
|
|
+
|
|
|
+ def _resolve(
|
|
|
+ self, d: dict[str, Scalar | Ref], context: Context
|
|
|
+ ) -> dict[str, Scalar]:
|
|
|
+ return {k: context.access(v) if isinstance(v, Ref) else v for k, v in d.items()}
|
|
|
+
|
|
|
+ def _expand(self, context: Context) -> Iterable[dict[str, Scalar]]:
|
|
|
+ if isinstance(self.matrix, list):
|
|
|
+ return (self._resolve(d, context) for d in self.matrix)
|
|
|
+
|
|
|
+ flatten = [itertools.product([k], vs) for k, vs in self.matrix.items()]
|
|
|
+ return (self._resolve(dict(d), context) for d in itertools.product(*flatten))
|
|
|
+
|
|
|
+ @final
|
|
|
+ def run(self, context: Context) -> Status:
|
|
|
+ status = Status.NOT_RUN
|
|
|
+
|
|
|
+ for matrix in self._expand(context):
|
|
|
+ with context.extend(matrix=matrix) as matctx:
|
|
|
+ tmp = copy.deepcopy(self.uses)
|
|
|
+ status |= tmp.run(matctx)
|
|
|
+ self.outputs = tmp.outputs
|
|
|
+
|
|
|
+ if self.fail_fast and status is Status.FAILURE:
|
|
|
+ break
|
|
|
+
|
|
|
+ return status
|