From 3f83045a247c8c7d4509e29adb2f66135f4bcc58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mikael=20Engstr=C3=B6m?= Date: Thu, 21 Nov 2024 08:38:56 +0100 Subject: [PATCH] Fix #744: Bulk messages with aoiapns not working properly Previous version implemented aioapns in a way that made it hang indefinetly. Especially when receiver list contaned a lot of bad tokens. Having a lot of bad tokens still affects the reliability and transfer speed of notification sets, which is why this fix also deactivate devices for error codes BadDeviceToken and DeviceTokenNotForTopic unlike previous versions. --- push_notifications/apns_async.py | 364 +++++++++++++++++-------------- 1 file changed, 195 insertions(+), 169 deletions(-) diff --git a/push_notifications/apns_async.py b/push_notifications/apns_async.py index a0710d85..390499c0 100644 --- a/push_notifications/apns_async.py +++ b/push_notifications/apns_async.py @@ -1,5 +1,6 @@ import asyncio import time + from dataclasses import asdict, dataclass from typing import Awaitable, Callable, Dict, Optional, Union @@ -111,132 +112,100 @@ def asDict(self) -> dict[str, any]: } -class APNsService: - __slots__ = ("client",) - - def __init__( - self, - application_id: str = None, - creds: Credentials = None, - topic: str = None, - err_func: ErrFunc = None, - ): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - self.client = self._create_client( - creds=creds, application_id=application_id, topic=topic, err_func=err_func - ) +def _create_notification_request_from_args( + registration_id: str, + alert: Union[str, Alert], + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + aps_kwargs: dict = {}, + message_kwargs: dict = {}, + notification_request_kwargs: dict = {}, +): + if alert is None: + alert = Alert(body="") - def send_message( - self, - request: NotificationRequest, - ): - loop = asyncio.get_event_loop() - routine = self.client.send_notification(request) - res = loop.run_until_complete(routine) - return res - - def _create_notification_request_from_args( - self, - registration_id: str, - alert: Union[str, Alert], - badge: int = None, - sound: str = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - aps_kwargs: dict = {}, - message_kwargs: dict = {}, - notification_request_kwargs: dict = {}, - ): - if alert is None: - alert = Alert(body="") - - if loc_key: - if isinstance(alert, str): - alert = Alert(body=alert) - alert.loc_key = loc_key - - if isinstance(alert, Alert): - alert = alert.asDict() - - notification_request_kwargs_out = notification_request_kwargs.copy() - - if expiration is not None: - notification_request_kwargs_out["time_to_live"] = expiration - int( - time.time() - ) - if priority is not None: - notification_request_kwargs_out["priority"] = priority - - if collapse_id is not None: - notification_request_kwargs_out["collapse_key"] = collapse_id - - request = NotificationRequest( - device_token=registration_id, - message={ - "aps": { - "alert": alert, - "badge": badge, - "sound": sound, - "thread-id": thread_id, - **aps_kwargs, - }, - **extra, - **message_kwargs, - }, - **notification_request_kwargs_out, - ) + if loc_key: + if isinstance(alert, str): + alert = Alert(body=alert) + alert.loc_key = loc_key - return request - - def _create_client( - self, - creds: Credentials = None, - application_id: str = None, - topic=None, - err_func: ErrFunc = None, - ) -> APNs: - use_sandbox = get_manager().get_apns_use_sandbox(application_id) - if topic is None: - topic = get_manager().get_apns_topic(application_id) - if creds is None: - creds = self._get_credentials(application_id) - - client = APNs( - **asdict(creds), - topic=topic, # Bundle ID - use_sandbox=use_sandbox, - err_func=err_func, + if isinstance(alert, Alert): + alert = alert.asDict() + + notification_request_kwargs_out = notification_request_kwargs.copy() + + if expiration is not None: + notification_request_kwargs_out["time_to_live"] = expiration - int( + time.time() ) - return client - - def _get_credentials(self, application_id): - if not get_manager().has_auth_token_creds(application_id): - # TLS certificate authentication - cert = get_manager().get_apns_certificate(application_id) - return CertificateCredentials( - client_cert=cert, - ) - else: - # Token authentication - keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) + if priority is not None: + notification_request_kwargs_out["priority"] = priority + + if collapse_id is not None: + notification_request_kwargs_out["collapse_key"] = collapse_id + + request = NotificationRequest( + device_token=registration_id, + message={ + "aps": { + "alert": alert, + "badge": badge, + "sound": sound, + "thread-id": thread_id, + **aps_kwargs, + }, + **extra, + **message_kwargs, + }, + **notification_request_kwargs_out, + ) + return request -# Public interface + +def _create_client( + creds: Credentials = None, + application_id: str = None, + topic=None, + err_func: ErrFunc = None, +) -> APNs: + use_sandbox = get_manager().get_apns_use_sandbox(application_id) + if topic is None: + topic = get_manager().get_apns_topic(application_id) + if creds is None: + creds = _get_credentials(application_id) + + client = APNs( + **asdict(creds), + topic=topic, # Bundle ID + use_sandbox=use_sandbox, + err_func=err_func, + ) + return client + + +def _get_credentials(application_id): + if not get_manager().has_auth_token_creds(application_id): + # TLS certificate authentication + cert = get_manager().get_apns_certificate(application_id) + return CertificateCredentials( + client_cert=cert, + ) + else: + # Token authentication + keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + return TokenCredentials(key=keyPath, key_id=keyId, team_id=teamId) def apns_send_message( @@ -270,33 +239,28 @@ def apns_send_message( :param application_id: The application_id to use :param creds: The credentials to use """ + results = apns_send_bulk_message( + registration_ids=[registration_id], + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + err_func=err_func, + ) - try: - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - - request = apns_service._create_notification_request_from_args( - registration_id, - alert, - badge=badge, - sound=sound, - extra=extra, - expiration=expiration, - thread_id=thread_id, - loc_key=loc_key, - priority=priority, - collapse_id=collapse_id, - ) - res = apns_service.send_message(request) - if not res.is_successful: - if res.description == "Unregistered": - models.APNSDevice.objects.filter( - registration_id=registration_id - ).update(active=False) - raise APNSServerError(status=res.description) - except ConnectionError as e: - raise APNSServerError(status=e.__class__.__name__) + for result in results.values(): + if result == "Success": + return {"results": [result]} + else: + return {"results": [{"error": result}]} def apns_send_bulk_message( @@ -328,17 +292,17 @@ def apns_send_bulk_message( :param application_id: The application_id to use :param creds: The credentials to use """ - - topic = get_manager().get_apns_topic(application_id) - results: Dict[str, str] = {} - inactive_tokens = [] - apns_service = APNsService( - application_id=application_id, creds=creds, topic=topic, err_func=err_func - ) - for registration_id in registration_ids: - request = apns_service._create_notification_request_from_args( - registration_id, - alert, + try: + topic = get_manager().get_apns_topic(application_id) + results: Dict[str, str] = {} + inactive_tokens = [] + + responses = asyncio.run(_send_bulk_request( + registration_ids=registration_ids, + alert=alert, + application_id=application_id, + creds=creds, + topic=topic, badge=badge, sound=sound, extra=extra, @@ -347,17 +311,79 @@ def apns_send_bulk_message( loc_key=loc_key, priority=priority, collapse_id=collapse_id, - ) + err_func=err_func, + )) + + results = {} + for registration_id, result in responses: + results[registration_id] = ( + "Success" if result.is_successful else result.description + ) + if not result.is_successful and result.description in ["Unregistered", "BadDeviceToken", + "DeviceTokenNotForTopic"]: + inactive_tokens.append(registration_id) + + if len(inactive_tokens) > 0: + models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( + active=False + ) + + return results + + except ConnectionError as e: + raise APNSServerError(status=e.__class__.__name__) - result = apns_service.send_message(request) - results[registration_id] = ( - "Success" if result.is_successful else result.description - ) - if not result.is_successful and result.description == "Unregistered": - inactive_tokens.append(registration_id) - if len(inactive_tokens) > 0: - models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( - active=False +async def _send_bulk_request( + registration_ids: list[str], + alert: Union[str, Alert], + application_id: str = None, + creds: Credentials = None, + topic: str = None, + badge: int = None, + sound: str = None, + extra: dict = {}, + expiration: int = None, + thread_id: str = None, + loc_key: str = None, + priority: int = None, + collapse_id: str = None, + err_func: ErrFunc = None, +): + client = _create_client( + creds=creds, application_id=application_id, topic=topic, err_func=err_func + ) + + requests = [_create_notification_request_from_args( + registration_id, + alert, + badge=badge, + sound=sound, + extra=extra, + expiration=expiration, + thread_id=thread_id, + loc_key=loc_key, + priority=priority, + collapse_id=collapse_id, + ) for registration_id in registration_ids] + + send_requests = [_send_request(client, request) for request in requests] + return await asyncio.gather(*send_requests) + + +async def _send_request(apns, request): + try: + res = await asyncio.wait_for(apns.send_notification(request), timeout=1) + return request.device_token, res + except asyncio.TimeoutError: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="TimeoutError" + ) + except: + return request.device_token, NotificationResult( + notification_id=request.notification_id, + status="failed", + description="CommunicationError" ) - return results