Pārlūkot izejas kodu

Merge branch 'feat/predicate'

Sam Jaffe 3 nedēļas atpakaļ
vecāks
revīzija
c36d20a034

+ 14 - 7
src/cipy/__init__.py

@@ -12,9 +12,10 @@ import pydantic
 from cipy.action import Action, Call, Composite
 from cipy.common import Inputs, Outputs, Ref
 from cipy.context import Context, Factory
+from cipy.predicate import Predicate
 from cipy.script import NodeScript, Script
 from cipy.shell import Shell
-from cipy.status import Status
+from cipy.status import Status, Success, Always, Cancelled, Failure
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 from . import settings
@@ -33,23 +34,29 @@ logging.basicConfig(
 )
 
 __all__ = [
-    # Basic Datamodels
+    # Predicate functions
+    "Predicate",
+    "Always",
+    "Cancelled",
+    "Failure",
+    "Success",
+    # Common objects/stubs
+    "Status",
+    "Context",
     "Inputs",
     "Outputs",
-    "Context",
-    "Status",
     # Actions (Linear)
     "Action",
     "Call",
     "Composite",
-    "Matrix",
-    "MatrixParams",
     "NodeScript",
     "Script",
     "Shell",
     # Workflow (Non-Linear)
-    "Workflow",
     "Job",
+    "Matrix",
+    "MatrixParams",
+    "Workflow",
     # Helpers
     "compute",
     "context",

+ 42 - 0
src/cipy/_predicate_engine.py

@@ -0,0 +1,42 @@
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+from cipy.predicate import Predicate, CompoundPredicate
+from cipy.status import Status, StatusPredicate, Success
+
+
+@dataclass
+class PredicateEngine:
+    status: Status
+    context: SimpleNamespace
+    logger: logging.Logger
+
+    _depth: int = -1
+    _has_visited_status: bool = False
+
+    def __call__(self, pred: Predicate) -> bool:
+        if not isinstance(pred, (StatusPredicate, CompoundPredicate)):
+            raise NotImplementedError(type(pred))
+
+        self._depth += 1
+        self.logger.debug("%sif { %r }", '\u2502 ' * self._depth, pred)
+
+        if isinstance(pred, StatusPredicate):
+            self._has_visited_status = True
+            rval = pred(self.status)
+
+        if isinstance(pred, CompoundPredicate):
+            rval = pred(self)
+
+        self.logger.debug("%s\u2500%s\u27A4 %s", "\u2514" if self._depth == 0 else "\u251C",
+                          '\u2500\u2500' * self._depth, rval)
+        self._depth -= 1
+
+        if rval and self._depth == -1 and not self._has_visited_status:
+            self.logger.debug("if success() [because no status predicate was requested]")
+            rval = Success()(self.status)
+            self.logger.debug("\u2514\u2500\u27A4 %s", rval)
+
+        return rval

+ 13 - 9
src/cipy/action.py

@@ -10,7 +10,9 @@ from pydantic import BaseModel, Field, PrivateAttr
 from cipy import settings
 from cipy.common import Inputs, Outputs
 from cipy.context import Context, Results, Value
-from cipy.status import Status
+from cipy.predicate import Predicate
+from cipy.status import Status, Success
+from cipy._predicate_engine import PredicateEngine
 
 
 class Action(BaseModel, abc.ABC):
@@ -18,6 +20,7 @@ class Action(BaseModel, abc.ABC):
 
     name: str
     id: str = Field(default="", exclude_if=lambda v: not v)
+    enabled: Predicate = Success()
     inputs: Inputs = Inputs()
     outputs: Outputs = Outputs()
 
@@ -26,10 +29,10 @@ class Action(BaseModel, abc.ABC):
         """Get this class's logger"""
         return logging.getLogger(self.__class__.__name__)
 
-    # pylint: disable=unused-argument
-    def enabled(self, status: Status, context: Context) -> bool:
-        """Should this action even be run?"""
-        return status.value <= Status.SUCCESS.value
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        """Proxy function to test if this Action is permitted to run"""
+        with context.extend(inputs=self.inputs) as ctx:
+            return PredicateEngine(status, context, self.logger)(self.enabled)
 
     @abc.abstractmethod
     def run(self, context: Context) -> Status:
@@ -80,8 +83,8 @@ class Call(Action, extra="allow"):
         self.name = using.name
 
     @final
-    def enabled(self, status: Status, context: Context) -> bool:
-        return self.using.enabled(status, context)
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        return self.using.is_enabled(status, context)
 
     @final
     def run(self, context: Context) -> Status:
@@ -120,7 +123,7 @@ class Composite(Action):
                 context.fabricate(step, "inputs")
                 self._counter += 1
 
-                if step.enabled(status, context):
+                if step.is_enabled(status, context):
                     stat = step.run(context)
                     outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                 else:
@@ -129,7 +132,8 @@ class Composite(Action):
 
                 status |= stat
 
-            outctx.fabricate(self, "outputs")
+            if status is Status.SUCCESS:
+                outctx.fabricate(self, "outputs")
 
         return status
 

+ 3 - 0
src/cipy/common.py

@@ -29,3 +29,6 @@ class Ref:
         self.path = pathstr.split(".")
         if not self.path:
             raise ValueError("References must be of the form A.B.C etc.")
+
+    def __repr__(self) -> str:
+        return "Ref({})".format(".".join(self.path))

+ 52 - 8
src/cipy/context.py

@@ -1,13 +1,15 @@
 """Classes for managing the context of a CI run"""
+
 import os
 
 from contextlib import contextmanager
 from dataclasses import dataclass, field
-from functools import reduce
-from types import SimpleNamespace
-from typing import Any, Callable, Iterator, Literal, overload
+from functools import partial, reduce
+from types import NoneType, SimpleNamespace, UnionType
+from typing import Any, Callable, Iterator, Literal, Protocol, get_args, overload
 
 from pydantic import BaseModel
+from pydantic.fields import FieldInfo
 from pydantic_core import PydanticUndefined
 
 from cipy.common import Inputs, Outputs, Ref
@@ -48,6 +50,29 @@ class Results(SimpleNamespace):
         return self.__getattribute__(subscript)
 
 
+class _Stub:
+    pass
+
+
+def _chain_attrs(context: Any, ref: Ref) -> Any:
+    try:
+        for token in ref.path:
+            if (
+                isinstance(context, Results.Item)
+                and context.conclusion is not Status.SUCCESS
+            ):
+                return _Stub()
+
+            if isinstance(context, dict):
+                context = context[token]
+            else:
+                context = getattr(context, token)
+        return context
+    except KeyError, AttributeError:
+        reason = "NULL object" if context is None else "not found"
+        raise AttributeError(f'unable to find {ref} item "{token}": {reason}')
+
+
 class Context(SimpleNamespace):
     """Wrapper class for the context of the CI runtime"""
 
@@ -66,9 +91,7 @@ class Context(SimpleNamespace):
             assert len(arg.path) == 2
             return os.environ.get(arg.path[1])
 
-        return reduce(  # type: ignore[return-value]
-            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), arg.path, self
-        )
+        return _chain_attrs(self, arg)
 
     @overload
     def fabricate(
@@ -105,15 +128,36 @@ class Context(SimpleNamespace):
             fields = vars(model)
 
         for name, fld in annotation.__pydantic_fields__.items():
+            coerce = partial(self.__coerce, state, name, fld)
             if name in extra:
-                fields[name] = self(extra[name])
+                fields[name] = coerce(extra[name])
             elif fld.default is not PydanticUndefined:
-                fields[name] = self(fld.default)
+                fields[name] = coerce(fld.default)
 
         model = annotation(**fields)
         setattr(state, attr, model)
         return model
 
+    def __coerce(
+        self, state: BaseModel, name: str, fld: FieldInfo, arg: Value | None
+    ) -> Any:
+        value: Scalar | None = self(arg)
+        if value is None or not isinstance(value, _Stub):
+            return value
+
+        anno = fld.annotation
+
+        assert anno is not None
+        if isinstance(anno, UnionType):
+            anno = next(iter(t for t in get_args(anno) if t is not NoneType))
+
+        assert hasattr(state, "logger")
+        state.logger.warning(
+            'binding %s to "%s" failed: action was not successful', arg, name
+        )
+        state.logger.debug("coercing to %s", anno.__name__)
+        return anno()
+
     @contextmanager
     def extend(self, **kwargs: Any) -> Iterator[Context]:
         """Create a new context that inherits an extra property"""

+ 72 - 0
src/cipy/predicate.py

@@ -0,0 +1,72 @@
+import abc
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Callable, Final, Protocol, final
+
+
+@dataclass
+class Predicate(abc.ABC):
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([self, other])
+
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([self, other])
+
+    def __invert__(self) -> Predicate:
+        return NotPredicate(self)
+
+
+class CompoundPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        ...
+
+
+@dataclass
+class AndPredicate(CompoundPredicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " && ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return all(callback(p) for p in self.preds)
+
+    @final
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([*self.preds, other])
+
+
+@dataclass
+class OrPredicate(CompoundPredicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " || ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return any(callback(p) for p in self.preds)
+
+    @final
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([*self.preds, other])
+
+
+@dataclass
+class NotPredicate(CompoundPredicate):
+    pred: Predicate
+
+    def __repr__(self) -> str:
+        return f"!{self.pred}"
+
+    @final
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return not callback(self.pred)
+
+    @final
+    def __not__(self) -> Predicate:
+        return self.pred

+ 48 - 0
src/cipy/status.py

@@ -1,5 +1,11 @@
 """Status enum"""
+
+import abc
+
 from enum import Enum, auto
+from typing import TypeVar, final
+
+from cipy.predicate import Predicate
 
 
 class Status(Enum):
@@ -13,3 +19,45 @@ class Status(Enum):
 
     def __ior__(self, other: Status) -> Status:
         return self if self.value > other.value else other
+
+
+class StatusPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, status: Status) -> bool:
+        ...
+
+
+class Success(StatusPredicate):
+    def __repr__(self) -> str:
+        return "success()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status.value <= Status.SUCCESS.value
+
+
+class Always(StatusPredicate):
+    def __repr__(self) -> str:
+        return "always()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return True
+
+
+class Cancelled(StatusPredicate):
+    def __repr__(self) -> str:
+        return "cancelled()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is Status.CANCELLED
+
+
+class Failure(StatusPredicate):
+    def __repr__(self) -> str:
+        return "failure()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is Status.FAILURE

+ 3 - 2
src/cipy/workflow.py

@@ -56,7 +56,7 @@ class Workflow(Action):
                 self.logger.info("Running Job: %s", job.id)
                 visited.add(job.id)
 
-                if job.action.enabled(status, context):
+                if job.action.is_enabled(status, context):
                     stat = job.action.run(context)
                     outctx.needs[job.id] = Results.Item(
                         stat, job.action.outputs.validated()
@@ -68,7 +68,8 @@ class Workflow(Action):
                 status |= stat
                 self._finished(job.id)
 
-            outctx.fabricate(self, "outputs")
+            if status is Status.SUCCESS:
+                outctx.fabricate(self, "outputs")
 
         return status