Przeglądaj źródła

refactor: replace _validate() free function with Outputs.validated()

Sam Jaffe 1 miesiąc temu
rodzic
commit
c79861c08f
4 zmienionych plików z 11 dodań i 24 usunięć
  1. 2 2
      src/cipy/action.py
  2. 6 18
      src/cipy/common.py
  3. 1 2
      src/cipy/runner.py
  4. 2 2
      src/cipy/workflow.py

+ 2 - 2
src/cipy/action.py

@@ -12,7 +12,7 @@ from pydantic import Field, PrivateAttr
 
 
 import cipy.runner
 import cipy.runner
 from cipy import settings
 from cipy import settings
-from cipy.common import Action, Context, Factory, Ref, Results, Status, _validate
+from cipy.common import Action, Context, Factory, Ref, Results, Status
 from cipy.shell import Shell
 from cipy.shell import Shell
 
 
 from . import _io
 from . import _io
@@ -161,7 +161,7 @@ class Composite(Action):
 
 
                 if step.enabled(status, context):
                 if step.enabled(status, context):
                     stat = step.run(context)
                     stat = step.run(context)
-                    outctx.steps[step.id] = Results.Item(stat, _validate(step.outputs))
+                    outctx.steps[step.id] = Results.Item(stat, step.outputs.validated())
                 else:
                 else:
                     stat = Status.SKIPPED
                     stat = Status.SKIPPED
                     outctx.steps[step.id] = Results.Item(Status.SKIPPED)
                     outctx.steps[step.id] = Results.Item(Status.SKIPPED)

+ 6 - 18
src/cipy/common.py

@@ -9,7 +9,7 @@ from contextlib import contextmanager
 from enum import Enum, auto
 from enum import Enum, auto
 from functools import reduce
 from functools import reduce
 from types import SimpleNamespace, NoneType
 from types import SimpleNamespace, NoneType
-from typing import Annotated, Any, Callable, Iterator, Literal, overload
+from typing import Annotated, Any, Callable, Iterator, Literal, Self, final, overload
 
 
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from pydantic_core import PydanticUndefined
 from pydantic_core import PydanticUndefined
@@ -39,6 +39,11 @@ class Inputs(BaseModel):
 class Outputs(BaseModel):
 class Outputs(BaseModel):
     """Stub class describing result data"""
     """Stub class describing result data"""
 
 
+    @final
+    def validated(self) -> Self:
+        """Validate this output object, affirm that it is properly constructed"""
+        return self.model_validate(self, extra="forbid")
+
 
 
 @dataclasses.dataclass
 @dataclasses.dataclass
 class Ref:
 class Ref:
@@ -144,7 +149,6 @@ class Context(SimpleNamespace):
                 fields[name] = self(fld.default)
                 fields[name] = self(fld.default)
 
 
         model = annotation(**fields)
         model = annotation(**fields)
-        _validate(model)
         setattr(state, attr, model)
         setattr(state, attr, model)
         return model
         return model
 
 
@@ -181,19 +185,3 @@ class Action(BaseModel, abc.ABC):
 
 
     def cleanup(self, context: Context) -> None:
     def cleanup(self, context: Context) -> None:
         """Optionally clean up after ourselves"""
         """Optionally clean up after ourselves"""
-
-
-def _validate(model: BaseModel):
-    """Perform the actual model validation that we sabotaged w/ required() and similar functions"""
-    for k, fld in model.__pydantic_fields__.items():
-        attr = getattr(model, k)
-        if fld.annotation is None or isinstance(attr, fld.annotation):
-            continue
-
-        if isinstance(attr, (Ref, Factory, NoneType)):
-            raise TypeError(f"fld '{k}' in {type(model).__qualname__} is unset")
-
-        raise TypeError(
-            f"field '{k}' in {type(model).__qualname__} is of the wrong type "
-            f"(should be {fld.annotation})"
-        )

+ 1 - 2
src/cipy/runner.py

@@ -15,7 +15,7 @@ from dotenv import dotenv_values
 
 
 import cipy.common
 import cipy.common
 from cipy import settings
 from cipy import settings
-from cipy.common import Context, Status, _validate
+from cipy.common import Context, Status
 
 
 Action = TypeVar("Action", bound=cipy.common.Action)
 Action = TypeVar("Action", bound=cipy.common.Action)
 type Run[Action] = Callable[[Action, Context], Status]
 type Run[Action] = Callable[[Action, Context], Status]
@@ -71,7 +71,6 @@ def ipc(func: Run[Action]) -> Run[Action]:
             assert annotation is not None
             assert annotation is not None
             self.outputs = annotation(**outdata)
             self.outputs = annotation(**outdata)
 
 
-            _validate(self.outputs)
         return rval
         return rval
 
 
     return wrapper
     return wrapper

+ 2 - 2
src/cipy/workflow.py

@@ -7,7 +7,7 @@ from typing import Any, Iterable, final, override
 
 
 from pydantic import BaseModel, PrivateAttr
 from pydantic import BaseModel, PrivateAttr
 
 
-from cipy.common import Action, Context, Results, Scalar, Status, Value, _validate
+from cipy.common import Action, Context, Results, Scalar, Status, Value
 
 
 
 
 class Job(BaseModel):
 class Job(BaseModel):
@@ -55,7 +55,7 @@ class Workflow(Action):
                 if job.action.enabled(status, context):
                 if job.action.enabled(status, context):
                     stat = job.action.run(context)
                     stat = job.action.run(context)
                     outctx.needs[job.id] = Results.Item(
                     outctx.needs[job.id] = Results.Item(
-                        stat, _validate(job.action.outputs)
+                        stat, job.action.outputs.validated()
                     )
                     )
                 else:
                 else:
                     stat = Status.SKIPPED
                     stat = Status.SKIPPED