瀏覽代碼

refactor: make Ref into a dataclass

Sam Jaffe 1 月之前
父節點
當前提交
277c3fad39
共有 1 個文件被更改,包括 12 次插入18 次删除
  1. 12 18
      src/cipy/common.py

+ 12 - 18
src/cipy/common.py

@@ -9,7 +9,7 @@ from contextlib import contextmanager
 from enum import Enum, auto
 from functools import reduce
 from types import SimpleNamespace, NoneType
-from typing import Any, Callable, Iterator, Literal, overload
+from typing import Annotated, Any, Callable, Iterator, Literal, overload
 
 from pydantic import BaseModel, Field
 from pydantic_core import core_schema
@@ -36,20 +36,16 @@ class Outputs(BaseModel):
     """Stub class describing result data"""
 
 
-class Ref(str):
+@dataclasses.dataclass
+class Ref:
     """Annotation class describing a reference into Context or another place"""
+    path: list[Annotated[str, Field(pattern="\\w*(_\\w*)*")]]
 
-    @classmethod
-    # pylint: disable=unused-argument
-    def __get_pydantic_core_schema__(cls, source, handler) -> core_schema.CoreSchema:
-
-        def validate(s, _):
-            if s and all(t for t in s.split(".")):
-                return s
+    def __init__(self, pathstr: str) -> None:
+        self.path = pathstr.split(".")
+        if not self.path:
             raise ValueError("References must be of the form A.B.C etc.")
 
-        return core_schema.general_plain_validator_function(validate)
-
 
 @dataclasses.dataclass
 class Factory:
@@ -84,16 +80,14 @@ class Results(SimpleNamespace):
 class Context(SimpleNamespace):
     """Wrapper class for the context of the CI runtime"""
 
-    def access(self, ctx: str) -> Any:
+    def access(self, ref: Ref) -> Any:
         """Accessor for context state with a dot-separated path"""
-        path = ctx.split(".")
-
-        if path[0] == "env":
-            assert len(path) == 2
-            return os.environ.get(path[1])
+        if ref.path[0] == "env":
+            assert len(ref.path) == 2
+            return os.environ.get(ref.path[1])
 
         return reduce(
-            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), path, self
+            lambda o, a: o[a] if isinstance(o, dict) else getattr(o, a), ref.path, self
         )
 
     @overload