Browse Source

refactor: make Ref into a dataclass

Sam Jaffe 1 month ago
parent
commit
277c3fad39
1 changed files with 12 additions and 18 deletions
  1. 12 18
      src/cipy/common.py

+ 12 - 18
src/cipy/common.py

@@ -9,7 +9,7 @@ from contextlib import contextmanager
 from enum import Enum, auto
 from enum import Enum, auto
 from functools import reduce
 from functools import reduce
 from types import SimpleNamespace, NoneType
 from types import SimpleNamespace, NoneType
-from typing import Any, Callable, Iterator, Literal, overload
+from typing import Annotated, Any, Callable, Iterator, Literal, overload
 
 
 from pydantic import BaseModel, Field
 from pydantic import BaseModel, Field
 from pydantic_core import core_schema
 from pydantic_core import core_schema
@@ -36,20 +36,16 @@ class Outputs(BaseModel):
     """Stub class describing result data"""
     """Stub class describing result data"""
 
 
 
 
-class Ref(str):
+@dataclasses.dataclass
+class Ref:
     """Annotation class describing a reference into Context or another place"""
     """Annotation class describing a reference into Context or another place"""
+    path: list[Annotated[str, Field(pattern="\\w*(_\\w*)*")]]
 
 
-    @classmethod
-    # pylint: disable=unused-argument
-    def __get_pydantic_core_schema__(cls, source, handler) -> core_schema.CoreSchema:
-
-        def validate(s, _):
-            if s and all(t for t in s.split(".")):
-                return s
+    def __init__(self, pathstr: str) -> None:
+        self.path = pathstr.split(".")
+        if not self.path:
             raise ValueError("References must be of the form A.B.C etc.")
             raise ValueError("References must be of the form A.B.C etc.")
 
 
-        return core_schema.general_plain_validator_function(validate)
-
 
 
 @dataclasses.dataclass
 @dataclasses.dataclass
 class Factory:
 class Factory:
@@ -84,16 +80,14 @@ class Results(SimpleNamespace):
 class Context(SimpleNamespace):
 class Context(SimpleNamespace):
     """Wrapper class for the context of the CI runtime"""
     """Wrapper class for the context of the CI runtime"""
 
 
-    def access(self, ctx: str) -> Any:
+    def access(self, ref: Ref) -> Any:
         """Accessor for context state with a dot-separated path"""
         """Accessor for context state with a dot-separated path"""
-        path = ctx.split(".")
-
-        if path[0] == "env":
-            assert len(path) == 2
-            return os.environ.get(path[1])
+        if ref.path[0] == "env":
+            assert len(ref.path) == 2
+            return os.environ.get(ref.path[1])
 
 
         return reduce(
         return reduce(
-            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), path, self
+            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), ref.path, self
         )
         )
 
 
     @overload
     @overload