Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: allow unexpected fields in responses #69

Merged
merged 4 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions littlepay/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,57 @@
from dataclasses import dataclass
from inspect import signature
import logging
from typing import Generator, Protocol, TypeVar

from authlib.integrations.requests_client import OAuth2Session


logger = logging.getLogger(__name__)


# Generic type parameter, used to represent the result of an API call.
TResponse = TypeVar("TResponse")


def from_kwargs(cls, **kwargs):
"""
Helper function meant to be used as a @classmethod
for instantiating a dataclass and allowing unexpected fields

See https://stackoverflow.com/a/55101438
"""
# fetch the constructor's signature
class_fields = {field for field in signature(cls).parameters}

# split the kwargs into native ones and new ones
native_args, new_args = {}, {}
for name, val in kwargs.items():
if name in class_fields:
native_args[name] = val
else:
new_args[name] = val

# use the native ones to create the class ...
instance = cls(**native_args)

# ... and log any unexpected args
for new_name, new_val in new_args.items():
logger.info(f"Ran into an unexpected arg: {new_name} = {new_val}")

return instance


@dataclass
class ListResponse:
"""An API response with list and total_count attributes."""

list: list
total_count: int

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class ClientProtocol(Protocol):
"""Protocol describing key functionality for an API connection."""
Expand Down
3 changes: 2 additions & 1 deletion littlepay/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def _delete(self, endpoint: str) -> bool:
def _get(self, endpoint: str, response_cls: TResponse, **kwargs) -> TResponse:
response = self.oauth.get(endpoint, headers=self.headers, params=kwargs)
response.raise_for_status()
return response_cls(**response.json())

return response_cls.from_kwargs(**response.json())

def _get_list(self, endpoint: str, **kwargs) -> Generator[dict, None, None]:
params = dict(page=1, per_page=100)
Expand Down
27 changes: 27 additions & 0 deletions littlepay/api/funding_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from littlepay.api import ClientProtocol

from . import from_kwargs


@dataclass
class FundingSourceResponse:
Expand All @@ -17,6 +19,7 @@ class FundingSourceResponse:
participant_id: str
is_fpan: bool
related_funding_sources: List[dict]
created_date: datetime | None = None
card_category: Optional[str] = None
issuer_country_code: Optional[str] = None
issuer_country_numeric_code: Optional[str] = None
Expand All @@ -25,6 +28,26 @@ class FundingSourceResponse:
token_key_id: Optional[str] = None
icc_hash: Optional[str] = None

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)

def __post_init__(self):
"""Parses any date parameters into Python datetime objects.

For @dataclasses with a generated __init__ function, this function is called automatically.

Includes a workaround for Python 3.10 where datetime.fromisoformat() can only parse the format output
by datetime.isoformat(), i.e. without a trailing 'Z' offset character and with UTC offset expressed
as +/-HH:mm

https://docs.python.org/3.11/library/datetime.html#datetime.datetime.fromisoformat
"""
if self.created_date:
self.created_date = datetime.fromisoformat(self.created_date.replace("Z", "+00:00", 1))
else:
self.created_date = None
thekaveman marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class FundingSourceDateFields:
Expand Down Expand Up @@ -65,6 +88,10 @@ class FundingSourceGroupResponse(FundingSourceDateFields):
group_id: str
label: str

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class FundingSourcesMixin(ClientProtocol):
"""Mixin implements APIs for funding sources."""
Expand Down
10 changes: 9 additions & 1 deletion littlepay/api/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime, timezone
from typing import Generator

from littlepay.api import ClientProtocol
from littlepay.api import ClientProtocol, from_kwargs
from littlepay.api.funding_sources import FundingSourceDateFields, FundingSourcesMixin


Expand All @@ -24,11 +24,19 @@ def csv_header() -> str:
instance = GroupResponse("", "", "")
return ",".join(vars(instance).keys())

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@dataclass(kw_only=True)
class GroupFundingSourceResponse(FundingSourceDateFields):
id: str

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class GroupsMixin(ClientProtocol):
"""Mixin implements APIs for concession groups."""
Expand Down
6 changes: 5 additions & 1 deletion littlepay/api/products.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Generator

