Ver Fonte

refactor: replace _validate() free function with Outputs.validated()

Sam Jaffe há 1 mês atrás
pai
commit
c79861c08f
4 ficheiros alterados com 11 adições e 24 exclusões
  1. 2 2
      src/cipy/action.py
  2. 6 18
      src/cipy/common.py
  3. 1 2
      src/cipy/runner.py
  4. 2 2
      src/cipy/workflow.py

+ 2 - 2
src/cipy/action.py

@@ -12,7 +12,7 @@ from pydantic import Field, PrivateAttr
 
 import cipy.runner
 from cipy import settings
-from cipy.common import Action, Context, Factory, Ref, Results, Status, _validate
+from cipy.common import Action, Context, Factory, Ref, Results, Status
 from cipy.shell import Shell
 
 from . import _io
@@ -161,7 +161,7 @@ class Composite(Action):
 
                 if step.enabled(status, context):
                     stat = step.run(context)
-                    outctx.steps[step.id] = Results.Item(stat, _validate(step.outputs))
+                    outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                 else:
                     stat = Status.SKIPPED
                     outctx.steps[step.id] = Results.Item(Status.SKIPPED)

+ 6 - 18
src/cipy/common.py

@@ -9,7 +9,7 @@ from contextlib import contextmanager
 from enum import Enum, auto
 from functools import reduce
 from types import SimpleNamespace, NoneType
-from typing import Annotated, Any, Callable, Iterator, Literal, overload
+from typing import Annotated, Any, Callable, Iterator, Literal, Self, final, overload
 
 from pydantic import BaseModel, Field
 from pydantic_core import PydanticUndefined
@@ -39,6 +39,11 @@ class Inputs(BaseModel):
 class Outputs(BaseModel):
     """Stub class describing result data"""
 
+    @final
+    def validated(self) -> Self:
+        """Validate this output object, affirm that it is properly constructed"""
+        return self.model_validate(self, extra="forbid")
+
 
 @dataclasses.dataclass
 class Ref:
@@ -144,7 +149,6 @@ class Context(SimpleNamespace):
                 fields[name] = self(fld.default)
 
         model = annotation(**fields)
-        _validate(model)
         setattr(state, attr, model)
         return model
 
@@ -181,19 +185,3 @@ class Action(BaseModel, abc.ABC):
 
     def cleanup(self, context: Context) -> None:
         """Optionally clean up after ourselves"""
-
-
-def _validate(model: BaseModel):
-    """Perform the actual model validation that we sabotaged w/ required() and similar functions"""
-    for k, fld in model.__pydantic_fields__.items():
-        attr = getattr(model, k)
-        if fld.annotation is None or isinstance(attr, fld.annotation):
-            continue
-
-        if isinstance(attr, (Ref, Factory, NoneType)):
-            raise TypeError(f"fld '{k}' in {type(model).__qualname__} is unset")
-
-        raise TypeError(
-            f"field '{k}' in {type(model).__qualname__} is of the wrong type "
-            f"(should be {fld.annotation})"
-        )

+ 1 - 2
src/cipy/runner.py

@@ -15,7 +15,7 @@ from dotenv import dotenv_values
 
 import cipy.common
 from cipy import settings
-from cipy.common import Context, Status, _validate
+from cipy.common import Context, Status
 
 Action = TypeVar("Action", bound=cipy.common.Action)
 type Run[Action] = Callable[[Action, Context], Status]
@@ -71,7 +71,6 @@ def ipc(func: Run[Action]) -> Run[Action]:
             assert annotation is not None
             self.outputs = annotation(**outdata)
 
-            _validate(self.outputs)
         return rval
 
     return wrapper

+ 2 - 2
src/cipy/workflow.py

@@ -7,7 +7,7 @@ from typing import Any, Iterable, final, override
 
 from pydantic import BaseModel, PrivateAttr
 
-from cipy.common import Action, Context, Results, Scalar, Status, Value, _validate
+from cipy.common import Action, Context, Results, Scalar, Status, Value
 
 
 class Job(BaseModel):
@@ -55,7 +55,7 @@ class Workflow(Action):
                 if job.action.enabled(status, context):
                     stat = job.action.run(context)
                     outctx.needs[job.id] = Results.Item(
-                        stat, _validate(job.action.outputs)
+                        stat, job.action.outputs.validated()
                     )
                 else:
                     stat = Status.SKIPPED