|
|
@@ -1,13 +1,15 @@
|
|
|
"""Classes for managing the context of a CI run"""
|
|
|
+
|
|
|
import os
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
from dataclasses import dataclass, field
|
|
|
-from functools import reduce
|
|
|
-from types import SimpleNamespace
|
|
|
-from typing import Any, Callable, Iterator, Literal, overload
|
|
|
+from functools import partial, reduce
|
|
|
+from types import NoneType, SimpleNamespace, UnionType
|
|
|
+from typing import Any, Callable, Iterator, Literal, Protocol, get_args, overload
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
+from pydantic.fields import FieldInfo
|
|
|
from pydantic_core import PydanticUndefined
|
|
|
|
|
|
from cipy.common import Inputs, Outputs, Ref
|
|
|
@@ -48,6 +50,29 @@ class Results(SimpleNamespace):
|
|
|
return self.__getattribute__(subscript)
|
|
|
|
|
|
|
|
|
+class _Stub:
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+def _chain_attrs(context: Any, ref: Ref) -> Any:
|
|
|
+ try:
|
|
|
+ for token in ref.path:
|
|
|
+ if (
|
|
|
+ isinstance(context, Results.Item)
|
|
|
+ and context.conclusion is not Status.SUCCESS
|
|
|
+ ):
|
|
|
+ return _Stub()
|
|
|
+
|
|
|
+ if isinstance(context, dict):
|
|
|
+ context = context[token]
|
|
|
+ else:
|
|
|
+ context = getattr(context, token)
|
|
|
+ return context
|
|
|
+ except KeyError, AttributeError:
|
|
|
+ reason = "NULL object" if context is None else "not found"
|
|
|
+ raise AttributeError(f'unable to find {ref} item "{token}": {reason}')
|
|
|
+
|
|
|
+
|
|
|
class Context(SimpleNamespace):
|
|
|
"""Wrapper class for the context of the CI runtime"""
|
|
|
|
|
|
@@ -66,9 +91,7 @@ class Context(SimpleNamespace):
|
|
|
assert len(arg.path) == 2
|
|
|
return os.environ.get(arg.path[1])
|
|
|
|
|
|
- return reduce( # type: ignore[return-value]
|
|
|
- lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), arg.path, self
|
|
|
- )
|
|
|
+ return _chain_attrs(self, arg)
|
|
|
|
|
|
@overload
|
|
|
def fabricate(
|
|
|
@@ -105,15 +128,36 @@ class Context(SimpleNamespace):
|
|
|
fields = vars(model)
|
|
|
|
|
|
for name, fld in annotation.__pydantic_fields__.items():
|
|
|
+ coerce = partial(self.__coerce, state, name, fld)
|
|
|
if name in extra:
|
|
|
- fields[name] = self(extra[name])
|
|
|
+ fields[name] = coerce(extra[name])
|
|
|
elif fld.default is not PydanticUndefined:
|
|
|
- fields[name] = self(fld.default)
|
|
|
+ fields[name] = coerce(fld.default)
|
|
|
|
|
|
model = annotation(**fields)
|
|
|
setattr(state, attr, model)
|
|
|
return model
|
|
|
|
|
|
+ def __coerce(
|
|
|
+ self, state: BaseModel, name: str, fld: FieldInfo, arg: Value | None
|
|
|
+ ) -> Any:
|
|
|
+ value: Scalar | None = self(arg)
|
|
|
+ if value is None or not isinstance(value, _Stub):
|
|
|
+ return value
|
|
|
+
|
|
|
+ anno = fld.annotation
|
|
|
+
|
|
|
+ assert anno is not None
|
|
|
+ if isinstance(anno, UnionType):
|
|
|
+ anno = next(iter(t for t in get_args(anno) if t is not NoneType))
|
|
|
+
|
|
|
+ assert hasattr(state, "logger")
|
|
|
+ state.logger.warning(
|
|
|
+ 'binding %s to "%s" failed: action was not successful', arg, name
|
|
|
+ )
|
|
|
+ state.logger.debug("coercing to %s", anno.__name__)
|
|
|
+ return anno()
|
|
|
+
|
|
|
@contextmanager
|
|
|
def extend(self, **kwargs: Any) -> Iterator[Context]:
|
|
|
"""Create a new context that inherits an extra property"""
|