diff --git a/tokenlists/typing.py b/tokenlists/typing.py index 1b04c57..b85991d 100644 --- a/tokenlists/typing.py +++ b/tokenlists/typing.py @@ -55,25 +55,39 @@ class TokenInfo(BaseModel): tags: Optional[List[TagId]] = None extensions: Optional[Dict[str, Any]] = None - @validator("extensions") - def parse_extensions(cls, v: dict): - if "bridgeInfo" in v: + @validator("extensions", pre=True) + def parse_extensions(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + # 1. Check extension depth first + def extension_depth(obj: Optional[Dict[str, Any]]) -> int: + if not isinstance(obj, dict) or len(obj) == 0: + return 0 + + return 1 + max(extension_depth(v) for v in obj.values()) + + if (depth := extension_depth(v)) > 3: + raise ValueError(f"Extension depth is greater than 3: {depth}") + + # 2. Parse valid extensions + if v and "bridgeInfo" in v: v["bridgeInfo"] = BridgeInfo.parse_obj(v.pop("bridgeInfo")) return v @validator("extensions") - def check_extension_depth(cls, v: dict): - def extension_depth(obj: dict) -> int: - depth = 0 - for v in obj.values(): - if isinstance(v, dict): - depth = max(depth, extension_depth(v)) + def extensions_must_contain_allowed_types( + cls, d: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + if not d: + return d - return depth + 1 + # NOTE: `extensions` is mapping from `str` to either: + # - a parsed `dict` type (e.g. `BaseModel`) + # - a "simple" type (e.g. string, integer or boolean value) + for key, val in d.items(): + if not isinstance(val, (BaseModel, str, int, bool)) and val is not None: + raise ValueError(f"Incorrect extension field value: {val}") - assert extension_depth(v) < 3 - return v + return d @property def bridge_info(self) -> Optional[BridgeInfo]: @@ -101,21 +115,6 @@ def decimals_must_be_uint8(cls, v: TokenDecimals): return v - @validator("extensions") - def extensions_must_contain_simple_types(cls, d: Optional[dict]) -> Optional[dict]: - if not d: - return d - - # `extensions` is `Dict[str, Union[str, int, bool, None]]`, but pydantic mutates entries - for key, val in d.items(): - if key in "bridgeInfo": - continue # don't parse valid extensions - - if not isinstance(val, (str, int, bool)) and val is not None: - raise ValueError(f"Incorrect extension field value: {val}") - - return d - class Tag(BaseModel): name: str