Sfoglia il codice sorgente

feat: separate different predicate types, add an engine to help with things and logs

Sam Jaffe 4 settimane fa
parent
commit
631d1330b3
5 ha cambiato i file con 114 aggiunte e 72 eliminazioni
  1. 2 2
      src/cipy/__init__.py
  2. 42 0
      src/cipy/_predicate_engine.py
  3. 5 6
      src/cipy/action.py
  4. 17 64
      src/cipy/predicate.py
  5. 48 0
      src/cipy/status.py

+ 2 - 2
src/cipy/__init__.py

@@ -12,10 +12,10 @@ import pydantic
 from cipy.action import Action, Call, Composite
 from cipy.common import Inputs, Outputs, Ref
 from cipy.context import Context, Factory
-from cipy.predicate import Predicate, Success, Always, Cancelled, Failure
+from cipy.predicate import Predicate
 from cipy.script import NodeScript, Script
 from cipy.shell import Shell
-from cipy.status import Status
+from cipy.status import Status, Success, Always, Cancelled, Failure
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 from . import settings

+ 42 - 0
src/cipy/_predicate_engine.py

@@ -0,0 +1,42 @@
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+from cipy.predicate import Predicate, CompoundPredicate
+from cipy.status import Status, StatusPredicate, Success
+
+
+@dataclass
+class PredicateEngine:
+    status: Status
+    context: SimpleNamespace
+    logger: logging.Logger
+
+    _depth: int = -1
+    _has_visited_status: bool = False
+
+    def __call__(self, pred: Predicate) -> bool:
+        if not isinstance(pred, (StatusPredicate, CompoundPredicate)):
+            raise NotImplementedError(type(pred))
+
+        self._depth += 1
+        self.logger.debug("%sif { %r }", '\u2502 ' * self._depth, pred)
+
+        if isinstance(pred, StatusPredicate):
+            self._has_visited_status = True
+            rval = pred(self.status)
+
+        if isinstance(pred, CompoundPredicate):
+            rval = pred(self)
+
+        self.logger.debug("%s\u2500%s\u27A4 %s", "\u2514" if self._depth == 0 else "\u251C",
+                          '\u2500\u2500' * self._depth, rval)
+        self._depth -= 1
+
+        if rval and self._depth == -1 and not self._has_visited_status:
+            self.logger.debug("if success() [because no status predicate was requested]")
+            rval = Success()(self.status)
+            self.logger.debug("\u2514\u2500\u27A4 %s", rval)
+
+        return rval

+ 5 - 6
src/cipy/action.py

@@ -10,8 +10,9 @@ from pydantic import BaseModel, Field, PrivateAttr
 from cipy import settings
 from cipy.common import Inputs, Outputs
 from cipy.context import Context, Results, Value
-from cipy.predicate import Predicate, Success, _SENTINEL
-from cipy.status import Status
+from cipy.predicate import Predicate
+from cipy.status import Status, Success
+from cipy._predicate_engine import PredicateEngine
 
 
 class Action(BaseModel, abc.ABC):
@@ -30,10 +31,8 @@ class Action(BaseModel, abc.ABC):
 
     def is_enabled(self, status: Status, context: Context) -> bool:
         """Proxy function to test if this Action is permitted to run"""
-        with context.extend(
-            inputs=self.inputs, status=status, logger=self.logger
-        ) as ctx:
-            return self.enabled(ctx) and (hasattr(ctx, _SENTINEL) or Success()(ctx))
+        with context.extend(inputs=self.inputs) as ctx:
+            return PredicateEngine(status, context, self.logger)(self.enabled)
 
     @abc.abstractmethod
     def run(self, context: Context) -> Status:

+ 17 - 64
src/cipy/predicate.py

@@ -3,43 +3,37 @@ import logging
 
 from dataclasses import dataclass
 from types import SimpleNamespace
