Ver código fonte

Merge branch 'feat/predicate'

Sam Jaffe 3 semanas atrás
pai
commit
c36d20a034

+ 14 - 7
src/cipy/__init__.py

@@ -12,9 +12,10 @@ import pydantic
 from cipy.action import Action, Call, Composite
 from cipy.action import Action, Call, Composite
 from cipy.common import Inputs, Outputs, Ref
 from cipy.common import Inputs, Outputs, Ref
 from cipy.context import Context, Factory
 from cipy.context import Context, Factory
+from cipy.predicate import Predicate
 from cipy.script import NodeScript, Script
 from cipy.script import NodeScript, Script
 from cipy.shell import Shell
 from cipy.shell import Shell
-from cipy.status import Status
+from cipy.status import Status, Success, Always, Cancelled, Failure
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 from cipy.workflow import Job, Matrix, MatrixParams, Workflow
 
 
 from . import settings
 from . import settings
@@ -33,23 +34,29 @@ logging.basicConfig(
 )
 )
 
 
 __all__ = [
 __all__ = [
-    # Basic Datamodels
+    # Predicate functions
+    "Predicate",
+    "Always",
+    "Cancelled",
+    "Failure",
+    "Success",
+    # Common objects/stubs
+    "Status",
+    "Context",
     "Inputs",
     "Inputs",
     "Outputs",
     "Outputs",
-    "Context",
-    "Status",
     # Actions (Linear)
     # Actions (Linear)
     "Action",
     "Action",
     "Call",
     "Call",
     "Composite",
     "Composite",
-    "Matrix",
-    "MatrixParams",
     "NodeScript",
     "NodeScript",
     "Script",
     "Script",
     "Shell",
     "Shell",
     # Workflow (Non-Linear)
     # Workflow (Non-Linear)
-    "Workflow",
     "Job",
     "Job",
+    "Matrix",
+    "MatrixParams",
+    "Workflow",
     # Helpers
     # Helpers
     "compute",
     "compute",
     "context",
     "context",

+ 42 - 0
src/cipy/_predicate_engine.py

@@ -0,0 +1,42 @@
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+
+from cipy.predicate import Predicate, CompoundPredicate
+from cipy.status import Status, StatusPredicate, Success
+
+
+@dataclass
+class PredicateEngine:
+    status: Status
+    context: SimpleNamespace
+    logger: logging.Logger
+
+    _depth: int = -1
+    _has_visited_status: bool = False
+
+    def __call__(self, pred: Predicate) -> bool:
+        if not isinstance(pred, (StatusPredicate, CompoundPredicate)):
+            raise NotImplementedError(type(pred))
+
+        self._depth += 1
+        self.logger.debug("%sif { %r }", '\u2502 ' * self._depth, pred)
+
+        if isinstance(pred, StatusPredicate):
+            self._has_visited_status = True
+            rval = pred(self.status)
+
+        if isinstance(pred, CompoundPredicate):
+            rval = pred(self)
+
+        self.logger.debug("%s\u2500%s\u27A4 %s", "\u2514" if self._depth == 0 else "\u251C",
+                          '\u2500\u2500' * self._depth, rval)
+        self._depth -= 1
+
+        if rval and self._depth == -1 and not self._has_visited_status:
+            self.logger.debug("if success() [because no status predicate was requested]")
+            rval = Success()(self.status)
+            self.logger.debug("\u2514\u2500\u27A4 %s", rval)
+
+        return rval

+ 13 - 9
src/cipy/action.py

@@ -10,7 +10,9 @@ from pydantic import BaseModel, Field, PrivateAttr
 from cipy import settings
 from cipy import settings
 from cipy.common import Inputs, Outputs
 from cipy.common import Inputs, Outputs
 from cipy.context import Context, Results, Value
 from cipy.context import Context, Results, Value
-from cipy.status import Status
+from cipy.predicate import Predicate
+from cipy.status import Status, Success
+from cipy._predicate_engine import PredicateEngine
 
 
 
 
 class Action(BaseModel, abc.ABC):
 class Action(BaseModel, abc.ABC):
@@ -18,6 +20,7 @@ class Action(BaseModel, abc.ABC):
 
 
     name: str
     name: str
     id: str = Field(default="", exclude_if=lambda v: not v)
     id: str = Field(default="", exclude_if=lambda v: not v)
+    enabled: Predicate = Success()
     inputs: Inputs = Inputs()
     inputs: Inputs = Inputs()
     outputs: Outputs = Outputs()
     outputs: Outputs = Outputs()
 
 
@@ -26,10 +29,10 @@ class Action(BaseModel, abc.ABC):
         """Get this class's logger"""
         """Get this class's logger"""
         return logging.getLogger(self.__class__.__name__)
         return logging.getLogger(self.__class__.__name__)
 
 
-    # pylint: disable=unused-argument
-    def enabled(self, status: Status, context: Context) -> bool:
-        """Should this action even be run?"""
-        return status.value <= Status.SUCCESS.value
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        """Proxy function to test if this Action is permitted to run"""
+        with context.extend(inputs=self.inputs) as ctx:
+            return PredicateEngine(status, context, self.logger)(self.enabled)
 
 
     @abc.abstractmethod
     @abc.abstractmethod
     def run(self, context: Context) -> Status:
     def run(self, context: Context) -> Status:
@@ -80,8 +83,8 @@ class Call(Action, extra="allow"):
         self.name = using.name
         self.name = using.name
 
 
     @final
     @final
-    def enabled(self, status: Status, context: Context) -> bool:
-        return self.using.enabled(status, context)
+    def is_enabled(self, status: Status, context: Context) -> bool:
+        return self.using.is_enabled(status, context)
 
 
     @final
     @final
     def run(self, context: Context) -> Status:
     def run(self, context: Context) -> Status:
@@ -120,7 +123,7 @@ class Composite(Action):
                 context.fabricate(step, "inputs")
                 context.fabricate(step, "inputs")
                 self._counter += 1
                 self._counter += 1
 
 
-                if step.enabled(status, context):
+                if step.is_enabled(status, context):
                     stat = step.run(context)
                     stat = step.run(context)
                     outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                     outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                 else:
                 else:
@@ -129,7 +132,8 @@ class Composite(Action):
 
 
                 status |= stat
                 status |= stat
 
 
-            outctx.fabricate(self, "outputs")
+            if status is Status.SUCCESS:
+                outctx.fabricate(self, "outputs")
 
 
         return status
         return status
 
 

+ 3 - 0
src/cipy/common.py

@@ -29,3 +29,6 @@ class Ref:
         self.path = pathstr.split(".")
         self.path = pathstr.split(".")
         if not self.path:
         if not self.path:
             raise ValueError("References must be of the form A.B.C etc.")
             raise ValueError("References must be of the form A.B.C etc.")
+
+    def __repr__(self) -> str:
+        return "Ref({})".format(".".join(self.path))

+ 52 - 8
src/cipy/context.py

@@ -1,13 +1,15 @@
 """Classes for managing the context of a CI run"""
 """Classes for managing the context of a CI run"""
+
 import os
 import os
 
 
 from contextlib import contextmanager
 from contextlib import contextmanager
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
-from functools import reduce
-from types import SimpleNamespace
-from typing import Any, Callable, Iterator, Literal, overload
+from functools import partial, reduce
+from types import NoneType, SimpleNamespace, UnionType
+from typing import Any, Callable, Iterator, Literal, Protocol, get_args, overload
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
+from pydantic.fields import FieldInfo
 from pydantic_core import PydanticUndefined
 from pydantic_core import PydanticUndefined
 
 
 from cipy.common import Inputs, Outputs, Ref
 from cipy.common import Inputs, Outputs, Ref
@@ -48,6 +50,29 @@ class Results(SimpleNamespace):
         return self.__getattribute__(subscript)
         return self.__getattribute__(subscript)
 
 
 
 
+class _Stub:
+    pass
+
+
+def _chain_attrs(context: Any, ref: Ref) -> Any:
+    try:
+        for token in ref.path:
+            if (
+                isinstance(context, Results.Item)
+                and context.conclusion is not Status.SUCCESS
+            ):
+                return _Stub()
+
+            if isinstance(context, dict):
+                context = context[token]
+            else:
+                context = getattr(context, token)
+        return context
+    except KeyError, AttributeError:
+        reason = "NULL object" if context is None else "not found"
+        raise AttributeError(f'unable to find {ref} item "{token}": {reason}')
+
+
 class Context(SimpleNamespace):
 class Context(SimpleNamespace):
     """Wrapper class for the context of the CI runtime"""
     """Wrapper class for the context of the CI runtime"""
 
 
@@ -66,9 +91,7 @@ class Context(SimpleNamespace):
             assert len(arg.path) == 2
             assert len(arg.path) == 2
             return os.environ.get(arg.path[1])
             return os.environ.get(arg.path[1])
 
 
-        return reduce(  # type: ignore[return-value]
-            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), arg.path, self
-        )
+        return _chain_attrs(self, arg)
 
 
     @overload
     @overload
     def fabricate(
     def fabricate(
@@ -105,15 +128,36 @@ class Context(SimpleNamespace):
             fields = vars(model)
             fields = vars(model)
 
 
         for name, fld in annotation.__pydantic_fields__.items():
         for name, fld in annotation.__pydantic_fields__.items():
+            coerce = partial(self.__coerce, state, name, fld)
             if name in extra:
             if name in extra:
-                fields[name] = self(extra[name])
+                fields[name] = coerce(extra[name])
             elif fld.default is not PydanticUndefined:
             elif fld.default is not PydanticUndefined:
-                fields[name] = self(fld.default)
+                fields[name] = coerce(fld.default)
 
 
         model = annotation(**fields)
         model = annotation(**fields)
         setattr(state, attr, model)
         setattr(state, attr, model)
         return model
         return model
 
 
+    def __coerce(
+        self, state: BaseModel, name: str, fld: FieldInfo, arg: Value | None
+    ) -> Any:
+        value: Scalar | None = self(arg)
+        if value is None or not isinstance(value, _Stub):
+            return value
+
+        anno = fld.annotation
+
+        assert anno is not None
+        if isinstance(anno, UnionType):
+            anno = next(iter(t for t in get_args(anno) if t is not NoneType))
+
+        assert hasattr(state, "logger")
+        state.logger.warning(
+            'binding %s to "%s" failed: action was not successful', arg, name
+        )
+        state.logger.debug("coercing to %s", anno.__name__)
+        return anno()
+
     @contextmanager
     @contextmanager
     def extend(self, **kwargs: Any) -> Iterator[Context]:
     def extend(self, **kwargs: Any) -> Iterator[Context]:
         """Create a new context that inherits an extra property"""
         """Create a new context that inherits an extra property"""

+ 72 - 0
src/cipy/predicate.py

@@ -0,0 +1,72 @@
+import abc
+import logging
+
+from dataclasses import dataclass
+from types import SimpleNamespace
+from typing import Callable, Final, Protocol, final
+
+
+@dataclass
+class Predicate(abc.ABC):
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([self, other])
+
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([self, other])
+
+    def __invert__(self) -> Predicate:
+        return NotPredicate(self)
+
+
+class CompoundPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        ...
+
+
+@dataclass
+class AndPredicate(CompoundPredicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " && ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return all(callback(p) for p in self.preds)
+
+    @final
+    def __and__(self, other: Predicate) -> Predicate:
+        return AndPredicate([*self.preds, other])
+
+
+@dataclass
+class OrPredicate(CompoundPredicate):
+    preds: list[Predicate]
+
+    def __repr__(self) -> str:
+        return " || ".join([str(p) for p in self.preds])
+
+    @final
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return any(callback(p) for p in self.preds)
+
+    @final
+    def __or__(self, other: Predicate) -> Predicate:
+        return OrPredicate([*self.preds, other])
+
+
+@dataclass
+class NotPredicate(CompoundPredicate):
+    pred: Predicate
+
+    def __repr__(self) -> str:
+        return f"!{self.pred}"
+
+    @final
+    def __call__(self, callback: Callable[[Predicate], bool]) -> bool:
+        return not callback(self.pred)
+
+    @final
+    def __not__(self) -> Predicate:
+        return self.pred

+ 48 - 0
src/cipy/status.py

@@ -1,5 +1,11 @@
 """Status enum"""
 """Status enum"""
+
+import abc
+
 from enum import Enum, auto
 from enum import Enum, auto
+from typing import TypeVar, final
+
+from cipy.predicate import Predicate
 
 
 
 
 class Status(Enum):
 class Status(Enum):
@@ -13,3 +19,45 @@ class Status(Enum):
 
 
     def __ior__(self, other: Status) -> Status:
     def __ior__(self, other: Status) -> Status:
         return self if self.value > other.value else other
         return self if self.value > other.value else other
+
+
+class StatusPredicate(Predicate):
+    @abc.abstractmethod
+    def __call__(self, status: Status) -> bool:
+        ...
+
+
+class Success(StatusPredicate):
+    def __repr__(self) -> str:
+        return "success()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status.value <= Status.SUCCESS.value
+
+
+class Always(StatusPredicate):
+    def __repr__(self) -> str:
+        return "always()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return True
+
+
+class Cancelled(StatusPredicate):
+    def __repr__(self) -> str:
+        return "cancelled()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is Status.CANCELLED
+
+
+class Failure(StatusPredicate):
+    def __repr__(self) -> str:
+        return "failure()"
+
+    @final
+    def __call__(self, status: Status) -> bool:
+        return status is Status.FAILURE

+ 3 - 2
src/cipy/workflow.py

@@ -56,7 +56,7 @@ class Workflow(Action):
                 self.logger.info("Running Job: %s", job.id)
                 self.logger.info("Running Job: %s", job.id)
                 visited.add(job.id)
                 visited.add(job.id)
 
 
-                if job.action.enabled(status, context):
+                if job.action.is_enabled(status, context):
                     stat = job.action.run(context)
                     stat = job.action.run(context)
                     outctx.needs[job.id] = Results.Item(
                     outctx.needs[job.id] = Results.Item(
                         stat, job.action.outputs.validated()
                         stat, job.action.outputs.validated()
@@ -68,7 +68,8 @@ class Workflow(Action):
                 status |= stat
                 status |= stat
                 self._finished(job.id)
                 self._finished(job.id)
 
 
-            outctx.fabricate(self, "outputs")
+            if status is Status.SUCCESS:
+                outctx.fabricate(self, "outputs")
 
 
         return status
         return status