Parcourir la source

feat: implement primitive Predicate interface

Sam Jaffe il y a 4 semaines
Parent
commit
c4da2e3d26
6 fichiers modifiés avec 171 ajouts et 29 suppressions
  1. 16 5
      src/cipy/__init__.py
  2. 3 3
      src/cipy/action.py
  3. 11 19
      src/cipy/common.py
  4. 119 0
      src/cipy/predicate.py
  5. 19 0
      src/cipy/status.py
  6. 3 2
      src/cipy/workflow.py

+ 16 - 5
src/cipy/__init__.py

@@ -10,8 +10,10 @@ import sys
 import pydantic
 import pydantic
 
 
 from cipy.action import Call, Composite, NodeScript, Script
 from cipy.action import Call, Composite, NodeScript, Script
-from cipy.common import Context, Factory, Inputs, Outputs, Ref, Status
+from cipy.common import Context, Factory, Inputs, Outputs, Ref
+from cipy.predicate import Predicate, Success, Always, Cancelled, Failure
 from cipy.shell import Shell
 from cipy.shell import Shell
+from cipy.status import Status
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 
 from . import settings
 from . import settings
@@ -30,19 +32,28 @@ logging.basicConfig(
 )
 )
 
 
 __all__ = [
 __all__ = [
-    "Call",
-    "Composite",
+    # Predicate functions
+    "Predicate",
+    "Always",
+    "Cancelled",
+    "Failure",
+    "Success",
+    # Common objects/stubs
+    "Status",
     "Context",
     "Context",
     "Inputs",
     "Inputs",
+    "Outputs",
+    # Action composition
+    "Call",
+    "Composite",
     "Job",
     "Job",
     "Matrix",
     "Matrix",
     "MatrixParams",
     "MatrixParams",
     "NodeScript",
     "NodeScript",
-    "Outputs",
     "Script",
     "Script",
     "Shell",
     "Shell",
-    "Status",
     "Workflow",
     "Workflow",
+    # Helper functions
     "compute",
     "compute",
     "context",
     "context",
     "outputs",
     "outputs",

+ 3 - 3
src/cipy/action.py

@@ -59,8 +59,8 @@ class Call(Action, extra="allow"):
         self.name = using.name
         self.name = using.name
 
 
     @final
     @final
-    def enabled(self, status: Status, context: Context) -> bool:
-        return self.using.enabled(status, context)
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        return self.using.is_enabled(status, context)
 
 
     @final
     @final
     def run(self, context: Context) -> Status:
     def run(self, context: Context) -> Status:
@@ -159,7 +159,7 @@ class Composite(Action):
                 context.fabricate(step, "inputs")
                 context.fabricate(step, "inputs")
                 self._counter += 1
                 self._counter += 1
 
 
-                if step.enabled(status, context):
+                if step.is_enabled(status, context):
                     stat = step.run(context)
                     stat = step.run(context)
                     outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                     outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                 else:
                 else:

+ 11 - 19
src/cipy/common.py

@@ -6,32 +6,21 @@ import logging
 import os
 import os
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
-from enum import Enum, auto
-from functools import reduce
+from functools import reduce, wraps
 from types import SimpleNamespace, NoneType
 from types import SimpleNamespace, NoneType
 from typing import Annotated, Any, Callable, Iterator, Literal, Self, final, overload
 from typing import Annotated, Any, Callable, Iterator, Literal, Self, final, overload
 
 
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from pydantic_core import PydanticUndefined
 from pydantic_core import PydanticUndefined
 
 
+from cipy.status import Status
+from cipy.predicate import Predicate, Success, _SENTINEL
+
 type Scalar = bool | int | float | str
 type Scalar = bool | int | float | str
 type Computed = Ref | Factory
 type Computed = Ref | Factory
 type Value = Scalar | Computed
 type Value = Scalar | Computed
 
 
 
 
-class Status(Enum):
-    """Result status of a runner, higher numbers take priority"""
-
-    NOT_RUN = auto()
-    SKIPPED = auto()
-    SUCCESS = auto()
-    FAILURE = auto()
-    CANCELLED = auto()
-
-    def __ior__(self, other: Status) -> Status:
-        return self if self.value > other.value else other
-
-
 class Inputs(BaseModel):
 class Inputs(BaseModel):
     """Stub class describing input arguments"""
     """Stub class describing input arguments"""
 
 
@@ -164,6 +153,7 @@ class Action(BaseModel, abc.ABC):
 
 
     name: str
     name: str
     id: str = Field(default="", exclude_if=lambda v: not v)
     id: str = Field(default="", exclude_if=lambda v: not v)
+    enabled: Predicate = Success()
     inputs: Inputs = Inputs()
     inputs: Inputs = Inputs()
     outputs: Outputs = Outputs()
     outputs: Outputs = Outputs()
 
 
@@ -172,10 +162,12 @@ class Action(BaseModel, abc.ABC):
         """Get this class's logger"""
         """Get this class's logger"""
         return logging.getLogger(self.__class__.__name__)
         return logging.getLogger(self.__class__.__name__)
 
 
-    # pylint: disable=unused-argument
-    def enabled(self, status: Status, context: Context) -> bool:
-        """Should this action even be run?"""
-        return status.value <= Status.SUCCESS.value
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        """Proxy function to test if this Action is permitted to run"""
+        with context.extend(
+            inputs=self.inputs, status=status, logger=self.logger
+        ) as ctx:
+            return self.enabled(ctx) and (hasattr(ctx, _SENTINEL) or Success()(ctx))
 
 
     @abc.abstractmethod
     @abc.abstractmethod
     def run(self, context: Context) -> Status:
     def run(self, context: Context) -> Status:

+ 119 - 0
src/cipy/predicate.py

@@ -0,0 +1,119 @@
+import abc
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Final, Protocol, final
+
+from cipy.status import Status
+
+_SENTINEL: Final[str] = "__visited_status_predicate"
+
+class IContext(Protocol):
+    status: Status
+    inputs: SimpleNamespace
+    logger: logging.Logger
+
+
+@dataclass
+class Predicate(abc.ABC):
+    @abc.abstractmethod
+    def __call__(self, context: IContext) -> bool: ...
+
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([self, other])
+
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([self, other])
+
+    def __not__(self) -> Predicate:
+        return NotPredicate(self)
+
+
+@dataclass
+class AndPredicate(Predicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " && ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        return all(p(context) for p in self.preds)
+
+    @final
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([*self.preds, other])
+
+
+@dataclass
+class OrPredicate(Predicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " || ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        return any(p(context) for p in self.preds)
+
+    @final
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([*self.preds, other])
+
+
+@dataclass
+class NotPredicate(Predicate):
+    pred: Predicate
+
+    def __repr__(self) -> str:
+        return f"!{self.pred}"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        return not self.pred(context)
+
+    @final
+    def __not__(self) -> Predicate:
+        return self.pred
+
+
+class Success(Predicate):
+    def __repr__(self) -> str:
+        return "success()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        context.logger.debug(self)
+        setattr(context, _SENTINEL, True)
+        return context.status.value <= Status.SUCCESS.value
+
+
+class Always(Predicate):
+    def __repr__(self) -> str:
+        return "always()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        setattr(context, _SENTINEL, True)
+        return True
+
+
+class Cancelled(Predicate):
+    def __repr__(self) -> str:
+        return "cancelled()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        setattr(context, _SENTINEL, True)
+        return context.status is not Status.CANCELLED
+
+
+class Failure(Predicate):
+    def __repr__(self) -> str:
+        return "failure()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        setattr(context, _SENTINEL, True)
+        return context.status is Status.FAILURE

+ 19 - 0
src/cipy/status.py

@@ -0,0 +1,19 @@
+from enum import Enum, auto
+
+from typing import TypeVar, final
+
+Action = TypeVar("Action")
+Context = TypeVar("Context")
+
+
+class Status(Enum):
+    """Result status of a runner, higher numbers take priority"""
+
+    NOT_RUN = auto()
+    SKIPPED = auto()
+    SUCCESS = auto()
+    FAILURE = auto()
+    CANCELLED = auto()
+
+    def __ior__(self, other: Status) -> Status:
+        return self if self.value > other.value else other

+ 3 - 2
src/cipy/workflow.py

@@ -9,7 +9,8 @@ from pydantic import BaseModel, PrivateAttr
 
 
 import cipy.runner
 import cipy.runner
 
 
-from cipy.common import Action, Context, Results, Scalar, Status, Value
+from cipy.common import Action, Context, Results, Scalar, Value
+from cipy.status import Status
 
 
 
 
 class Job(BaseModel):
 class Job(BaseModel):
@@ -54,7 +55,7 @@ class Workflow(Action):
                 self.logger.info("Running Job: %s", job.id)
                 self.logger.info("Running Job: %s", job.id)
                 visited.add(job.id)
                 visited.add(job.id)
 
 
-                if job.action.enabled(status, context):
+                if job.action.is_enabled(status, context):
                     stat = job.action.run(context)
                     stat = job.action.run(context)
                     outctx.needs[job.id] = Results.Item(
                     outctx.needs[job.id] = Results.Item(
                         stat, job.action.outputs.validated()
                         stat, job.action.outputs.validated()