Skip to content

Commit

Permalink
Migrate to pydantic v2 (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamGagorik authored Nov 14, 2023
1 parent 84e7fc0 commit 76ba527
Show file tree
Hide file tree
Showing 10 changed files with 791 additions and 1,484 deletions.
2 changes: 1 addition & 1 deletion bany/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main() -> int:
description=__doc__,
parents=[parent],
formatter_class=RawTextHelpFormatter,
epilog=f"environment:\n{env.json(indent=2)}",
epilog=f"environment:\n{env.model_dump_json(indent=2)}",
)

subparsers = parser.add_subparsers(title="commands")
Expand Down
48 changes: 27 additions & 21 deletions bany/cmd/extract/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from datetime import date
from re import Match
from re import Pattern
from typing import Any

import yaml
from moneyed import Money
from moneyed import USD
from pydantic import BaseModel
from pydantic import parse_obj_as
from pydantic import validator
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from pydantic import TypeAdapter
from pydantic import ValidationInfo


CONTEXT = {
Expand All @@ -32,22 +34,17 @@

class Rule(BaseModel):
flags: int = re.I | re.X
regex: Pattern = SKIP
regex: Pattern = Field(default=SKIP, validate_default=True)
match: Match | None = None
group: str = "UNKNOWN"
transform: str = "{VALUE}"
pages: set[int] | None = None
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)

class Config:
extra = "forbid"
validate_assignment = True
arbitrary_types_allowed = True

# noinspection PyMethodParameters
@validator("regex", pre=True, always=True)
def _validate_regex(cls, value: str | Pattern, values: dict[str, Any]):
@field_validator("regex", mode="before")
def _validate_regex(cls, value: str | Pattern, values: ValidationInfo):
if isinstance(value, str):
return re.compile(value.format(**CONTEXT), values["flags"])
return re.compile(value.format(**CONTEXT), values.data["flags"])
else:
return value

Expand All @@ -61,10 +58,9 @@ class AmountRule(Rule):
total: bool = False
inflow: bool = False
transform: str = "{MONEY}"
value: Money | int | None = None
value: Money | int | None = Field(default=None, validate_default=True)

# noinspection PyMethodParameters
@validator("value", pre=True, always=True)
@field_validator("value", mode="before")
def _validate_value(cls, value: Money | int | None):
if isinstance(value, int):
return Money(amount=value, currency=USD)
Expand All @@ -81,15 +77,25 @@ class TransactionRule(BaseModel):
amount: int | str | tuple[int, int, str]
date: date | str | tuple[int, int, str]

# noinspection PyMethodParameters
@validator("date", pre=True, always=True)
@field_validator("budget", mode="before")
def _validate_budget(cls, value: str):
return str(value)

@field_validator("account", mode="before")
def _validate_account(cls, value: str):
return str(value)

@field_validator("memo", mode="before")
def _validate_memo(cls, value: str):
return str(value)

@field_validator("date", mode="before")
def _validate_date(cls, value: Money | int | None):
if isinstance(value, str):
return 0, 0, value
return value

# noinspection PyMethodParameters
@validator("amount", pre=True, always=True)
@field_validator("amount", mode="before")
def _validate_amount(cls, value: Money | int | None):
if isinstance(value, str):
return 0, 0, value
Expand All @@ -115,4 +121,4 @@ def get_rules(path: pathlib.Path) -> Rules:
def get_rules_from_yml(path: pathlib.Path) -> Rules:
with path.open("r") as stream:
data = yaml.safe_load(stream)
return parse_obj_as(Rules, data)
return TypeAdapter(Rules).validate_python(data)
35 changes: 19 additions & 16 deletions bany/cmd/split/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from moneyed import Money
from moneyed import USD
from pydantic import BaseModel
from pydantic import validator
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from pydantic import ValidationInfo

from bany.core.money import as_money

Expand All @@ -36,12 +39,12 @@ class Split(BaseModel):
#: This is a category for the transaction (Food, Cleaning, etc)
category: str = "Unknown"
#: This is a mapping from the amount owed for each payer
debtors: str | tuple[str, ...] | dict[str, int | float] = ()
debtors: str | tuple[str, ...] | dict[str, int | float] = Field(default=(), validate_default=True)
#: This is the person or persons who paid for the transaction
creditors: str | tuple[str, ...] | dict[str, int | float] = ()
creditors: str | tuple[str, ...] | dict[str, int | float] = Field(default=(), validate_default=True)

