|
|
@@ -0,0 +1,119 @@
|
|
|
+import abc
|
|
|
+import logging
|
|
|
+
|
|
|
+from dataclasses import dataclass
|
|
|
+from types import SimpleNamespace
|
|
|
+from typing import Final, Protocol, final
|
|
|
+
|
|
|
+from cipy.status import Status
|
|
|
+
|
|
|
+_SENTINEL: Final[str] = "__visited_status_predicate"
|
|
|
+
|
|
|
+class IContext(Protocol):
|
|
|
+ status: Status
|
|
|
+ inputs: SimpleNamespace
|
|
|
+ logger: logging.Logger
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class Predicate(abc.ABC):
|
|
|
+ @abc.abstractmethod
|
|
|
+ def __call__(self, context: IContext) -> bool: ...
|
|
|
+
|
|
|
+ def __and__(self, other: Predicate) -> Predicate:
|
|
|
+ return AndPredicate([self, other])
|
|
|
+
|
|
|
+ def __or__(self, other: Predicate) -> Predicate:
|
|
|
+ return OrPredicate([self, other])
|
|
|
+
|
|
|
+ def __not__(self) -> Predicate:
|
|
|
+ return NotPredicate(self)
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class AndPredicate(Predicate):
|
|
|
+ preds: list[Predicate]
|
|
|
+
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return " && ".join([str(p) for p in self.preds])
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ return all(p(context) for p in self.preds)
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __and__(self, other: Predicate) -> Predicate:
|
|
|
+ return AndPredicate([*self.preds, other])
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class OrPredicate(Predicate):
|
|
|
+ preds: list[Predicate]
|
|
|
+
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return " || ".join([str(p) for p in self.preds])
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ return any(p(context) for p in self.preds)
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __or__(self, other: Predicate) -> Predicate:
|
|
|
+ return OrPredicate([*self.preds, other])
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class NotPredicate(Predicate):
|
|
|
+ pred: Predicate
|
|
|
+
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return f"!{self.pred}"
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ return not self.pred(context)
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __not__(self) -> Predicate:
|
|
|
+ return self.pred
|
|
|
+
|
|
|
+
|
|
|
+class Success(Predicate):
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return "success()"
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ context.logger.debug(self)
|
|
|
+ setattr(context, _SENTINEL, True)
|
|
|
+ return context.status.value <= Status.SUCCESS.value
|
|
|
+
|
|
|
+
|
|
|
+class Always(Predicate):
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return "always()"
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ setattr(context, _SENTINEL, True)
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+class Cancelled(Predicate):
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return "cancelled()"
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ setattr(context, _SENTINEL, True)
|
|
|
+ return context.status is not Status.CANCELLED
|
|
|
+
|
|
|
+
|
|
|
+class Failure(Predicate):
|
|
|
+ def __repr__(self) -> str:
|
|
|
+ return "failure()"
|
|
|
+
|
|
|
+ @final
|
|
|
+ def __call__(self, context: IContext) -> bool:
|
|
|
+ setattr(context, _SENTINEL, True)
|
|
|
+ return context.status is Status.FAILURE
|