Explorar el Código

refactor: better fabrication of Inputs/Outputs from Context object(s)

Sam Jaffe hace 1 mes
padre
commit
77535fe6e3
Se han modificado 1 ficheros con 12 adiciones y 19 borrados
  1. 12 19
      src/cipy/common.py

+ 12 - 19
src/cipy/common.py

@@ -12,6 +12,7 @@ from types import SimpleNamespace, NoneType
 from typing import Annotated, Any, Callable, Iterator, Literal, overload
 from typing import Annotated, Any, Callable, Iterator, Literal, overload
 
 
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
+from pydantic_core import PydanticUndefined
 
 
 type Scalar = bool | int | float | str
 type Scalar = bool | int | float | str
 type Computed = Ref | Factory
 type Computed = Ref | Factory
@@ -129,28 +130,20 @@ class Context(SimpleNamespace):
         if extra is None:
         if extra is None:
             extra = {}
             extra = {}
 
 
-        model = getattr(state, attr)
-        if model is None:
-            annotation = state.__pydantic_fields__[attr].annotation
-            assert annotation is not None
-            model = annotation()
+        fields: dict[str, Any] = {}
+        if (model := getattr(state, attr)) is not None:
+            fields = vars(model)
 
 
-        value: Value | None = None
-        for k, value in extra.items():
-            setattr(model, k, self(value))
+        annotation = state.__pydantic_fields__[attr].annotation
+        assert annotation is not None
 
 
-        for k, fld in model.__pydantic_fields__.items():
-            value = getattr(model, k)
-            if fld.annotation is None:
-                continue
-
-            if k in extra:
-                pass
-
-            value = self(value)
-            if value is not None:
-                setattr(model, k, fld.annotation(value))
+        for name, fld in annotation.__pydantic_fields__.items():
+            if name in extra:
+                fields[name] = self(extra[name])
+            elif fld.default is not PydanticUndefined:
+                fields[name] = self(fld.default)
 
 
+        model = annotation(**fields)
         _validate(model)
         _validate(model)
         setattr(state, attr, model)
         setattr(state, attr, model)
         return model
         return model