@validator("debtors", pre=True, always=True)
def _validate_debtors(cls, value: str | tuple[str, ...], values: dict) -> dict[str:int]:
@field_validator("debtors", mode="before")
def _validate_debtors(cls, value: str | tuple[str, ...], info: ValidationInfo) -> dict[str:int]:
"""
Validate creation of debtors field.
"""
Expand All @@ -53,11 +56,11 @@ def _validate_debtors(cls, value: str | tuple[str, ...], values: dict) -> dict[s
raise TypeError

total = sum(value.values())
amount = values["amount"].get_amount_in_sub_unit()
amount = info.data["amount"].get_amount_in_sub_unit()
return {k: v / total * amount for k, v in value.items()}

@validator("creditors", pre=True, always=True)
def _validate_creditors(cls, value: str | tuple[str, ...], values: dict) -> dict[str:int]:
@field_validator("creditors", mode="before")
def _validate_creditors(cls, value: str | tuple[str, ...], info: ValidationInfo) -> dict[str:int]:
"""
Validate creation of creditors field.
"""
Expand All @@ -69,20 +72,17 @@ def _validate_creditors(cls, value: str | tuple[str, ...], values: dict) -> dict
raise TypeError

total = sum(value.values())
amount = values["amount"].get_amount_in_sub_unit()
amount = info.data["amount"].get_amount_in_sub_unit()
return {k: v / total * amount for k, v in value.items()}

@validator("amount", pre=True, always=True)
@field_validator("amount", mode="before")
def _validate_amounts(cls, value: int | float | Money) -> Money:
"""
Validate creation of Money Field.
"""
return as_money(value)

class Config:
extra = "forbid"
validate_assignment = True
arbitrary_types_allowed = True
model_config = ConfigDict(extra="forbid", validate_assignment=True, arbitrary_types_allowed=True)


@dataclasses.dataclass(slots=True)
Expand Down Expand Up @@ -167,7 +167,10 @@ def _compute_weights_for_payers(self, table: pd.DataFrame) -> pd.DataFrame:
count = pd.DataFrame(iter(table.debtors.apply(Counter))).fillna(0).astype(int)
total = count.sum(axis=1)
for name in self.names:
count[f"{name}.w"] = count[name] / total
try:
count[f"{name}.w"] = count[name] / total
except KeyError:
count[f"{name}.w"] = 0.0
return pd.concat([table, count], axis=1)

@staticmethod
Expand Down Expand Up @@ -267,7 +270,7 @@ def split(self, split: Split, *objs: Tax | Tip) -> int:
Add a group of splits to the tracked splits.
"""
group = len(self.splits)
split = split.copy(update=dict(group=group))
split = split.model_copy(update=dict(group=group))
self.splits[group] = [split, *self._extract_tax_and_tip_for_split(split, *objs)]
return group

Expand Down
9 changes: 3 additions & 6 deletions bany/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from pydantic import AnyUrl
from pydantic import BaseSettings
from pydantic import SecretStr
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict


class Settings(BaseSettings):
YNAB_API_URL: AnyUrl = "https://api.youneedabudget.com/v1"
YNAB_API_KEY: SecretStr = ""

class Config:
env_file = ".env"
env_prefix = ""
env_file_encoding = "utf-8"
model_config = SettingsConfigDict(env_file=".env", env_prefix="", env_file_encoding="utf-8")
6 changes: 3 additions & 3 deletions bany/ynab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import requests
from pydantic import AnyUrl
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from requests import HTTPError
from requests import Response

Expand All @@ -32,15 +32,15 @@ class YNAB:

def _make_url(self, *components: AnyUrl | str) -> AnyUrl:
url = posixpath.join(*(str(c).lstrip("/") for c in itertools.chain((self.environ.YNAB_API_URL,), components)))
return parse_obj_as(AnyUrl, url)
return TypeAdapter(AnyUrl).validate_python(url)

def _make_headers(self, **kwargs):
return dict(Authorization=f"Bearer {self.environ.YNAB_API_KEY.get_secret_value()}") | kwargs

def _make_request(self, method: str, endpoint: str, **kwargs) -> Response:
url = self._make_url(endpoint)
headers = self._make_headers(**kwargs.pop("headers", {}))
response = requests.request(method, url, headers=headers, **kwargs)
response = requests.request(method, str(url), headers=headers, **kwargs)
try:
response.raise_for_status()
return response
Expand Down
2 changes: 1 addition & 1 deletion bany/ynab/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _make_request(self, method: str, endpoint: str, **kwargs) -> Response:
with responses.RequestsMock() as mocker:
mocker.add(
method=method,
url=self._make_url(endpoint),
url=str(self._make_url(endpoint)),
headers=self._make_headers(**kwargs.pop("headers", {})),
**self.mockdata.get(endpoint, {}),
)
Expand Down
22 changes: 9 additions & 13 deletions bany/ynab/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Literal

from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import validator
from pydantic import field_validator
from pydantic import ValidationInfo


NS = uuid.UUID("b9b024c9-e918-4447-9b75-2b340535d49e")
Expand Down Expand Up @@ -32,24 +34,18 @@ class Transaction(BaseModel):
approved: bool = False

import_index: int = Field(default=0, exclude=True, repr=False)
import_id: str | None = Field(repr=False)
import_id: str | None = Field(None, repr=False, validate_default=True)
model_config = ConfigDict(extra="forbid", frozen=True)

class Config:
extra = "forbid"
allow_mutation = False

# noinspection PyMethodParameters
@validator("import_id", pre=True, always=True)
def _set_import_id(cls, v, values):
@field_validator("import_id", mode="before")
def _set_import_id(cls, v, values: ValidationInfo):
v = v if v is not None else "{account_id}:{date}:{amount}:{payee_name}:{import_index}"
return str(uuid.uuid5(NS, v.format(**values)))
return str(uuid.uuid5(NS, v.format(**values.data)))

def __hash__(self):
return hash(self.import_id)


class Transactions(BaseModel):
transactions: list[Transaction]

class Config:
allow_mutation = False
model_config = ConfigDict(frozen=True)
Loading

0 comments on commit 76ba527

Please sign in to comment.