-from typing import Final, Protocol, final
-
-from cipy.status import Status
-
-_SENTINEL: Final[str] = "__visited_status_predicate"
-
-class IContext(Protocol):
-    status: Status
-    inputs: SimpleNamespace
-    logger: logging.Logger
+from typing import Callable, Final, Protocol, final
 
 
 @dataclass
 class Predicate(abc.ABC):
-    @abc.abstractmethod
-    def __call__(self, context: IContext) -> bool: ...
-
     def __and__(self, other: Predicate) -> Predicate:
         return AndPredicate([self, other])
 
     def __or__(self, other: Predicate) -> Predicate:
         return OrPredicate([self, other])
 
-    def __not__(self) -> Predicate:
+    def __invert__(self) -> Predicate:
         return NotPredicate(self)
 
 
+class CompoundPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        ...
+
+
 @dataclass
-class AndPredicate(Predicate):
+class AndPredicate(CompoundPredicate):
     preds: list[Predicate]
 
     def __repr__(self) -> str:
         return " && ".join([str(p) for p in self.preds])
 
     @final
-    def __call__(self, context: IContext) -> bool:
-        return all(p(context) for p in self.preds)
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return all(callback(p) for p in self.preds)
 
     @final
     def __and__(self, other: Predicate) -> Predicate:
@@ -47,15 +41,15 @@ class AndPredicate(Predicate):
 
 
 @dataclass
-class OrPredicate(Predicate):
+class OrPredicate(CompoundPredicate):
     preds: list[Predicate]
 
     def __repr__(self) -> str:
         return " || ".join([str(p) for p in self.preds])
 
     @final
-    def __call__(self, context: IContext) -> bool:
-        return any(p(context) for p in self.preds)
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return any(callback(p) for p in self.preds)
 
     @final
     def __or__(self, other: Predicate) -> Predicate:
@@ -63,57 +57,16 @@ class OrPredicate(Predicate):
 
 
 @dataclass
-class NotPredicate(Predicate):
+class NotPredicate(CompoundPredicate):
     pred: Predicate
 
     def __repr__(self) -> str:
         return f"!{self.pred}"
 
     @final
-    def __call__(self, context: IContext) -> bool:
-        return not self.pred(context)
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return not callback(self.pred)
 
     @final
     def __not__(self) -> Predicate:
         return self.pred
-
-
-class Success(Predicate):
-    def __repr__(self) -> str:
-        return "success()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        context.logger.debug(self)
-        setattr(context, _SENTINEL, True)
-        return context.status.value <= Status.SUCCESS.value
-
-
-class Always(Predicate):
-    def __repr__(self) -> str:
-        return "always()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        setattr(context, _SENTINEL, True)
-        return True
-
-
-class Cancelled(Predicate):
-    def __repr__(self) -> str:
-        return "cancelled()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        setattr(context, _SENTINEL, True)
-        return context.status is not Status.CANCELLED
-
-
-class Failure(Predicate):
-    def __repr__(self) -> str:
-        return "failure()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        setattr(context, _SENTINEL, True)
-        return context.status is Status.FAILURE

+ 48 - 0
src/cipy/status.py

@@ -1,5 +1,11 @@
 """Status enum"""
+
+import abc
+
 from enum import Enum, auto
+from typing import TypeVar, final
+
+from cipy.predicate import Predicate
 
 
 class Status(Enum):
@@ -13,3 +19,45 @@ class Status(Enum):
 
     def __ior__(self, other: Status) -> Status:
         return self if self.value > other.value else other
+
+
+class StatusPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, status: Status) -> bool:
+        ...
+
+
+class Success(StatusPredicate):
+    def __repr__(self) -> str:
+        return "success()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status.value <= Status.SUCCESS.value
+
+
+class Always(StatusPredicate):
+    def __repr__(self) -> str:
+        return "always()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return True
+
+
+class Cancelled(StatusPredicate):
+    def __repr__(self) -> str:
+        return "cancelled()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is not Status.CANCELLED
+
+
+class Failure(StatusPredicate):
+    def __repr__(self) -> str:
+        return "failure()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is Status.FAILURE