diff --git a/push_notifications/apns_async.py b/push_notifications/apns_async.py index a0710d85..00b3c9d5 100644 --- a/push_notifications/apns_async.py +++ b/push_notifications/apns_async.py @@ -1,5 +1,6 @@ import asyncio import time + from dataclasses import asdict, dataclass from typing import Awaitable, Callable, Dict, Optional, Union @@ -8,7 +9,7 @@ from . import models from .conf import get_manager -from .exceptions import APNSServerError +from .exceptions import APNSServerError, APNSError ErrFunc = Optional[Callable[[NotificationRequest, NotificationResult], Awaitable[None]]] """function to proces errors from aioapns send_message""" @@ -111,132 +112,100 @@ def asDict(self) -> dict[str, any]: } -class APNsService: - __slots__ = ("client",) - - def __init__( - self, - application_id: str = None, - creds: Credentials = None, - topic: str = None, - err_func: ErrFunc = None, - ): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - self.client = self._create_client( - creds=creds, application_id=application_id, topic=topic, err_func=err_func - ) +def _create_notification_request_from_args( + registration_id: str, + alert: Union[str, Alert], + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + aps_kwargs: dict = {}, + message_kwargs: dict = {}, + notification_request_kwargs: dict = {}, +): + if alert is None: + alert = Alert(body="") - def send_message( - self, - request: NotificationRequest, - ): - loop = asyncio.get_event_loop() - routine = self.client.send_notification(request) - res = loop.run_until_complete(routine) - return res - - def _create_notification_request_from_args( - self, - registration_id: str, - alert: Union[str, Alert], - badge: int = None, - sound: str = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - aps_kwargs: dict = {}, - message_kwargs: dict = {}, - notification_request_kwargs: dict = {}, - ): - if alert is None: - alert = Alert(body="") - - if loc_key: - if isinstance(alert, str): - alert = Alert(body=alert) - alert.loc_key = loc_key - - if isinstance(alert, Alert): - alert = alert.asDict() - - notification_request_kwargs_out = notification_request_kwargs.copy() - - if expiration is not None: - notification_request_kwargs_out["time_to_live"] = expiration - int( - time.time() - ) - if priority is not None: - notification_request_kwargs_out["priority"] = priority - - if collapse_id is not None: - notification_request_kwargs_out["collapse_key"] = collapse_id - - request = NotificationRequest( - device_token=registration_id, - message={ - "aps": { - "alert": alert, - "badge": badge, - "sound": sound, - "thread-id": thread_id, - **aps_kwargs, - }, - **extra, - **message_kwargs, - }, - **notification_request_kwargs_out, - ) + if loc_key: + if isinstance(alert, str): + alert = Alert(body=alert) + alert.loc_key = loc_key - return request - - def _create_client( - self, - creds: Credentials = None, - application_id: str = None, - topic=None, - err_func: ErrFunc = None, - ) -> APNs: - use_sandbox = get_manager().get_apns_use_sandbox(application_id) - if topic is None: - topic = get_manager().get_apns_topic(application_id) - if creds is None: - creds = self._get_credentials(application_id) - - client = APNs( - **asdict(creds), - topic=topic, # Bundle ID - use_sandbox=use_sandbox, - err_func=err_func, + if isinstance(alert, Alert): + alert = alert.asDict() + + notification_request_kwargs_out = notification_request_kwargs.copy() + + if expiration is not None: + notification_request_kwargs_out["time_to_live"] = expiration - int( + time.time() ) - return client - - def _get_credentials(self, application_id): - if not get_manager().has_auth_token_creds(application_id): - # TLS certificate authentication - cert = get_manager().get_apns_certificate(application_id) - return CertificateCredentials( - client_cert=cert, - ) - else: - # Token authentication - keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) + if priority is not None: + notification_request_kwargs_out["priority"] = priority + + if collapse_id is not None: + notification_request_kwargs_out["collapse_key"] = collapse_id + + request = NotificationRequest( + device_token=registration_id, + message={ + "aps": { + "alert": alert, + "badge": badge, + "sound": sound, + "thread-id": thread_id, + **aps_kwargs, + }, + **extra, + **message_kwargs, + }, + **notification_request_kwargs_out, + ) + return request -# Public interface + +def _create_client( + creds: Credentials = None, + application_id: str = None, + topic=None, + err_func: ErrFunc = None, +) -> APNs: + use_sandbox = get_manager().get_apns_use_sandbox(application_id) + if topic is None: + topic = get_manager().get_apns_topic(application_id) + if creds is None: + creds = _get_credentials(application_id) + + client = APNs( + **asdict(creds), + topic=topic, # Bundle ID + use_sandbox=use_sandbox, + err_func=err_func, + ) + return client + + +def _get_credentials(application_id): + if not get_manager().has_auth_token_creds(application_id): + # TLS certificate authentication + cert = get_manager().get_apns_certificate(application_id) + return CertificateCredentials( + client_cert=cert, + ) + else: + # Token authentication + keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) def apns_send_message( @@ -270,33 +239,28 @@ def apns_send_message( :param application_id: The application_id to use :param creds: The credentials to use """ + results = apns_send_bulk_message( + registration_ids=[registration_id], + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + err_func=err_func, + ) - try: - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - - request = apns_service._create_notification_request_from_args( - registration_id, - alert, - badge=badge, - sound=sound, - extra=extra, - expiration=expiration, - thread_id=thread_id, - loc_key=loc_key, - priority=priority, - collapse_id=collapse_id, - ) - res = apns_service.send_message(request) - if not res.is_successful: - if res.description == "Unregistered": - models.APNSDevice.objects.filter( - registration_id=registration_id - ).update(active=False) - raise APNSServerError(status=res.description) - except ConnectionError as e: - raise APNSServerError(status=e.__class__.__name__) + for result in results.values(): + if result == "Success": + return {"results": [result]} + else: + return {"results": [{"error": result}]} def apns_send_bulk_message( @@ -328,17 +292,17 @@ def apns_send_bulk_message( :param application_id: The application_id to use :param creds: The credentials to use """ - - topic = get_manager().get_apns_topic(application_id) - results: Dict[str, str] = {} - inactive_tokens = [] - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - for registration_id in registration_ids: - request = apns_service._create_notification_request_from_args( - registration_id, - alert, + try: + topic = get_manager().get_apns_topic(application_id) + results: Dict[str, str] = {} + inactive_tokens = [] + + responses = asyncio.run(_send_bulk_request( + registration_ids=registration_ids, + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, badge=badge, sound=sound, extra=extra, @@ -347,17 +311,86 @@ def apns_send_bulk_message( loc_key=loc_key, priority=priority, collapse_id=collapse_id, - ) + err_func=err_func, + )) + + results = {} + errors = [] + for registration_id, result in responses: + results[registration_id] = ( + "Success" if result.is_successful else result.description + ) + if not result.is_successful: + errors.append(result.description) + if result.description in ["Unregistered", "BadDeviceToken", + "DeviceTokenNotForTopic"]: + inactive_tokens.append(registration_id) + + if len(inactive_tokens) > 0: + models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( + active=False + ) + + if len(errors) > 0: + msg = "One or more errors failed with errors: {}".format(", ".join(errors)) + raise APNSError(msg) + + return results + + except ConnectionError as e: + raise APNSServerError(status=e.__class__.__name__) - result = apns_service.send_message(request) - results[registration_id] = ( - "Success" if result.is_successful else result.description - ) - if not result.is_successful and result.description == "Unregistered": - inactive_tokens.append(registration_id) - if len(inactive_tokens) > 0: - models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( - active=False +async def _send_bulk_request( + registration_ids: list[str], + alert: Union[str, Alert], + application_id: str = None, + creds: Credentials = None, + topic: str = None, + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + err_func: ErrFunc = None, +): + client = _create_client( + creds=creds, application_id=application_id, topic=topic, err_func=err_func + ) + + requests = [_create_notification_request_from_args( + registration_id, + alert, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + ) for registration_id in registration_ids] + + send_requests = [_send_request(client, request) for request in requests] + return await asyncio.gather(*send_requests) + + +async def _send_request(apns, request): + try: + res = await asyncio.wait_for(apns.send_notification(request), timeout=1) + return request.device_token, res + except asyncio.TimeoutError: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="TimeoutError" + ) + except: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="CommunicationError" ) - return results diff --git a/push_notifications/exceptions.py b/push_notifications/exceptions.py index 7fc5cdee..33fb4659 100644 --- a/push_notifications/exceptions.py +++ b/push_notifications/exceptions.py @@ -1,4 +1,7 @@ class NotificationError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message pass diff --git a/tests/test_apns_async_models.py b/tests/test_apns_async_models.py index 291cc01b..b4588502 100644 --- a/tests/test_apns_async_models.py +++ b/tests/test_apns_async_models.py @@ -87,19 +87,19 @@ def test_apns_send_message_to_single_device_with_error(self, mock_apns): mock_apns.return_value.send_notification.return_value = NotificationResult( status="400", notification_id="abc", - description="Unregistered", + description="PayloadTooLarge", ) device = APNSDevice.objects.get(registration_id="abc") with self.assertRaises(APNSError) as ae: device.send_message("Hello World!") - self.assertEqual(ae.exception.status, "Unregistered") - self.assertFalse(APNSDevice.objects.get(registration_id="abc").active) + self.assertTrue("PayloadTooLarge" in ae.exception.message) + self.assertTrue(APNSDevice.objects.get(registration_id="abc").active) @mock.patch("push_notifications.apns_async.APNs", autospec=True) def test_apns_send_message_to_several_devices_with_error(self, mock_apns): # these errors are device specific, device.active will be set false devices = ["abc", "def", "ghi"] - expected_exceptions_statuses = ["PayloadTooLarge", "BadTopic", "Unregistered"] + expected_exceptions_statuses = ["PayloadTooLarge", "DeviceTokenNotForTopic", "Unregistered"] self._create_devices(devices) mock_apns.return_value.send_notification.side_effect = [ @@ -111,7 +111,7 @@ def test_apns_send_message_to_several_devices_with_error(self, mock_apns): NotificationResult( status="400", notification_id="def", - description="BadTopic", + description="DeviceTokenNotForTopic", ), NotificationResult( status="400", @@ -124,12 +124,12 @@ def test_apns_send_message_to_several_devices_with_error(self, mock_apns): device = APNSDevice.objects.get(registration_id=token) with self.assertRaises(APNSError) as ae: device.send_message("Hello World!") - self.assertEqual(ae.exception.status, expected_exceptions_statuses[idx]) + self.assertTrue(expected_exceptions_statuses[idx] in ae.exception.message) - if idx == 2: - self.assertFalse(APNSDevice.objects.get(registration_id=token).active) - else: + if idx == 0: self.assertTrue(APNSDevice.objects.get(registration_id=token).active) + else: + self.assertFalse(APNSDevice.objects.get(registration_id=token).active) @mock.patch("push_notifications.apns_async.APNs", autospec=True) def test_apns_send_message_to_bulk_devices_with_error(self, mock_apns): @@ -144,7 +144,7 @@ def test_apns_send_message_to_bulk_devices_with_error(self, mock_apns): NotificationResult( status="400", notification_id="def", - description="BadTopic", + description="DeviceTokenNotForTopic", ), NotificationResult( status="400", @@ -156,13 +156,14 @@ def test_apns_send_message_to_bulk_devices_with_error(self, mock_apns): mock_apns.return_value.send_notification.side_effect = results - results = APNSDevice.objects.all().send_message("Hello World!") + with self.assertRaises(APNSError): + APNSDevice.objects.all().send_message("Hello World!") for idx, token in enumerate(devices): - if idx == 2: - self.assertFalse(APNSDevice.objects.get(registration_id=token).active) - else: + if idx == 0: self.assertTrue(APNSDevice.objects.get(registration_id=token).active) + else: + self.assertFalse(APNSDevice.objects.get(registration_id=token).active) @mock.patch("push_notifications.apns_async.APNs", autospec=True) def test_apns_send_messages_different_priority(self, mock_apns): diff --git a/tests/test_apns_async_push_payload.py b/tests/test_apns_async_push_payload.py index ebb11416..78750570 100644 --- a/tests/test_apns_async_push_payload.py +++ b/tests/test_apns_async_push_payload.py @@ -8,7 +8,7 @@ try: from aioapns.common import NotificationResult - from push_notifications.apns_async import TokenCredentials, apns_send_message + from push_notifications.apns_async import TokenCredentials, apns_send_message, CertificateCredentials except ModuleNotFoundError: # skipping because apns2 is not supported on python 3.10 # it uses hyper that imports from collections which were changed in 3.10 @@ -157,8 +157,7 @@ def test_collapse_id(self, mock_apns): self.assertEqual(req.collapse_key, "456789") @mock.patch("aioapns.client.APNsCertConnectionPool", autospec=True) - @mock.patch("aioapns.client.APNsKeyConnectionPool", autospec=True) - def test_aioapns_err_func(self, mock_cert_pool, mock_key_pool): + def test_aioapns_err_func(self, mock_cert_pool): mock_cert_pool.return_value.send_notification = mock.AsyncMock() result = NotificationResult( "123", "400" @@ -169,10 +168,8 @@ def test_aioapns_err_func(self, mock_cert_pool, mock_key_pool): apns_send_message( "123", "sample", - creds=TokenCredentials( - key="aaa", - key_id="bbb", - team_id="ccc", + creds=CertificateCredentials( + client_cert="dummy/path.pem", ), topic="default", err_func=err_func,