Skip to content

Commit

Permalink
refactor: use pydantic to validate settings
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Cherng <[email protected]>
  • Loading branch information
jfcherng committed Nov 3, 2024
1 parent ad41ac1 commit d9bf523
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 57 deletions.
19 changes: 8 additions & 11 deletions plugin/rules/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..cache import clearable_lru_cache
from ..constants import PLUGIN_NAME, ST_PLATFORM
from ..logger import Logger
from ..snapshot import ViewSnapshot
from ..types import Optimizable, StConstraintRule
from ..utils import (
Expand Down Expand Up @@ -46,7 +47,7 @@ class ConstraintRule(Optimizable):
constraint_name: str = ""
args: tuple[Any, ...] = tuple()
kwargs: dict[str, Any] = field(default_factory=dict)
inverted: bool = False # whether the test result should be inverted
inverted: bool = False

def is_droppable(self) -> bool:
return not (self.constraint and not self.constraint.is_droppable())
Expand Down Expand Up @@ -75,20 +76,16 @@ def make(cls, constraint_rule: StConstraintRule) -> Self:
"""Build this object with the `constraint_rule`."""
obj = cls()

if args := constraint_rule.get("args"):
# make sure args is always a tuple
obj.args = tuple(args) if isinstance(args, list) else (args,)
obj.args = tuple(constraint_rule.args)
obj.kwargs = constraint_rule.kwargs
obj.inverted = constraint_rule.inverted

if kwargs := constraint_rule.get("kwargs"):
obj.kwargs = kwargs

if (inverted := constraint_rule.get("inverted")) is not None:
obj.inverted = bool(inverted)

if constraint := constraint_rule.get("constraint"):
if constraint := constraint_rule.constraint:
obj.constraint_name = constraint
if constraint_class := find_constraint(constraint):
obj.constraint = constraint_class(*obj.args, **obj.kwargs)
else:
Logger.log(f"Unsupported constraint rule: {constraint}")

return obj

Expand Down
23 changes: 10 additions & 13 deletions plugin/rules/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing_extensions import Self

from ..cache import clearable_lru_cache
from ..logger import Logger
from ..snapshot import ViewSnapshot
from ..types import Optimizable, StMatchRule
from ..types import Optimizable, StConstraintRule, StMatchRule
from ..utils import camel_to_snake, list_all_subclasses, remove_suffix
from .constraint import ConstraintRule

Expand All @@ -30,8 +31,6 @@ def list_matches() -> Generator[type[AbstractMatch], None, None]:

@dataclass
class MatchRule(Optimizable):
DEFAULT_MATCH_NAME = "any"

match: AbstractMatch | None = None
match_name: str = ""
args: tuple[Any, ...] = tuple()
Expand Down Expand Up @@ -63,24 +62,22 @@ def make(cls, match_rule: StMatchRule) -> Self:
"""Build this object with the `match_rule`."""
obj = cls()

if args := match_rule.get("args"):
# make sure args is always a tuple
obj.args = tuple(args) if isinstance(args, list) else (args,)

if kwargs := match_rule.get("kwargs"):
obj.kwargs = kwargs
obj.args = tuple(match_rule.args)
obj.kwargs = match_rule.kwargs

match = match_rule.get("match", cls.DEFAULT_MATCH_NAME)
match = match_rule.match
if match_class := find_match(match):
obj.match_name = match
obj.match = match_class(*obj.args, **obj.kwargs)
else:
Logger.log(f"Unsupported match rule: {match}")

rules_compiled: list[MatchableRule] = []
for rule in match_rule.get("rules", []):
for rule in match_rule.rules:
rule_class: type[MatchableRule] | None = None
if "constraint" in rule:
if isinstance(rule, StConstraintRule):
rule_class = ConstraintRule
elif "rules" in rule: # nested MatchRule
elif isinstance(rule, StMatchRule):
rule_class = MatchRule
if rule_class and (rule_compiled := rule_class.make(rule)): # type: ignore
rules_compiled.append(rule_compiled)
Expand Down
17 changes: 4 additions & 13 deletions plugin/rules/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,14 @@ def make(cls, syntax_rule: StSyntaxRule) -> Self:
"""Build this object with the `syntax_rule`."""
obj = cls()

if comment := syntax_rule.get("comment"):
obj.comment = str(comment)

syntaxes = syntax_rule.get("syntaxes", [])
if isinstance(syntaxes, str):
syntaxes = [syntaxes]
obj.syntaxes_name = tuple(syntaxes)
if target_syntax := find_syntax_by_syntax_likes(syntaxes):
obj.syntaxes_name = tuple(syntax_rule.syntaxes)
if target_syntax := find_syntax_by_syntax_likes(syntax_rule.syntaxes):
obj.syntax = target_syntax

# note that an empty string selector should match any scope
if (selector := syntax_rule.get("selector")) is not None:
obj.selector = selector
obj.selector = syntax_rule.selector

if (on_events := syntax_rule.get("on_events")) is not None:
if isinstance(on_events, str):
on_events = [on_events]
if (on_events := syntax_rule.on_events) is not None:
obj.on_events = set(drop_falsy(map(ListenerEvent.from_value, on_events)))

if match_rule_compiled := MatchRule.make(syntax_rule):
Expand Down
5 changes: 3 additions & 2 deletions plugin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from collections import ChainMap
from itertools import chain
from typing import Any, Callable, Mapping, MutableMapping
from typing import Any, Callable, List, Mapping, MutableMapping

import sublime
import sublime_plugin
from more_itertools import unique_everseen
from pydantic import TypeAdapter

from .types import StSyntaxRule
from .utils import drop_falsy
Expand Down Expand Up @@ -34,7 +35,7 @@ def get_st_settings() -> sublime.Settings:


def pref_syntax_rules(*, window: sublime.Window | None = None) -> list[StSyntaxRule]:
return get_merged_plugin_setting("syntax_rules", [], window=window)
return TypeAdapter(List[StSyntaxRule]).validate_python(get_merged_plugin_setting("syntax_rules", [], window=window))


def pref_trim_suffixes(*, window: sublime.Window | None = None) -> tuple[str]:
Expand Down
54 changes: 36 additions & 18 deletions plugin/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from collections import UserDict as BuiltinUserDict
from collections.abc import Generator, Hashable, Iterator, KeysView
from enum import Enum
from typing import Any, Generic, TypedDict, TypeVar, Union, overload
from typing import Any, Generic, TypeVar, Union, overload

import sublime
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Self

SyntaxLike = Union[str, sublime.Syntax]
Expand Down Expand Up @@ -95,31 +96,48 @@ def optimize(self) -> Generator[Any, None, None]:
"""Does optimizations and returns a generator for dropped objects."""


class StConstraintRule(TypedDict):
"""Typed dict for corresponding ST settings."""
class StConstraintRule(BaseModel):
"""Model for a "constraint rule" in settings."""

constraint: str
args: list[Any] | Any | None
kwargs: dict[str, Any] | None
inverted: bool
"""The name of the "constraint"."""
args: list[Any] = Field(default_factory=list)
"""Positional arguments for the "constraint"."""
kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments for the "constraint"."""
inverted: bool = False
"""Whether the test result should be inverted."""


class StMatchRule(TypedDict):
"""Typed dict for corresponding ST settings."""
class StMatchRule(BaseModel):
"""Model for a "match rule" in settings."""

match: str
args: list[Any] | Any | None
kwargs: dict[str, Any] | None
rules: list[StMatchRule | StConstraintRule]
match: str = "any"
"""The name of the "match"."""
args: list[Any] = Field(default_factory=list)
"""Positional arguments for the "match"."""
kwargs: dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments for the "match"."""
rules: list[StConstraintRule | StMatchRule] = Field(default_factory=list)
"""Rules to match against."""


class StSyntaxRule(StMatchRule):
"""Typed dict for corresponding ST settings."""

comment: str
selector: str
syntaxes: str | list[str]
on_events: str | list[str] | None
"""Model for a "syntax rule" in settings."""

comment: str = ""
"""A comment for the rule."""
selector: str = "text.plain"
"""To constrain the syntax scope of the current view. An empty string matches any scope."""
syntaxes: list[str] = Field(default_factory=list)
"""Syntaxes to be used. The first available one will be used."""
on_events: list[str] | None = None
"""Events to listen to, or `None` for all events."""

@field_validator("syntaxes", "on_events", mode="before")
@classmethod
def str_to_list_str(cls, v: Any) -> list[str]:
return [v] if isinstance(v, str) else v


class WindowKeyedDict(UserDict[WindowIdAble, _T]):
Expand Down

0 comments on commit d9bf523

Please sign in to comment.