Skip to content

Commit

Permalink
refactor: have Client._get use from_kwargs to hydrate response classes
Browse files Browse the repository at this point in the history
this helper method allows for the presence of unexpected fields in the
response JSON.

this establishes the convention that response classes (at least the ones
that are used for GET responses) must define this class method.
  • Loading branch information
angela-tran committed Sep 16, 2024
1 parent b3321de commit 397e0e4
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 9 deletions.
33 changes: 33 additions & 0 deletions littlepay/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from inspect import signature
from typing import Generator, Protocol, TypeVar

from authlib.integrations.requests_client import OAuth2Session
Expand All @@ -8,13 +9,45 @@
TResponse = TypeVar("TResponse")


def from_kwargs(cls, **kwargs):
"""
Helper function meant to be used as a @classmethod
for instantiating a dataclass and allowing unexpected fields
See https://stackoverflow.com/a/55101438
"""
# fetch the constructor's signature
class_fields = {field for field in signature(cls).parameters}

# split the kwargs into native ones and new ones
native_args, new_args = {}, {}
for name, val in kwargs.items():
if name in class_fields:
native_args[name] = val
else:
new_args[name] = val

# use the native ones to create the class ...
instance = cls(**native_args)

# ... and add the new ones by hand
for new_name, new_val in new_args.items():
setattr(instance, new_name, new_val)

return instance


@dataclass
class ListResponse:
"""An API response with list and total_count attributes."""

list: list
total_count: int

@classmethod
def from_kwargs(cls, **kwargs):
from_kwargs(cls, **kwargs)


class ClientProtocol(Protocol):
"""Protocol describing key functionality for an API connection."""
Expand Down
3 changes: 2 additions & 1 deletion littlepay/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def _delete(self, endpoint: str) -> bool:
def _get(self, endpoint: str, response_cls: TResponse, **kwargs) -> TResponse:
response = self.oauth.get(endpoint, headers=self.headers, params=kwargs)
response.raise_for_status()
return response_cls(**response.json())

return response_cls.from_kwargs(**response.json())

def _get_list(self, endpoint: str, **kwargs) -> Generator[dict, None, None]:
params = dict(page=1, per_page=100)
Expand Down
10 changes: 10 additions & 0 deletions littlepay/api/funding_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from littlepay.api import ClientProtocol

from . import from_kwargs


@dataclass
class FundingSourceResponse:
Expand All @@ -25,6 +27,10 @@ class FundingSourceResponse:
token_key_id: Optional[str] = None
icc_hash: Optional[str] = None

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@dataclass
class FundingSourceDateFields:
Expand Down Expand Up @@ -65,6 +71,10 @@ class FundingSourceGroupResponse(FundingSourceDateFields):
group_id: str
label: str

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class FundingSourcesMixin(ClientProtocol):
"""Mixin implements APIs for funding sources."""
Expand Down
10 changes: 9 additions & 1 deletion littlepay/api/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime, timezone
from typing import Generator

from littlepay.api import ClientProtocol
from littlepay.api import ClientProtocol, from_kwargs
from littlepay.api.funding_sources import FundingSourceDateFields, FundingSourcesMixin


Expand All @@ -24,11 +24,19 @@ def csv_header() -> str:
instance = GroupResponse("", "", "")
return ",".join(vars(instance).keys())

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@dataclass(kw_only=True)
class GroupFundingSourceResponse(FundingSourceDateFields):
id: str

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class GroupsMixin(ClientProtocol):
"""Mixin implements APIs for concession groups."""
Expand Down
6 changes: 5 additions & 1 deletion littlepay/api/products.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Generator

from littlepay.api import ClientProtocol
from littlepay.api import ClientProtocol, from_kwargs
from littlepay.api.groups import GroupsMixin


Expand All @@ -26,6 +26,10 @@ def csv_header() -> str:
instance = ProductResponse("", "", "", "", "", "")
return ",".join(vars(instance).keys())

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class ProductsMixin(GroupsMixin, ClientProtocol):
"""Mixin implements APIs for products."""
Expand Down
38 changes: 37 additions & 1 deletion tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from requests import HTTPError

