Sam Jaffe 1 месяц назад
Родитель
Сommit
bd73cb7a0d
2 измененных файлов с 55 добавлено и 6 удалено
  1. 12 4
      src/cipy/common.py
  2. 43 2
      src/cipy/workflow.py

+ 12 - 4
src/cipy/common.py

@@ -1,16 +1,17 @@
 """Common objects and base classes in the CI hierarchy"""
 """Common objects and base classes in the CI hierarchy"""
 
 
 import abc
 import abc
+import dataclasses
 import os
 import os
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
-from dataclasses import dataclass, field
 from enum import Enum, auto
 from enum import Enum, auto
 from functools import reduce
 from functools import reduce
 from types import SimpleNamespace
 from types import SimpleNamespace
 from typing import Any, Callable, Iterator, Literal, overload
 from typing import Any, Callable, Iterator, Literal, overload
 
 
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
+from pydantic_core import core_schema
 
 
 
 
 class Status(Enum):
 class Status(Enum):
@@ -37,8 +38,15 @@ class Outputs(BaseModel):
 class Ref(str):
 class Ref(str):
     """Annotation class describing a reference into Context or another place"""
     """Annotation class describing a reference into Context or another place"""
 
 
+    @classmethod
+    # pylint: disable=unused-argument
+    def __get_pydantic_core_schema__(cls, source, handler) -> core_schema.CoreSchema:
+        return core_schema.general_plain_validator_function(
+            lambda s: s and all(t for t in s.split("."))
+        )
+
 
 
-@dataclass
+@dataclasses.dataclass
 class Factory:
 class Factory:
     """Annotation class describing a non-trivial synthesized property"""
     """Annotation class describing a non-trivial synthesized property"""
 
 
@@ -50,12 +58,12 @@ class Results(SimpleNamespace):
     Holder object for tracking the result of an action for Composite/Workflow actions
     Holder object for tracking the result of an action for Composite/Workflow actions
     """
     """
 
 
-    @dataclass
+    @dataclasses.dataclass
     class Item:
     class Item:
         """Result of a single action that needs to be tracked"""
         """Result of a single action that needs to be tracked"""
 
 
         conclusion: Status = Status.NOT_RUN
         conclusion: Status = Status.NOT_RUN
-        outputs: Outputs = field(default_factory=Outputs)
+        outputs: Outputs = dataclasses.field(default_factory=Outputs)
 
 
     def __contains__(self, subscript: str) -> bool:
     def __contains__(self, subscript: str) -> bool:
         return hasattr(self, subscript)
         return hasattr(self, subscript)

+ 43 - 2
src/cipy/workflow.py

@@ -1,10 +1,15 @@
 """Module containing basic Workflow definitions, which perform non-linear operations"""
 """Module containing basic Workflow definitions, which perform non-linear operations"""
 
 
-from typing import Any, final, override
+import copy
+import itertools
+
+from typing import Any, Iterable, final, override
 
 
 from pydantic import BaseModel, PrivateAttr
 from pydantic import BaseModel, PrivateAttr
 
 
-from cipy.common import Action, Context, Results, Status, _validate
+from cipy.common import Action, Context, Ref, Results, Status, _validate
+
+type Scalar = bool | int | float | str
 
 
 
 
 class Job(BaseModel):
 class Job(BaseModel):
@@ -63,3 +68,39 @@ class Workflow(Action):
             outctx.fabricate(self, "outputs")
             outctx.fabricate(self, "outputs")
 
 
         return status
         return status
+
+
+class Matrix(Action):
+    """
+    Actions that represent running a single Workflow/Action across multiple configurations
+    """
+    matrix: dict[str, list[Scalar | Ref]] | list[dict[str, Scalar | Ref]]
+    uses: Action
+    fail_fast: bool = True
+
+    def _resolve(
+        self, d: dict[str, Scalar | Ref], context: Context
+    ) -> dict[str, Scalar]:
+        return {k: context.access(v) if isinstance(v, Ref) else v for k, v in d.items()}
+
+    def _expand(self, context: Context) -> Iterable[dict[str, Scalar]]:
+        if isinstance(self.matrix, list):
+            return (self._resolve(d, context) for d in self.matrix)
+
+        flatten = [itertools.product([k], vs) for k, vs in self.matrix.items()]
+        return (self._resolve(dict(d), context) for d in itertools.product(*flatten))
+
+    @final
+    def run(self, context: Context) -> Status:
+        status = Status.NOT_RUN
+
+        for matrix in self._expand(context):
+            with context.extend(matrix=matrix) as matctx:
+                tmp = copy.deepcopy(self.uses)
+                status |= tmp.run(matctx)
+                self.outputs = tmp.outputs
+
+            if self.fail_fast and status is Status.FAILURE:
+                break
+
+        return status