소스 검색

feat: separate different predicate types, add an engine to help with things and logs

Sam Jaffe 4 주 전
부모
커밋
631d1330b3
5개의 변경된 파일114개의 추가작업 그리고 72개의 파일을 삭제
  1. 2 2
      src/cipy/__init__.py
  2. 42 0
      src/cipy/_predicate_engine.py
  3. 5 6
      src/cipy/action.py
  4. 17 64
      src/cipy/predicate.py
  5. 48 0
      src/cipy/status.py

+ 2 - 2
src/cipy/__init__.py

@@ -12,10 +12,10 @@ import pydantic
 from cipy.action import Action, Call, Composite
 from cipy.action import Action, Call, Composite
 from cipy.common import Inputs, Outputs, Ref
 from cipy.common import Inputs, Outputs, Ref
 from cipy.context import Context, Factory
 from cipy.context import Context, Factory
-from cipy.predicate import Predicate, Success, Always, Cancelled, Failure
+from cipy.predicate import Predicate
 from cipy.script import NodeScript, Script
 from cipy.script import NodeScript, Script
 from cipy.shell import Shell
 from cipy.shell import Shell
-from cipy.status import Status
+from cipy.status import Status, Success, Always, Cancelled, Failure
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 
 from . import settings
 from . import settings

+ 42 - 0
src/cipy/_predicate_engine.py

@@ -0,0 +1,42 @@
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+from cipy.predicate import Predicate, CompoundPredicate
+from cipy.status import Status, StatusPredicate, Success
+
+
+@dataclass
+class PredicateEngine:
+    status: Status
+    context: SimpleNamespace
+    logger: logging.Logger
+
+    _depth: int = -1
+    _has_visited_status: bool = False
+
+    def __call__(self, pred: Predicate) -> bool:
+        if not isinstance(pred, (StatusPredicate, CompoundPredicate)):
+            raise NotImplementedError(type(pred))
+
+        self._depth += 1
+        self.logger.debug("%sif { %r }", '\u2502 ' * self._depth, pred)
+
+        if isinstance(pred, StatusPredicate):
+            self._has_visited_status = True
+            rval = pred(self.status)
+
+        if isinstance(pred, CompoundPredicate):
+            rval = pred(self)
+
+        self.logger.debug("%s\u2500%s\u27A4 %s", "\u2514" if self._depth == 0 else "\u251C",
+                          '\u2500\u2500' * self._depth, rval)
+        self._depth -= 1
+
+        if rval and self._depth == -1 and not self._has_visited_status:
+            self.logger.debug("if success() [because no status predicate was requested]")
+            rval = Success()(self.status)
+            self.logger.debug("\u2514\u2500\u27A4 %s", rval)
+
+        return rval

+ 5 - 6
src/cipy/action.py

@@ -10,8 +10,9 @@ from pydantic import BaseModel, Field, PrivateAttr
 from cipy import settings
 from cipy import settings
 from cipy.common import Inputs, Outputs
 from cipy.common import Inputs, Outputs
 from cipy.context import Context, Results, Value
 from cipy.context import Context, Results, Value
-from cipy.predicate import Predicate, Success, _SENTINEL
-from cipy.status import Status
+from cipy.predicate import Predicate
+from cipy.status import Status, Success
+from cipy._predicate_engine import PredicateEngine
 
 
 
 
 class Action(BaseModel, abc.ABC):
 class Action(BaseModel, abc.ABC):
@@ -30,10 +31,8 @@ class Action(BaseModel, abc.ABC):
 
 
     def is_enabled(self, status: Status, context: Context) -> bool:
     def is_enabled(self, status: Status, context: Context) -> bool:
         """Proxy function to test if this Action is permitted to run"""
         """Proxy function to test if this Action is permitted to run"""
-        with context.extend(
-            inputs=self.inputs, status=status, logger=self.logger
-        ) as ctx:
-            return self.enabled(ctx) and (hasattr(ctx, _SENTINEL) or Success()(ctx))
+        with context.extend(inputs=self.inputs) as ctx:
+            return PredicateEngine(status, context, self.logger)(self.enabled)
 
 
     @abc.abstractmethod
     @abc.abstractmethod
     def run(self, context: Context) -> Status:
     def run(self, context: Context) -> Status:

+ 17 - 64
src/cipy/predicate.py

@@ -3,43 +3,37 @@ import logging
 
 
 from dataclasses import dataclass
 from dataclasses import dataclass
 from types import SimpleNamespace
 from types import SimpleNamespace
-from typing import Final, Protocol, final
-
-from cipy.status import Status
-
-_SENTINEL: Final[str] = "__visited_status_predicate"
-
-class IContext(Protocol):
-    status: Status
-    inputs: SimpleNamespace
-    logger: logging.Logger
+from typing import Callable, Final, Protocol, final
 
 
 
 
 @dataclass
 @dataclass
 class Predicate(abc.ABC):
 class Predicate(abc.ABC):
