Bladeren bron

feat: implement primitive Predicate interface

Sam Jaffe 4 weken geleden
bovenliggende
commit
c4da2e3d26
6 gewijzigde bestanden met toevoegingen van 171 en 29 verwijderingen
  1. 16 5
      src/cipy/__init__.py
  2. 3 3
      src/cipy/action.py
  3. 11 19
      src/cipy/common.py
  4. 119 0
      src/cipy/predicate.py
  5. 19 0
      src/cipy/status.py
  6. 3 2
      src/cipy/workflow.py

+ 16 - 5
src/cipy/__init__.py

@@ -10,8 +10,10 @@ import sys
 import pydantic
 
 from cipy.action import Call, Composite, NodeScript, Script
-from cipy.common import Context, Factory, Inputs, Outputs, Ref, Status
+from cipy.common import Context, Factory, Inputs, Outputs, Ref
+from cipy.predicate import Predicate, Success, Always, Cancelled, Failure
 from cipy.shell import Shell
+from cipy.status import Status
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 from . import settings
@@ -30,19 +32,28 @@ logging.basicConfig(
 )
 
 __all__ = [
-    "Call",
-    "Composite",
+    # Predicate functions
+    "Predicate",
+    "Always",
+    "Cancelled",
+    "Failure",
+    "Success",
+    # Common objects/stubs
+    "Status",
     "Context",
     "Inputs",
+    "Outputs",
+    # Action composition
+    "Call",
+    "Composite",
     "Job",
     "Matrix",
     "MatrixParams",
     "NodeScript",
-    "Outputs",
     "Script",
     "Shell",
-    "Status",
     "Workflow",
+    # Helper functions
     "compute",
     "context",
     "outputs",

+ 3 - 3
src/cipy/action.py

@@ -59,8 +59,8 @@ class Call(Action, extra="allow"):
         self.name = using.name
 
     @final
-    def enabled(self, status: Status, context: Context) -> bool:
-        return self.using.enabled(status, context)
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        return self.using.is_enabled(status, context)
 
     @final
     def run(self, context: Context) -> Status:
@@ -159,7 +159,7 @@ class Composite(Action):
                 context.fabricate(step, "inputs")
                 self._counter += 1
 
-                if step.enabled(status, context):
+                if step.is_enabled(status, context):
                     stat = step.run(context)
                     outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                 else:

+ 11 - 19
src/cipy/common.py

@@ -6,32 +6,21 @@ import logging
 import os
 
 from contextlib import contextmanager
-from enum import Enum, auto
-from functools import reduce
+from functools import reduce, wraps
 from types import SimpleNamespace, NoneType
 from typing import Annotated, Any, Callable, Iterator, Literal, Self, final, overload
 
 from pydantic import BaseModel, Field
 from pydantic_core import PydanticUndefined
 
+from cipy.status import Status
+from cipy.predicate import Predicate, Success, _SENTINEL
+
 type Scalar = bool | int | float | str
 type Computed = Ref | Factory
 type Value = Scalar | Computed
 
 
-class Status(Enum):
-    """Result status of a runner, higher numbers take priority"""
-
-    NOT_RUN = auto()
-    SKIPPED = auto()
-    SUCCESS = auto()
-    FAILURE = auto()
-    CANCELLED = auto()
-
-    def __ior__(self, other: Status) -> Status:
-        return self if self.value > other.value else other
-
-
 class Inputs(BaseModel):
     """Stub class describing input arguments"""
 
@@ -164,6 +153,7 @@ class Action(BaseModel, abc.ABC):
 
     name: str
     id: str = Field(default="", exclude_if=lambda v: not v)
+    enabled: Predicate = Success()
     inputs: Inputs = Inputs()
     outputs: Outputs = Outputs()
 
@@ -172,10 +162,12 @@ class Action(BaseModel, abc.ABC):
         """Get this class's logger"""
         return logging.getLogger(self.__class__.__name__)
 
-    # pylint: disable=unused-argument
-    def enabled(self, status: Status, context: Context) -> bool:
-        """Should this action even be run?"""
-        return status.value <= Status.SUCCESS.value
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        """Proxy function to test if this Action is permitted to run"""
+        with context.extend(
+            inputs=self.inputs, status=status, logger=self.logger
+        ) as ctx:
+            return self.enabled(ctx) and (hasattr(ctx, _SENTINEL) or Success()(ctx))
 
     @abc.abstractmethod
     def run(self, context: Context) -> Status:

+ 119 - 0
src/cipy/predicate.py

@@ -0,0 +1,119 @@
+import abc
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Final, Protocol, final
+
+from cipy.status import Status
+
+_SENTINEL: Final[str] = "__visited_status_predicate"
+
+class IContext(Protocol):
+    status: Status
+    inputs: SimpleNamespace
+    logger: logging.Logger
+
+
+@dataclass
+class Predicate(abc.ABC):
+    @abc.abstractmethod
+    def __call__(self, context: IContext) -> bool: ...
+
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([self, other])
+
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([self, other])
+
+    def __not__(self) -> Predicate:
+        return NotPredicate(self)
+
+
+@dataclass
+class AndPredicate(Predicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " && ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        return all(p(context) for p in self.preds)
+
+    @final
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([*self.preds, other])
+
+
+@dataclass
+class OrPredicate(Predicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " || ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        return any(p(context) for p in self.preds)
+
+    @final
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([*self.preds, other])
+
+
+@dataclass
+class NotPredicate(Predicate):
+    pred: Predicate
+
+    def __repr__(self) -> str:
+        return f"!{self.pred}"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        return not self.pred(context)
+
+    @final
+    def __not__(self) -> Predicate:
+        return self.pred
+
+
+class Success(Predicate):
+    def __repr__(self) -> str:
+        return "success()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        context.logger.debug(self)
+        setattr(context, _SENTINEL, True)
+        return context.status.value <= Status.SUCCESS.value
+
+
+class Always(Predicate):
+    def __repr__(self) -> str:
+        return "always()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        setattr(context, _SENTINEL, True)
+        return True
+
+
+class Cancelled(Predicate):
+    def __repr__(self) -> str:
+        return "cancelled()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        setattr(context, _SENTINEL, True)
+        return context.status is not Status.CANCELLED
+
+
+class Failure(Predicate):
+    def __repr__(self) -> str:
+        return "failure()"
+
+    @final
+    def __call__(self, context: IContext) -> bool:
+        setattr(context, _SENTINEL, True)
+        return context.status is Status.FAILURE

+ 19 - 0
src/cipy/status.py

@@ -0,0 +1,19 @@
+from enum import Enum, auto
+
+from typing import TypeVar, final
+
+Action = TypeVar("Action")
+Context = TypeVar("Context")
+
+
+class Status(Enum):
+    """Result status of a runner, higher numbers take priority"""
+
+    NOT_RUN = auto()
+    SKIPPED = auto()
+    SUCCESS = auto()
+    FAILURE = auto()
+    CANCELLED = auto()
+
+    def __ior__(self, other: Status) -> Status:
+        return self if self.value > other.value else other

+ 3 - 2
src/cipy/workflow.py

@@ -9,7 +9,8 @@ from pydantic import BaseModel, PrivateAttr
 
 import cipy.runner
 
-from cipy.common import Action, Context, Results, Scalar, Status, Value
+from cipy.common import Action, Context, Results, Scalar, Value
+from cipy.status import Status
 
 
 class Job(BaseModel):
@@ -54,7 +55,7 @@ class Workflow(Action):
                 self.logger.info("Running Job: %s", job.id)
                 visited.add(job.id)
 
-                if job.action.enabled(status, context):
+                if job.action.is_enabled(status, context):
                     stat = job.action.run(context)
                     outctx.needs[job.id] = Results.Item(
                         stat, job.action.outputs.validated()