|
@@ -8,7 +8,7 @@ import os
|
|
|
from contextlib import contextmanager
|
|
from contextlib import contextmanager
|
|
|
from enum import Enum, auto
|
|
from enum import Enum, auto
|
|
|
from functools import reduce
|
|
from functools import reduce
|
|
|
-from types import SimpleNamespace
|
|
|
|
|
|
|
+from types import SimpleNamespace, NoneType
|
|
|
from typing import Any, Callable, Iterator, Literal, overload
|
|
from typing import Any, Callable, Iterator, Literal, overload
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
from pydantic import BaseModel, Field
|
|
@@ -47,6 +47,7 @@ class Ref(str):
|
|
|
if s and all(t for t in s.split(".")):
|
|
if s and all(t for t in s.split(".")):
|
|
|
return s
|
|
return s
|
|
|
raise ValueError("References must be of the form A.B.C etc.")
|
|
raise ValueError("References must be of the form A.B.C etc.")
|
|
|
|
|
+
|
|
|
return core_schema.general_plain_validator_function(validate)
|
|
return core_schema.general_plain_validator_function(validate)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -91,16 +92,31 @@ class Context(SimpleNamespace):
|
|
|
assert len(path) == 2
|
|
assert len(path) == 2
|
|
|
return os.environ.get(path[1])
|
|
return os.environ.get(path[1])
|
|
|
|
|
|
|
|
- return reduce(getattr, path, self)
|
|
|
|
|
|
|
+ return reduce(
|
|
|
|
|
+ lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), path, self
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
@overload
|
|
@overload
|
|
|
- def fabricate(self, state: BaseModel, attr: Literal["inputs"]) -> Inputs: ...
|
|
|
|
|
|
|
+ def fabricate(
|
|
|
|
|
+ self,
|
|
|
|
|
+ state: BaseModel,
|
|
|
|
|
+ attr: Literal["inputs"],
|
|
|
|
|
+ extra: dict[str, Any] = {},
|
|
|
|
|
+ ) -> Inputs: ...
|
|
|
|
|
|
|
|
@overload
|
|
@overload
|
|
|
- def fabricate(self, state: BaseModel, attr: Literal["outputs"]) -> Outputs: ...
|
|
|
|
|
|
|
+ def fabricate(
|
|
|
|
|
+ self,
|
|
|
|
|
+ state: BaseModel,
|
|
|
|
|
+ attr: Literal["outputs"],
|
|
|
|
|
+ extra: dict[str, Any] = {},
|
|
|
|
|
+ ) -> Outputs: ...
|
|
|
|
|
|
|
|
def fabricate(
|
|
def fabricate(
|
|
|
- self, state: BaseModel, attr: Literal["inputs"] | Literal["outputs"]
|
|
|
|
|
|
|
+ self,
|
|
|
|
|
+ state: BaseModel,
|
|
|
|
|
+ attr: Literal["inputs"] | Literal["outputs"],
|
|
|
|
|
+ extra: dict[str, Ref | Factory] = {},
|
|
|
) -> Inputs | Outputs:
|
|
) -> Inputs | Outputs:
|
|
|
"""Fabricate and validate an Inputs or Outputs object"""
|
|
"""Fabricate and validate an Inputs or Outputs object"""
|
|
|
model = getattr(state, attr)
|
|
model = getattr(state, attr)
|
|
@@ -109,15 +125,27 @@ class Context(SimpleNamespace):
|
|
|
assert annotation is not None
|
|
assert annotation is not None
|
|
|
model = annotation()
|
|
model = annotation()
|
|
|
|
|
|
|
|
- for k, field in model.__pydantic_fields__.items():
|
|
|
|
|
- value = getattr(model, k)
|
|
|
|
|
- if field.annotation is None:
|
|
|
|
|
- continue
|
|
|
|
|
-
|
|
|
|
|
|
|
+ for k, value in extra.items():
|
|
|
if isinstance(value, Ref):
|
|
if isinstance(value, Ref):
|
|
|
setattr(model, k, self.access(value))
|
|
setattr(model, k, self.access(value))
|
|
|
elif isinstance(value, Factory):
|
|
elif isinstance(value, Factory):
|
|
|
setattr(model, k, value(self))
|
|
setattr(model, k, value(self))
|
|
|
|
|
+ else:
|
|
|
|
|
+ setattr(model, k, value)
|
|
|
|
|
+
|
|
|
|
|
+ for k, fld in model.__pydantic_fields__.items():
|
|
|
|
|
+ value = getattr(model, k)
|
|
|
|
|
+ if fld.annotation is None:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ if k in extra:
|
|
|
|
|
+ pass
|
|
|
|
|
+ elif isinstance(value, Ref):
|
|
|
|
|
+ value = self.access(value)
|
|
|
|
|
+ elif isinstance(value, Factory):
|
|
|
|
|
+ value = value(self)
|
|
|
|
|
+ if value is not None:
|
|
|
|
|
+ setattr(model, k, fld.annotation(value))
|
|
|
|
|
|
|
|
_validate(model)
|
|
_validate(model)
|
|
|
setattr(state, attr, model)
|
|
setattr(state, attr, model)
|
|
@@ -167,14 +195,14 @@ class Action(BaseModel, abc.ABC):
|
|
|
|
|
|
|
|
def _validate(model: BaseModel):
|
|
def _validate(model: BaseModel):
|
|
|
"""Perform the actual model validation that we sabotaged w/ required() and similar functions"""
|
|
"""Perform the actual model validation that we sabotaged w/ required() and similar functions"""
|
|
|
- for k, field in model.__pydantic_fields__.items():
|
|
|
|
|
|
|
+ for k, fld in model.__pydantic_fields__.items():
|
|
|
attr = getattr(model, k)
|
|
attr = getattr(model, k)
|
|
|
- if field.annotation is None:
|
|
|
|
|
|
|
+ if fld.annotation is None:
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- if isinstance(attr, (Ref, Factory)):
|
|
|
|
|
- raise TypeError(f"field '{k}' in {type(model)} is unset")
|
|
|
|
|
- if not isinstance(attr, field.annotation):
|
|
|
|
|
|
|
+ if isinstance(attr, (Ref, Factory, NoneType)):
|
|
|
|
|
+ raise TypeError(f"fld '{k}' in {type(model).__qualname__} is unset")
|
|
|
|
|
+ if not isinstance(attr, fld.annotation):
|
|
|
raise TypeError(
|
|
raise TypeError(
|
|
|
- f"field '{k}' in {type(model)} is of the wrong type (should be {field.annotation})"
|
|
|
|
|
|
|
+ f"field '{k}' in {type(model).__qualname__} is of the wrong type (should be {fld.annotation})"
|
|
|
)
|
|
)
|