from littlepay.api import ClientProtocol
from littlepay.api import ClientProtocol, from_kwargs
from littlepay.api.groups import GroupsMixin


Expand All @@ -26,6 +26,10 @@ def csv_header() -> str:
instance = ProductResponse("", "", "", "", "", "")
return ",".join(vars(instance).keys())

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


class ProductsMixin(GroupsMixin, ClientProtocol):
"""Mixin implements APIs for products."""
Expand Down
38 changes: 37 additions & 1 deletion tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from requests import HTTPError

from littlepay.api import ListResponse
from littlepay.api import ListResponse, from_kwargs
from littlepay.api.client import _client_from_active_config, _fix_bearer_token_header, _json_post_credentials, Client
from littlepay.config import Config

Expand Down Expand Up @@ -48,12 +48,21 @@ class SampleResponse:
two: str
three: int

@classmethod
def from_kwargs(cls, **kwargs):
return from_kwargs(cls, **kwargs)


@pytest.fixture
def SampleResponse_json():
return {"one": "single", "two": "double", "three": 3}


@pytest.fixture
def SampleResponse_json_with_unexpected_field():
return {"one": "single", "two": "double", "three": 3, "four": "4"}


@pytest.fixture
def default_list_params():
return dict(page=1, per_page=100)
Expand Down Expand Up @@ -232,6 +241,26 @@ def test_Client_get_params(mocker, make_client: ClientFunc, url, SampleResponse_
assert result.three == 3


def test_Client_get_response_has_unexpected_fields(
mocker, make_client: ClientFunc, url, SampleResponse_json_with_unexpected_field
):
client = make_client()
mock_response = mocker.Mock(
raise_for_status=mocker.Mock(return_value=False),
json=mocker.Mock(return_value=SampleResponse_json_with_unexpected_field),
)
req_spy = mocker.patch.object(client.oauth, "get", return_value=mock_response)

result = client._get(url, SampleResponse)

req_spy.assert_called_once_with(url, headers=client.headers, params={})
assert isinstance(result, SampleResponse)
assert result.one == "single"
assert result.two == "double"
assert result.three == 3
assert not hasattr(result, "four")


def test_Client_get_error_status(mocker, make_client: ClientFunc, url):
client = make_client()
mock_response = mocker.Mock(raise_for_status=mocker.Mock(side_effect=HTTPError))
Expand All @@ -243,6 +272,13 @@ def test_Client_get_error_status(mocker, make_client: ClientFunc, url):
req_spy.assert_called_once_with(url, headers=client.headers, params={})


def test_ListResponse_unexpected_fields():
response_json = {"list": [1, 2, 3], "total_count": 3, "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
ListResponse.from_kwargs(**response_json)


def test_Client_get_list(mocker, make_client: ClientFunc, url, default_list_params, ListResponse_sample):
client = make_client()
req_spy = mocker.patch.object(client, "_get", return_value=ListResponse_sample)
Expand Down
66 changes: 65 additions & 1 deletion tests/api/test_funding_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ListResponse_FundingSourceGroups(expected_expiry_str):


@pytest.fixture
def mock_ClientProtocol_get_FundingResource(mocker):
def mock_ClientProtocol_get_FundingResource(mocker, expected_expiry_str):
funding_source = FundingSourceResponse(
id="0",
card_first_digits="0000",
Expand All @@ -47,6 +47,7 @@ def mock_ClientProtocol_get_FundingResource(mocker):
participant_id="cst",
is_fpan=True,
related_funding_sources=[],
created_date=expected_expiry_str,
)
return mocker.patch("littlepay.api.ClientProtocol._get", return_value=funding_source)

Expand All @@ -59,6 +60,62 @@ def mock_ClientProtocol_get_list_FundingSourceGroup(mocker, ListResponse_Funding
)


def test_FundingSourceResponse_unexpected_fields():
response_json = {
"id": "0",
"card_first_digits": "0000",
"card_last_digits": "0000",
"card_expiry_month": "11",
"card_expiry_year": "24",
"card_scheme": "Visa",
"form_factor": "unknown",
"participant_id": "cst",
"is_fpan": True,
"related_funding_sources": [],
"unexpected_field": "test value",
}

# this test will fail if any error occurs from instantiating the class
FundingSourceResponse.from_kwargs(**response_json)


def test_FundingSourceResponse_no_date_field():
response_json = {
"id": "0",
"card_first_digits": "0000",
"card_last_digits": "0000",
"card_expiry_month": "11",
"card_expiry_year": "24",
"card_scheme": "Visa",
"form_factor": "unknown",
"participant_id": "cst",
"is_fpan": True,
"related_funding_sources": [],
}

funding_source = FundingSourceResponse.from_kwargs(**response_json)
assert funding_source.created_date is None


def test_FundingSourceResponse_with_date_field(expected_expiry_str, expected_expiry):
response_json = {
"id": "0",
"card_first_digits": "0000",
"card_last_digits": "0000",
"card_expiry_month": "11",
"card_expiry_year": "24",
"card_scheme": "Visa",
"form_factor": "unknown",
"participant_id": "cst",
"is_fpan": True,
"related_funding_sources": [],
"created_date": expected_expiry_str,
}

funding_source = FundingSourceResponse.from_kwargs(**response_json)
assert funding_source.created_date == expected_expiry


def test_FundingSourceDateFields(expected_expiry_str, expected_expiry):
fields = FundingSourceDateFields(
created_date=expected_expiry_str, updated_date=expected_expiry_str, expiry_date=expected_expiry_str
Expand All @@ -69,6 +126,13 @@ def test_FundingSourceDateFields(expected_expiry_str, expected_expiry):
assert fields.expiry_date == expected_expiry


def test_FundingSourceGroupResponse_unexpected_fields():
response_json = {"id": "id", "group_id": "group_id", "label": "label", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
FundingSourceGroupResponse.from_kwargs(**response_json)


def test_FundingSourceGroupResponse_no_dates():
response = FundingSourceGroupResponse(id="id", group_id="group_id", label="label")

Expand Down
14 changes: 14 additions & 0 deletions tests/api/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ def mock_ClientProtocol_put_update_concession_group_funding_source(mocker):
return mocker.patch("littlepay.api.ClientProtocol._put", side_effect=lambda *args, **kwargs: response)


def test_GroupResponse_unexpected_fields():
response_json = {"id": "id", "label": "label", "participant_id": "participant", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
GroupResponse.from_kwargs(**response_json)


def test_GroupResponse_csv():
group = GroupResponse("id", "label", "participant")
assert group.csv() == "id,label,participant"
Expand All @@ -81,6 +88,13 @@ def test_GroupResponse_csv_header():
assert GroupResponse.csv_header() == "id,label,participant_id"


def test_GroupFundingSourceResponse_unexpected_fields():
response_json = {"id": "id", "unexpected_field": "test value"}

# this test will fail if any error occurs from instantiating the class
GroupFundingSourceResponse.from_kwargs(**response_json)


def test_GroupFundingSourceResponse_no_dates():
response = GroupFundingSourceResponse(id="id")

Expand Down
15 changes: 15 additions & 0 deletions tests/api/test_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ def mock_ClientProtocol_post(mocker):
return mocker.patch("littlepay.api.ClientProtocol._post", side_effect=lambda *args, **kwargs: response)


def test_ProductResponse_unexpected_fields():
response_json = {
"id": "id",
"code": "code",
"status": "status",
"type": "type",
"description": "description",
"participant_id": "participant",
"unexpected_field": "test value",
}

# this test will fail if any error occurs from instantiating the class
ProductResponse.from_kwargs(**response_json)


def test_ProductResponse_csv():
product = ProductResponse("id", "code", "status", "type", "description", "participant")
assert product.csv() == "id,code,status,type,description,participant"
Expand Down
Loading