from littlepay.api import ListResponse
from littlepay.api import ListResponse, from_kwargs
from littlepay.api.client import _client_from_active_config, _fix_bearer_token_header, _json_post_credentials, Client
from littlepay.config import Config

Expand Down Expand Up @@ -48,12 +48,21 @@ class SampleResponse:
two: str
three: int

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@pytest.fixture
def SampleResponse_json():
return {"one": "single", "two": "double", "three": 3}


@pytest.fixture
def SampleResponse_json_with_unexpected_field():
return {"one": "single", "two": "double", "three": 3, "four": "4"}


@pytest.fixture
def default_list_params():
return dict(page=1, per_page=100)
Expand Down Expand Up @@ -232,6 +241,26 @@ def test_Client_get_params(mocker, make_client: ClientFunc, url, SampleResponse_
assert result.three == 3


def test_Client_get_response_has_unexpected_fields(
mocker, make_client: ClientFunc, url, SampleResponse_json_with_unexpected_field
):
client = make_client()
mock_response = mocker.Mock(
raise_for_status=mocker.Mock(return_value=False),
json=mocker.Mock(return_value=SampleResponse_json_with_unexpected_field),
)
req_spy = mocker.patch.object(client.oauth, "get", return_value=mock_response)

result = client._get(url, SampleResponse)

req_spy.assert_called_once_with(url, headers=client.headers, params={})
assert isinstance(result, SampleResponse)
assert result.one == "single"
assert result.two == "double"
assert result.three == 3
assert result.four == "4"


def test_Client_get_error_status(mocker, make_client: ClientFunc, url):
client = make_client()
mock_response = mocker.Mock(raise_for_status=mocker.Mock(side_effect=HTTPError))
Expand All @@ -243,6 +272,13 @@ def test_Client_get_error_status(mocker, make_client: ClientFunc, url):
req_spy.assert_called_once_with(url, headers=client.headers, params={})


def test_ListResponse_unexpected_fields():
response_json = {"list": [1, 2, 3], "total_count": 3, "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
ListResponse.from_kwargs(**response_json)


def test_Client_get_list(mocker, make_client: ClientFunc, url, default_list_params, ListResponse_sample):
client = make_client()
req_spy = mocker.patch.object(client, "_get", return_value=ListResponse_sample)
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_funding_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_FundingSourceResponse_unexpected_fields():
}

# this test will fail if any error occurs from instantiating the class
FundingSourceResponse(**response_json)
FundingSourceResponse.from_kwargs(**response_json)


def test_FundingSourceDateFields(expected_expiry_str, expected_expiry):
Expand All @@ -92,7 +92,7 @@ def test_FundingSourceGroupResponse_unexpected_fields():
response_json = {"id": "id", "group_id": "group_id", "label": "label", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
FundingSourceGroupResponse(**response_json)
FundingSourceGroupResponse.from_kwargs(**response_json)


def test_FundingSourceGroupResponse_no_dates():
Expand Down
4 changes: 2 additions & 2 deletions tests/api/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_GroupResponse_unexpected_fields():
response_json = {"id": "id", "label": "label", "participant_id": "participant", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
GroupResponse(**response_json)
GroupResponse.from_kwargs(**response_json)


def test_GroupResponse_csv():
Expand All @@ -92,7 +92,7 @@ def test_GroupFundingSourceResponse_unexpected_fields():
response_json = {"id": "id", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
GroupFundingSourceResponse(**response_json)
GroupFundingSourceResponse.from_kwargs(**response_json)


def test_GroupFundingSourceResponse_no_dates():
Expand Down
2 changes: 1 addition & 1 deletion tests/api/test_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_ProductResponse_unexpected_fields():
}

# this test will fail if any error occurs from instantiating the class
ProductResponse(**response_json)
ProductResponse.from_kwargs(**response_json)


def test_ProductResponse_csv():
Expand Down

0 comments on commit 397e0e4

Please sign in to comment.