-    @abc.abstractmethod
-    def __call__(self, context: IContext) -> bool: ...
-
     def __and__(self, other: Predicate) -> Predicate:
     def __and__(self, other: Predicate) -> Predicate:
         return AndPredicate([self, other])
         return AndPredicate([self, other])
 
 
     def __or__(self, other: Predicate) -> Predicate:
     def __or__(self, other: Predicate) -> Predicate:
         return OrPredicate([self, other])
         return OrPredicate([self, other])
 
 
-    def __not__(self) -> Predicate:
+    def __invert__(self) -> Predicate:
         return NotPredicate(self)
         return NotPredicate(self)
 
 
 
 
+class CompoundPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        ...
+
+
 @dataclass
 @dataclass
-class AndPredicate(Predicate):
+class AndPredicate(CompoundPredicate):
     preds: list[Predicate]
     preds: list[Predicate]
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return " && ".join([str(p) for p in self.preds])
         return " && ".join([str(p) for p in self.preds])
 
 
     @final
     @final
-    def __call__(self, context: IContext) -> bool:
-        return all(p(context) for p in self.preds)
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return all(callback(p) for p in self.preds)
 
 
     @final
     @final
     def __and__(self, other: Predicate) -> Predicate:
     def __and__(self, other: Predicate) -> Predicate:
@@ -47,15 +41,15 @@ class AndPredicate(Predicate):
 
 
 
 
 @dataclass
 @dataclass
-class OrPredicate(Predicate):
+class OrPredicate(CompoundPredicate):
     preds: list[Predicate]
     preds: list[Predicate]
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return " || ".join([str(p) for p in self.preds])
         return " || ".join([str(p) for p in self.preds])
 
 
     @final
     @final
-    def __call__(self, context: IContext) -> bool:
-        return any(p(context) for p in self.preds)
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return any(callback(p) for p in self.preds)
 
 
     @final
     @final
     def __or__(self, other: Predicate) -> Predicate:
     def __or__(self, other: Predicate) -> Predicate:
@@ -63,57 +57,16 @@ class OrPredicate(Predicate):
 
 
 
 
 @dataclass
 @dataclass
-class NotPredicate(Predicate):
+class NotPredicate(CompoundPredicate):
     pred: Predicate
     pred: Predicate
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return f"!{self.pred}"
         return f"!{self.pred}"
 
 
     @final
     @final
-    def __call__(self, context: IContext) -> bool:
-        return not self.pred(context)
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return not callback(self.pred)
 
 
     @final
     @final
     def __not__(self) -> Predicate:
     def __not__(self) -> Predicate:
         return self.pred
         return self.pred
-
-
-class Success(Predicate):
-    def __repr__(self) -> str:
-        return "success()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        context.logger.debug(self)
-        setattr(context, _SENTINEL, True)
-        return context.status.value <= Status.SUCCESS.value
-
-
-class Always(Predicate):
-    def __repr__(self) -> str:
-        return "always()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        setattr(context, _SENTINEL, True)
-        return True
-
-
-class Cancelled(Predicate):
-    def __repr__(self) -> str:
-        return "cancelled()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        setattr(context, _SENTINEL, True)
-        return context.status is not Status.CANCELLED
-
-
-class Failure(Predicate):
-    def __repr__(self) -> str:
-        return "failure()"
-
-    @final
-    def __call__(self, context: IContext) -> bool:
-        setattr(context, _SENTINEL, True)
-        return context.status is Status.FAILURE

+ 48 - 0
src/cipy/status.py

@@ -1,5 +1,11 @@
 """Status enum"""
 """Status enum"""
+
+import abc
+
 from enum import Enum, auto
 from enum import Enum, auto
+from typing import TypeVar, final
+
+from cipy.predicate import Predicate
 
 
 
 
 class Status(Enum):
 class Status(Enum):
@@ -13,3 +19,45 @@ class Status(Enum):
 
 
     def __ior__(self, other: Status) -> Status:
     def __ior__(self, other: Status) -> Status:
         return self if self.value > other.value else other
         return self if self.value > other.value else other
+
+
+class StatusPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, status: Status) -> bool:
+        ...
+
+
+class Success(StatusPredicate):
+    def __repr__(self) -> str:
+        return "success()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status.value <= Status.SUCCESS.value
+
+
+class Always(StatusPredicate):
+    def __repr__(self) -> str:
+        return "always()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return True
+
+
+class Cancelled(StatusPredicate):
+    def __repr__(self) -> str:
+        return "cancelled()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is not Status.CANCELLED
+
+
+class Failure(StatusPredicate):
+    def __repr__(self) -> str:
+        return "failure()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is Status.FAILURE