|
|
@@ -3,43 +3,37 @@ import logging
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
from types import SimpleNamespace
|
|
|
-from typing import Final, Protocol, final
|
|
|
-
|
|
|
-from cipy.status import Status
|
|
|
-
|
|
|
-_SENTINEL: Final[str] = "__visited_status_predicate"
|
|
|
-
|
|
|
-class IContext(Protocol):
|
|
|
- status: Status
|
|
|
- inputs: SimpleNamespace
|
|
|
- logger: logging.Logger
|
|
|
+from typing import Callable, Final, Protocol, final
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class Predicate(abc.ABC):
|
|
|
- @abc.abstractmethod
|
|
|
- def __call__(self, context: IContext) -> bool: ...
|
|
|
-
|
|
|
def __and__(self, other: Predicate) -> Predicate:
|
|
|
return AndPredicate([self, other])
|
|
|
|
|
|
def __or__(self, other: Predicate) -> Predicate:
|
|
|
return OrPredicate([self, other])
|
|
|
|
|
|
- def __not__(self) -> Predicate:
|
|
|
+ def __invert__(self) -> Predicate:
|
|
|
return NotPredicate(self)
|
|
|
|
|
|
|
|
|
+class CompoundPredicate(Predicate):
|
|
|
+ @abc.abstractmethod
|
|
|
+ def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
-class AndPredicate(Predicate):
|
|
|
+class AndPredicate(CompoundPredicate):
|
|
|
preds: list[Predicate]
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return " && ".join([str(p) for p in self.preds])
|
|
|
|
|
|
@final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- return all(p(context) for p in self.preds)
|
|
|
+ def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
|
|
|
+ return all(callback(p) for p in self.preds)
|
|
|
|
|
|
@final
|
|
|
def __and__(self, other: Predicate) -> Predicate:
|
|
|
@@ -47,15 +41,15 @@ class AndPredicate(Predicate):
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
-class OrPredicate(Predicate):
|
|
|
+class OrPredicate(CompoundPredicate):
|
|
|
preds: list[Predicate]
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return " || ".join([str(p) for p in self.preds])
|
|
|
|
|
|
@final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- return any(p(context) for p in self.preds)
|
|
|
+ def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
|
|
|
+ return any(callback(p) for p in self.preds)
|
|
|
|
|
|
@final
|
|
|
def __or__(self, other: Predicate) -> Predicate:
|
|
|
@@ -63,57 +57,16 @@ class OrPredicate(Predicate):
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
-class NotPredicate(Predicate):
|
|
|
+class NotPredicate(CompoundPredicate):
|
|
|
pred: Predicate
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return f"!{self.pred}"
|
|
|
|
|
|
@final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- return not self.pred(context)
|
|
|
+ def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
|
|
|
+ return not callback(self.pred)
|
|
|
|
|
|
@final
|
|
|
def __not__(self) -> Predicate:
|
|
|
return self.pred
|
|
|
-
|
|
|
-
|
|
|
-class Success(Predicate):
|
|
|
- def __repr__(self) -> str:
|
|
|
- return "success()"
|
|
|
-
|
|
|
- @final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- context.logger.debug(self)
|
|
|
- setattr(context, _SENTINEL, True)
|
|
|
- return context.status.value <= Status.SUCCESS.value
|
|
|
-
|
|
|
-
|
|
|
-class Always(Predicate):
|
|
|
- def __repr__(self) -> str:
|
|
|
- return "always()"
|
|
|
-
|
|
|
- @final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- setattr(context, _SENTINEL, True)
|
|
|
- return True
|
|
|
-
|
|
|
-
|
|
|
-class Cancelled(Predicate):
|
|
|
- def __repr__(self) -> str:
|
|
|
- return "cancelled()"
|
|
|
-
|
|
|
- @final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- setattr(context, _SENTINEL, True)
|
|
|
- return context.status is not Status.CANCELLED
|
|
|
-
|
|
|
-
|
|
|
-class Failure(Predicate):
|
|
|
- def __repr__(self) -> str:
|
|
|
- return "failure()"
|
|
|
-
|
|
|
- @final
|
|
|
- def __call__(self, context: IContext) -> bool:
|
|
|
- setattr(context, _SENTINEL, True)
|
|
|
- return context.status is Status.FAILURE
|