diff --git a/push_notifications/admin.py b/push_notifications/admin.py index 1242656b..b0ff5cc3 100644 --- a/push_notifications/admin.py +++ b/push_notifications/admin.py @@ -2,6 +2,8 @@ from django.contrib import admin, messages from django.utils.encoding import force_str from django.utils.translation import gettext_lazy as _ +from django.http import HttpRequest +from django.db.models import QuerySet from .exceptions import APNSServerError, GCMError, WebPushError from .models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice @@ -20,9 +22,9 @@ class DeviceAdmin(admin.ModelAdmin): if hasattr(User, "USERNAME_FIELD"): search_fields = ("name", "device_id", "user__%s" % (User.USERNAME_FIELD)) else: - search_fields = ("name", "device_id") + search_fields = ("name", "device_id", "") - def send_messages(self, request, queryset, bulk=False): + def send_messages(self, request: HttpRequest, queryset: QuerySet, bulk: bool = False) -> None: """ Provides error handling for DeviceAdmin send_message and send_bulk_message methods. """ @@ -105,22 +107,22 @@ def send_messages(self, request, queryset, bulk=False): msg = _("All messages were sent: %s" % (ret)) self.message_user(request, msg) - def send_message(self, request, queryset): + def send_message(self, request: HttpRequest, queryset: QuerySet) -> None: self.send_messages(request, queryset) send_message.short_description = _("Send test message") - def send_bulk_message(self, request, queryset): + def send_bulk_message(self, request: HttpRequest, queryset: QuerySet) -> None: self.send_messages(request, queryset, True) send_bulk_message.short_description = _("Send test message in bulk") - def enable(self, request, queryset): + def enable(self, request: HttpRequest, queryset: QuerySet) -> None: queryset.update(active=True) enable.short_description = _("Enable selected devices") - def disable(self, request, queryset): + def disable(self, request: HttpRequest, queryset: QuerySet) -> None: queryset.update(active=False) disable.short_description = _("Disable selected devices") @@ -132,7 +134,7 @@ class GCMDeviceAdmin(DeviceAdmin): ) list_filter = ("active", "cloud_message_type") - def send_messages(self, request, queryset, bulk=False): + def send_messages(self, request: HttpRequest, queryset: QuerySet, bulk: bool = False) -> None: """ Provides error handling for DeviceAdmin send_message and send_bulk_message methods. """ @@ -171,7 +173,7 @@ class WebPushDeviceAdmin(DeviceAdmin): if hasattr(User, "USERNAME_FIELD"): search_fields = ("name", "registration_id", "user__%s" % (User.USERNAME_FIELD)) else: - search_fields = ("name", "registration_id") + search_fields = ("name", "registration_id", "") admin.site.register(APNSDevice, DeviceAdmin) diff --git a/push_notifications/api/rest_framework.py b/push_notifications/api/rest_framework.py index bfbeff55..082e9dc3 100644 --- a/push_notifications/api/rest_framework.py +++ b/push_notifications/api/rest_framework.py @@ -4,9 +4,10 @@ from rest_framework.serializers import ModelSerializer, Serializer, ValidationError from rest_framework.viewsets import ModelViewSet -from ..fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re -from ..models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice -from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS +from push_notifications.fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re +from push_notifications.models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice +from push_notifications.settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS +from typing import Any, Union, Dict, Optional # Fields @@ -15,16 +16,16 @@ class HexIntegerField(IntegerField): Store an integer represented as a hex string of form "0x01". """ - def to_internal_value(self, data): + def to_internal_value(self, data: Union[str, int]) -> int: # validate hex string and convert it to the unsigned # integer representation for internal use try: - data = int(data, 16) if type(data) != int else data + data = int(data, 16) if not isinstance(data, int) else data except ValueError: raise ValidationError("Device ID is not a valid hex number") return super().to_internal_value(data) - def to_representation(self, value): + def to_representation(self, value: int) -> int: return value @@ -32,8 +33,13 @@ def to_representation(self, value): class DeviceSerializerMixin(ModelSerializer): class Meta: fields = ( - "id", "name", "application_id", "registration_id", "device_id", - "active", "date_created" + "id", + "name", + "application_id", + "registration_id", + "device_id", + "active", + "date_created", ) read_only_fields = ("date_created",) @@ -45,8 +51,7 @@ class APNSDeviceSerializer(ModelSerializer): class Meta(DeviceSerializerMixin.Meta): model = APNSDevice - def validate_registration_id(self, value): - + def validate_registration_id(self, value: str) -> str: # https://developer.apple.com/documentation/uikit/uiapplicationdelegate/1622958-application # As of 02/2023 APNS tokens (registration_id) "are of variable length. Do not hard-code their size." if hex_re.match(value) is None: @@ -56,10 +61,10 @@ def validate_registration_id(self, value): class UniqueRegistrationSerializerMixin(Serializer): - def validate(self, attrs): - devices = None - primary_key = None - request_method = None + def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]: + devices: Optional[Any] = None + primary_key: Optional[Any] = None + request_method: Optional[str] = None if self.initial_data.get("registration_id", None): if self.instance: @@ -76,9 +81,10 @@ def validate(self, attrs): Device = self.Meta.model if request_method == "update": - reg_id = attrs.get("registration_id", self.instance.registration_id) - devices = Device.objects.filter(registration_id=reg_id) \ - .exclude(id=primary_key) + reg_id: str = attrs.get("registration_id", self.instance.registration_id) + devices = Device.objects.filter(registration_id=reg_id).exclude( + id=primary_key + ) elif request_method == "create": devices = Device.objects.filter(registration_id=attrs["registration_id"]) @@ -92,20 +98,26 @@ class GCMDeviceSerializer(UniqueRegistrationSerializerMixin, ModelSerializer): help_text="ANDROID_ID / TelephonyManager.getDeviceId() (e.g: 0x01)", style={"input_type": "text"}, required=False, - allow_null=True + allow_null=True, ) class Meta(DeviceSerializerMixin.Meta): model = GCMDevice fields = ( - "id", "name", "registration_id", "device_id", "active", "date_created", - "cloud_message_type", "application_id", + "id", + "name", + "registration_id", + "device_id", + "active", + "date_created", + "cloud_message_type", + "application_id", ) extra_kwargs = {"id": {"read_only": False, "required": False}} - def validate_device_id(self, value): + def validate_device_id(self, value: Optional[int] = None) -> Optional[int]: # device ids are 64 bit unsigned values - if value > UNSIGNED_64BIT_INT_MAX_VALUE: + if value is not None and value > UNSIGNED_64BIT_INT_MAX_VALUE: raise ValidationError("Device ID is out of range") return value @@ -119,26 +131,36 @@ class WebPushDeviceSerializer(UniqueRegistrationSerializerMixin, ModelSerializer class Meta(DeviceSerializerMixin.Meta): model = WebPushDevice fields = ( - "id", "name", "registration_id", "active", "date_created", - "p256dh", "auth", "browser", "application_id", + "id", + "name", + "registration_id", + "active", + "date_created", + "p256dh", + "auth", + "browser", + "application_id", ) # Permissions class IsOwner(permissions.BasePermission): - def has_object_permission(self, request, view, obj): + def has_object_permission(self, request: Any, view: Any, obj: Any) -> bool: # must be the owner to view the object return obj.user == request.user # Mixins class DeviceViewSetMixin: - lookup_field = "registration_id" - - def create(self, request, *args, **kwargs): - serializer = None - is_update = False - if SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID") and self.lookup_field in request.data: + lookup_field: str = "registration_id" + + def create(self, request: Any, *args: Any, **kwargs: Any) -> Response: + serializer: Optional[Any] = None + is_update: bool = False + if ( + SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID") + and self.lookup_field in request.data + ): instance = self.queryset.model.objects.filter( registration_id=request.data[self.lookup_field] ).first() @@ -155,23 +177,25 @@ def create(self, request, *args, **kwargs): else: self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) - def perform_create(self, serializer): + def perform_create(self, serializer: Serializer) -> Any: if self.request.user.is_authenticated: serializer.save(user=self.request.user) return super().perform_create(serializer) - def perform_update(self, serializer): + def perform_update(self, serializer: Serializer) -> Any: if self.request.user.is_authenticated: serializer.save(user=self.request.user) return super().perform_update(serializer) class AuthorizedMixin: - permission_classes = (permissions.IsAuthenticated, IsOwner) + permission_classes: tuple = (permissions.IsAuthenticated, IsOwner) - def get_queryset(self): + def get_queryset(self) -> Any: # filter all devices to only those belonging to the current user return self.queryset.filter(user=self.request.user) @@ -207,7 +231,7 @@ class WNSDeviceAuthorizedViewSet(AuthorizedMixin, WNSDeviceViewSet): class WebPushDeviceViewSet(DeviceViewSetMixin, ModelViewSet): queryset = WebPushDevice.objects.all() serializer_class = WebPushDeviceSerializer - lookup_value_regex = '.+' + lookup_value_regex: str = ".+" class WebPushDeviceAuthorizedViewSet(AuthorizedMixin, WebPushDeviceViewSet): diff --git a/push_notifications/apns.py b/push_notifications/apns.py index 766062a7..77dcc9a9 100644 --- a/push_notifications/apns.py +++ b/push_notifications/apns.py @@ -5,7 +5,7 @@ """ import time - +from typing import Optional, Dict, Any, List, Union from apns2 import client as apns2_client from apns2 import credentials as apns2_credentials from apns2 import errors as apns2_errors @@ -13,10 +13,10 @@ from . import models from .conf import get_manager -from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError +from .exceptions import APNSUnsupportedPriority, APNSServerError -def _apns_create_socket(creds=None, application_id=None): +def _apns_create_socket(creds: Optional[apns2_credentials.Credentials] = None, application_id: Optional[str] = None) -> apns2_client.APNsClient: if creds is None: if not get_manager().has_auth_token_creds(application_id): cert = get_manager().get_apns_certificate(application_id) @@ -39,31 +39,48 @@ def _apns_create_socket(creds=None, application_id=None): def _apns_prepare( - token, alert, application_id=None, badge=None, sound=None, category=None, - content_available=False, action_loc_key=None, loc_key=None, loc_args=[], - extra={}, mutable_content=False, thread_id=None, url_args=None): - if action_loc_key or loc_key or loc_args: - apns2_alert = apns2_payload.PayloadAlert( - body=alert if alert else {}, body_localized_key=loc_key, - body_localized_args=loc_args, action_localized_key=action_loc_key) - else: - apns2_alert = alert - - if callable(badge): - badge = badge(token) - - return apns2_payload.Payload( - alert=apns2_alert, badge=badge, sound=sound, category=category, - url_args=url_args, custom=extra, thread_id=thread_id, - content_available=content_available, mutable_content=mutable_content) + token: str, + alert: Optional[str], + application_id: Optional[str] = None, + badge: Optional[int] = None, + sound: Optional[str] = None, + category: Optional[str] = None, + content_available: bool = False, + action_loc_key: Optional[str] = None, + loc_key: Optional[str] = None, + loc_args: List[Any] = [], + extra: Dict[str, Any] = {}, + mutable_content: bool = False, + thread_id: Optional[str] = None, + url_args: Optional[list] = None +) -> apns2_payload.Payload: + if action_loc_key or loc_key or loc_args: + apns2_alert = apns2_payload.PayloadAlert( + body=alert if alert else {}, body_localized_key=loc_key, + body_localized_args=loc_args, action_localized_key=action_loc_key) + else: + apns2_alert = alert + + if callable(badge): + badge = badge(token) + + return apns2_payload.Payload( + alert=apns2_alert, badge=badge, sound=sound, category=category, + url_args=url_args, custom=extra, thread_id=thread_id, + content_available=content_available, mutable_content=mutable_content) def _apns_send( - registration_id, alert, batch=False, application_id=None, creds=None, **kwargs -): + registration_id: Union[str, List[str]], + alert: Optional[str] = None, + batch: bool = False, + application_id: Optional[str] = None, + creds: Optional[apns2_credentials.Credentials] = None, + **kwargs: Any +) -> Optional[Dict[str, str]]: client = _apns_create_socket(creds=creds, application_id=application_id) - notification_kwargs = {} + notification_kwargs: Dict[str, Any] = {} # if expiration isn"t specified use 1 month from now notification_kwargs["expiration"] = kwargs.pop("expiration", None) @@ -97,7 +114,13 @@ def _apns_send( ) -def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs): +def apns_send_message( + registration_id: str, + alert: Optional[str] = None, + application_id: Optional[str] = None, + creds: Optional[apns2_credentials.Credentials] = None, + **kwargs: Any +) -> None: """ Sends an APNS notification to a single registration_id. This will send the notification as form data. @@ -122,8 +145,12 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, * def apns_send_bulk_message( - registration_ids, alert, application_id=None, creds=None, **kwargs -): + registration_ids: List[str], + alert: Optional[str] = None, + application_id: Optional[str] = None, + creds: Optional[apns2_credentials.Credentials] = None, + **kwargs: Any +) -> Optional[Dict[str, str]]: """ Sends an APNS notification to one or more registration_ids. The registration_ids argument needs to be a list. diff --git a/push_notifications/apns_async.py b/push_notifications/apns_async.py index bf6f3b29..811b0f90 100644 --- a/push_notifications/apns_async.py +++ b/push_notifications/apns_async.py @@ -2,7 +2,7 @@ import time from dataclasses import asdict, dataclass -from typing import Awaitable, Callable, Dict, Optional, Union +from typing import Awaitable, Callable, Dict, Optional, Union, Any, Tuple, List from aioapns import APNs, ConnectionError, NotificationRequest from aioapns.common import NotificationResult @@ -16,7 +16,7 @@ class NotSet: - def __init__(self): + def __init__(self) -> None: raise RuntimeError("NotSet cannot be instantiated") @@ -94,16 +94,18 @@ class Alert: An array of strings containing replacement values for variables in your message text. Each %@ character in the string specified by loc-key is replaced by a value from this array. The first item in the array replaces the first instance of the %@ character in the string, the second item replaces the second instance, and so on. """ - sound: Union[str, any] = NotSet + sound:Union[str, Any] = ( + NotSet + ) """ string - The name of a sound file in your app’s main bundle or in the Library/Sounds folder of your app’s container directory. Specify the string “default” to play the system sound. Use this key for regular notifications. For critical alerts, use the sound dictionary instead. For information about how to prepare sounds, see UNNotificationSound. + The name of a sound file in your app's main bundle or in the Library/Sounds folder of your app's container directory. Specify the string "default" to play the system sound. Use this key for regular notifications. For critical alerts, use the sound dictionary instead. For information about how to prepare sounds, see UNNotificationSound. dictionary A dictionary that contains sound information for critical alerts. For regular notifications, use the sound string instead. """ - def asDict(self) -> dict[str, any]: + def asDict(self) -> Dict[str, Any]: python_dict = asdict(self) return { key.replace("_", "-"): value @@ -115,18 +117,18 @@ def asDict(self) -> dict[str, any]: def _create_notification_request_from_args( registration_id: str, alert: Union[str, Alert], - badge: int = None, - sound: str = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - aps_kwargs: dict = {}, - message_kwargs: dict = {}, - notification_request_kwargs: dict = {}, -): + badge: Optional[int] = None, + sound: Optional[str] = None, + extra: Optional[Dict[str, Any]] = None, + expiration: Optional[int] = None, + thread_id: Optional[str] = None, + loc_key: Optional[str] = None, + priority: Optional[int] = None, + collapse_id: Optional[str] = None, + aps_kwargs: Dict[str, Any] = {}, + message_kwargs: Dict[str, Any] = {}, + notification_request_kwargs: Dict[str, Any] = {}, +) -> NotificationRequest: if alert is None: alert = Alert(body="") @@ -148,6 +150,9 @@ def _create_notification_request_from_args( if collapse_id is not None: notification_request_kwargs_out["collapse_key"] = collapse_id + if extra is None: + extra = {} + request = NotificationRequest( device_token=registration_id, message={ @@ -168,9 +173,9 @@ def _create_notification_request_from_args( def _create_client( - creds: Credentials = None, - application_id: str = None, - topic=None, + creds: Optional[Credentials] = None, + application_id: Optional[str] = None, + topic: Optional[str] = None, err_func: ErrFunc = None, ) -> APNs: use_sandbox = get_manager().get_apns_use_sandbox(application_id) @@ -179,8 +184,15 @@ def _create_client( if creds is None: creds = _get_credentials(application_id) + # Convert credentials to dict based on type + creds_dict = {} + if isinstance(creds, TokenCredentials): + creds_dict = asdict(creds) + elif isinstance(creds, CertificateCredentials): + creds_dict = asdict(creds) + client = APNs( - **asdict(creds), + **creds_dict, topic=topic, # Bundle ID use_sandbox=use_sandbox, err_func=err_func, @@ -188,7 +200,7 @@ def _create_client( return client -def _get_credentials(application_id): +def _get_credentials(application_id: Optional[str] = None) -> Credentials: if not get_manager().has_auth_token_creds(application_id): # TLS certificate authentication cert = get_manager().get_apns_certificate(application_id) @@ -209,22 +221,22 @@ def _get_credentials(application_id): def apns_send_message( registration_id: str, alert: Union[str, Alert], - application_id: str = None, - creds: Credentials = None, - topic: str = None, - badge: int = None, - sound: str = None, - content_available: bool = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, + application_id: Optional[str] = None, + creds: Optional[Credentials] = None, + topic: Optional[str] = None, + badge: Optional[int] = None, + sound: Optional[str] = None, + content_available: Optional[bool] = None, + extra: Dict[str, Any] = {}, + expiration: Optional[int] = None, + thread_id: Optional[str] = None, + loc_key: Optional[str] = None, + priority: Optional[int] = None, + collapse_id: Optional[str] = None, mutable_content: bool = False, - category: str = None, - err_func: ErrFunc = None, -): + category: Optional[str] = None, + err_func: Optional[ErrFunc] = None, +) -> Dict[str, List[Union[str, Dict[str, str]]]]: """ Sends an APNS notification to a single registration_id. If sending multiple notifications, it is more efficient to use @@ -240,12 +252,12 @@ def apns_send_message( :param application_id: The application_id to use :param creds: The credentials to use :param mutable_content: If True, the "mutable-content" flag will be set to 1. - This allows the app's Notification Service Extension to modify - the notification before it is displayed. + This allows the app's Notification Service Extension to modify + the notification before it is displayed. :param category: The category identifier for actionable notifications. - This should match a category identifier defined in the app's - Notification Content Extension or UNNotificationCategory configuration. - It allows the app to display custom actions with the notification. + This should match a category identifier defined in the app's + Notification Content Extension or UNNotificationCategory configuration. + It allows the app to display custom actions with the notification. :param content_available: If True the `content-available` flag will be set to 1, allowing the app to be woken up in the background """ results = apns_send_bulk_message( @@ -278,22 +290,22 @@ def apns_send_message( def apns_send_bulk_message( registration_ids: list[str], alert: Union[str, Alert], - application_id: str = None, - creds: Credentials = None, - topic: str = None, - badge: int = None, - sound: str = None, - content_available: bool = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - mutable_content: bool = False, - category: str = None, - err_func: ErrFunc = None, -): + application_id: Optional[str] = None, + creds: Optional[Credentials] = None, + topic: Optional[str] = None, + badge: Optional[int] = None, + sound: Optional[str] = None, + content_available: Optional[bool] = None, + extra: Optional[dict] = None, + expiration: Optional[int] = None, + thread_id: Optional[str] = None, + loc_key: Optional[str] = None, + priority: Optional[int] = None, + collapse_id: Optional[str] = None, + mutable_content: Optional[bool] = False, + category: Optional[str] = None, + err_func: Optional[ErrFunc] = None, +) -> Dict[str, str]: """ Sends an APNS notification to one or more registration_ids. The registration_ids argument needs to be a list. @@ -307,12 +319,12 @@ def apns_send_bulk_message( :param application_id: The application_id to use :param creds: The credentials to use :param mutable_content: If True, the "mutable-content" flag will be set to 1. - This allows the app's Notification Service Extension to modify - the notification before it is displayed. + This allows the app's Notification Service Extension to modify + the notification before it is displayed. :param category: The category identifier for actionable notifications. - This should match a category identifier defined in the app's - Notification Content Extension or UNNotificationCategory configuration. - It allows the app to display custom actions with the notification. + This should match a category identifier defined in the app's + Notification Content Extension or UNNotificationCategory configuration. + It allows the app to display custom actions with the notification. :param content_available: If True the `content-available` flag will be set to 1, allowing the app to be woken up in the background """ try: @@ -375,22 +387,22 @@ def apns_send_bulk_message( async def _send_bulk_request( registration_ids: list[str], alert: Union[str, Alert], - application_id: str = None, - creds: Credentials = None, - topic: str = None, - badge: int = None, - sound: str = None, - content_available: bool = None, - extra: dict = {}, - expiration: int = None, - thread_id: str = None, - loc_key: str = None, - priority: int = None, - collapse_id: str = None, - mutable_content: bool = False, - category: str = None, - err_func: ErrFunc = None, -): + application_id: Optional[str] = None, + creds: Optional[Credentials] = None, + topic: Optional[str] = None, + badge: Optional[int] = None, + sound: Optional[str] = None, + content_available: Optional[bool] = None, + extra: Optional[dict] = None, + expiration: Optional[int] = None, + thread_id: Optional[str] = None, + loc_key: Optional[str] = None, + priority: Optional[int] = None, + collapse_id: Optional[str] = None, + mutable_content: Optional[bool] = False, + category: Optional[str] = None, + err_func: Optional[ErrFunc] = None, +) -> List[Tuple[str, NotificationResult]]: client = _create_client( creds=creds, application_id=application_id, topic=topic, err_func=err_func ) @@ -424,19 +436,26 @@ async def _send_bulk_request( return await asyncio.gather(*send_requests) -async def _send_request(apns, request): +async def _send_request( + apns: APNs, + request: NotificationRequest, +) -> Tuple[str, NotificationResult]: try: res = await asyncio.wait_for(apns.send_notification(request), timeout=1) return request.device_token, res + except asyncio.TimeoutError: return request.device_token, NotificationResult( notification_id=request.notification_id, status="failed", description="TimeoutError", ) - except: + + except Exception as e: + # Catch any other communication errors (network issues, APNs errors) + # Return a failed result with the exception message for easier debugging return request.device_token, NotificationResult( notification_id=request.notification_id, status="failed", - description="CommunicationError", + description=f"CommunicationError: {e}", ) diff --git a/push_notifications/conf/__init__.py b/push_notifications/conf/__init__.py index e11f28fb..b0106241 100644 --- a/push_notifications/conf/__init__.py +++ b/push_notifications/conf/__init__.py @@ -1,20 +1,21 @@ from django.utils.module_loading import import_string +from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS +from typing import Union, Optional +from .app import AppConfig +from .appmodel import AppModelConfig +from .legacy import LegacyConfig -from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS # noqa: I001 -from .app import AppConfig # noqa: F401 -from .appmodel import AppModelConfig # noqa: F401 -from .legacy import LegacyConfig # noqa: F401 +# ManagerType is an alias for the possible configuration manager classes +# that can be loaded dynamically via SETTINGS["CONFIG"]. +ManagerType = Union[AppConfig, AppModelConfig, LegacyConfig] +manager: Optional[ManagerType] = None -manager = None - -def get_manager(reload=False): +def get_manager(reload: bool = False) -> ManagerType: global manager - - if not manager or reload is True: + if not manager or reload: manager = import_string(SETTINGS["CONFIG"])() - return manager diff --git a/push_notifications/conf/app.py b/push_notifications/conf/app.py index 4e055372..38bde083 100644 --- a/push_notifications/conf/app.py +++ b/push_notifications/conf/app.py @@ -1,12 +1,11 @@ from django.core.exceptions import ImproperlyConfigured -from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS +from push_notifications.settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS from .base import BaseConfig, check_apns_certificate +from typing import Dict, List, Any, Optional, Tuple -SETTING_MISMATCH = ( - "Application '{application_id}' ({platform}) does not support the setting '{setting}'." -) +SETTING_MISMATCH = "Application '{application_id}' ({platform}) does not support the setting '{setting}'." # code can be "missing" or "invalid" BAD_PLATFORM = ( @@ -14,13 +13,9 @@ "Must be one of: {platforms}." ) -UNKNOWN_PLATFORM = ( - "Unknown Platform: {platform}. Must be one of: {platforms}." -) +UNKNOWN_PLATFORM = "Unknown Platform: {platform}. Must be one of: {platforms}." -MISSING_SETTING = ( - 'PUSH_NOTIFICATIONS_SETTINGS.APPLICATIONS["{application_id}"]["{setting}"] is missing.' -) +MISSING_SETTING = 'PUSH_NOTIFICATIONS_SETTINGS.APPLICATIONS["{application_id}"]["{setting}"] is missing.' PLATFORMS = [ "APNS", @@ -37,9 +32,7 @@ # Settings that an application may have to enable optional features # these settings are stubs for registry support and have no effect on the operation # of the application at this time. -OPTIONAL_SETTINGS = [ - "APPLICATION_GROUP", "APPLICATION_SECRET" -] +OPTIONAL_SETTINGS = ["APPLICATION_GROUP", "APPLICATION_SECRET"] # Since we can have an auth key, combined with a auth key id and team id *or* # a certificate, we make these all optional, and then make sure we have one or @@ -50,14 +43,10 @@ APNS_AUTH_CREDS_REQUIRED = ["AUTH_KEY_PATH", "AUTH_KEY_ID", "TEAM_ID"] APNS_AUTH_CREDS_OPTIONAL = ["CERTIFICATE", "ENCRYPTION_ALGORITHM", "TOKEN_LIFETIME"] -APNS_OPTIONAL_SETTINGS = [ - "USE_SANDBOX", "USE_ALTERNATIVE_PORT", "TOPIC" -] +APNS_OPTIONAL_SETTINGS = ["USE_SANDBOX", "USE_ALTERNATIVE_PORT", "TOPIC"] FCM_REQUIRED_SETTINGS = [] -FCM_OPTIONAL_SETTINGS = [ - "MAX_RECIPIENTS", "FIREBASE_APP" -] +FCM_OPTIONAL_SETTINGS = ["MAX_RECIPIENTS", "FIREBASE_APP"] WNS_REQUIRED_SETTINGS = ["PACKAGE_SECURITY_ID", "SECRET_KEY"] WNS_OPTIONAL_SETTINGS = ["WNS_ACCESS_URL"] @@ -71,7 +60,7 @@ class AppConfig(BaseConfig): Supports any number of push notification enabled applications. """ - def __init__(self, settings=None): + def __init__(self, settings: Optional[Dict[str, Any]] = None) -> None: # supports overriding the settings to be loaded. Will load from ..settings by default. self._settings = settings or SETTINGS @@ -81,14 +70,16 @@ def __init__(self, settings=None): # validate application configurations self._validate_applications(self._settings["APPLICATIONS"]) - def _validate_applications(self, apps): + def _validate_applications(self, apps: Dict[str, Dict[str, Any]]) -> None: """Validate the application collection""" for application_id, application_config in apps.items(): self._validate_config(application_id, application_config) application_config["APPLICATION_ID"] = application_id - def _validate_config(self, application_id, application_config): + def _validate_config( + self, application_id: str, application_config: Dict[str, Any] + ) -> None: platform = application_config.get("PLATFORM", None) # platform is not present @@ -97,7 +88,7 @@ def _validate_config(self, application_id, application_config): BAD_PLATFORM.format( application_id=application_id, code="required", - platforms=", ".join(PLATFORMS) + platforms=", ".join(PLATFORMS), ) ) @@ -107,7 +98,7 @@ def _validate_config(self, application_id, application_config): BAD_PLATFORM.format( application_id=application_id, code="invalid", - platforms=", ".join(PLATFORMS) + platforms=", ".join(PLATFORMS), ) ) @@ -118,23 +109,26 @@ def _validate_config(self, application_id, application_config): else: raise ImproperlyConfigured( UNKNOWN_PLATFORM.format( - platform=platform, - platforms=", ".join(PLATFORMS) + platform=platform, platforms=", ".join(PLATFORMS) ) ) - def _validate_apns_config(self, application_id, application_config): - allowed = REQUIRED_SETTINGS + OPTIONAL_SETTINGS + \ - APNS_AUTH_CREDS_REQUIRED + \ - APNS_AUTH_CREDS_OPTIONAL + \ - APNS_OPTIONAL_SETTINGS + def _validate_apns_config( + self, application_id: str, application_config: Dict[str, Any] + ) -> None: + allowed = ( + REQUIRED_SETTINGS + + OPTIONAL_SETTINGS + + APNS_AUTH_CREDS_REQUIRED + + APNS_AUTH_CREDS_OPTIONAL + + APNS_OPTIONAL_SETTINGS + ) self._validate_allowed_settings(application_id, application_config, allowed) # We have two sets of settings, certificate and JWT auth key. # Auth Key requires 3 values, so if that is set, that will take # precedence. If None are set, we will throw an error. - has_cert_creds = APNS_SETTINGS_CERT_CREDS in \ - application_config.keys() + has_cert_creds = APNS_SETTINGS_CERT_CREDS in application_config.keys() self.has_token_creds = True for token_setting in APNS_AUTH_CREDS_REQUIRED: if token_setting not in application_config.keys(): @@ -145,17 +139,23 @@ def _validate_apns_config(self, application_id, application_config): raise ImproperlyConfigured( MISSING_SETTING.format( application_id=application_id, - setting=(APNS_SETTINGS_CERT_CREDS, APNS_AUTH_CREDS_REQUIRED))) + setting=(APNS_SETTINGS_CERT_CREDS, APNS_AUTH_CREDS_REQUIRED), + ) + ) cert_path = None if has_cert_creds: cert_path = "CERTIFICATE" elif self.has_token_creds: cert_path = "AUTH_KEY_PATH" - allowed_tokens = APNS_AUTH_CREDS_REQUIRED + \ - APNS_AUTH_CREDS_OPTIONAL + \ - APNS_OPTIONAL_SETTINGS + \ - REQUIRED_SETTINGS - self._validate_allowed_settings(application_id, application_config, allowed_tokens) + allowed_tokens = ( + APNS_AUTH_CREDS_REQUIRED + + APNS_AUTH_CREDS_OPTIONAL + + APNS_OPTIONAL_SETTINGS + + REQUIRED_SETTINGS + ) + self._validate_allowed_settings( + application_id, application_config, allowed_tokens + ) self._validate_required_settings( application_id, application_config, APNS_AUTH_CREDS_REQUIRED ) @@ -166,7 +166,7 @@ def _validate_apns_config(self, application_id, application_config): application_config.setdefault("USE_ALTERNATIVE_PORT", False) application_config.setdefault("TOPIC", None) - def _validate_apns_certificate(self, certfile): + def _validate_apns_certificate(self, certfile: str) -> None: """Validate the APNS certificate at startup.""" try: @@ -175,12 +175,19 @@ def _validate_apns_certificate(self, certfile): check_apns_certificate(content) except Exception as e: raise ImproperlyConfigured( - "The APNS certificate file at {!r} is not readable: {}".format(certfile, e) + "The APNS certificate file at {!r} is not readable: {}".format( + certfile, e + ) ) - def _validate_fcm_config(self, application_id, application_config): + def _validate_fcm_config( + self, application_id: str, application_config: Dict[str, Any] + ) -> None: allowed = ( - REQUIRED_SETTINGS + OPTIONAL_SETTINGS + FCM_REQUIRED_SETTINGS + FCM_OPTIONAL_SETTINGS + REQUIRED_SETTINGS + + OPTIONAL_SETTINGS + + FCM_REQUIRED_SETTINGS + + FCM_OPTIONAL_SETTINGS ) self._validate_allowed_settings(application_id, application_config, allowed) @@ -191,9 +198,14 @@ def _validate_fcm_config(self, application_id, application_config): application_config.setdefault("FIREBASE_APP", None) application_config.setdefault("MAX_RECIPIENTS", 1000) - def _validate_wns_config(self, application_id, application_config): + def _validate_wns_config( + self, application_id: str, application_config: Dict[str, Any] + ) -> None: allowed = ( - REQUIRED_SETTINGS + OPTIONAL_SETTINGS + WNS_REQUIRED_SETTINGS + WNS_OPTIONAL_SETTINGS + REQUIRED_SETTINGS + + OPTIONAL_SETTINGS + + WNS_REQUIRED_SETTINGS + + WNS_OPTIONAL_SETTINGS ) self._validate_allowed_settings(application_id, application_config, allowed) @@ -201,27 +213,42 @@ def _validate_wns_config(self, application_id, application_config): application_id, application_config, WNS_REQUIRED_SETTINGS ) - application_config.setdefault("WNS_ACCESS_URL", "https://login.live.com/accesstoken.srf") + application_config.setdefault( + "WNS_ACCESS_URL", "https://login.live.com/accesstoken.srf" + ) - def _validate_wp_config(self, application_id, application_config): + def _validate_wp_config( + self, application_id: str, application_config: Dict[str, Any] + ) -> None: allowed = ( - REQUIRED_SETTINGS + OPTIONAL_SETTINGS + WP_REQUIRED_SETTINGS + WP_OPTIONAL_SETTINGS + REQUIRED_SETTINGS + + OPTIONAL_SETTINGS + + WP_REQUIRED_SETTINGS + + WP_OPTIONAL_SETTINGS ) self._validate_allowed_settings(application_id, application_config, allowed) self._validate_required_settings( application_id, application_config, WP_REQUIRED_SETTINGS ) - application_config.setdefault("POST_URL", { - "CHROME": "https://fcm.googleapis.com/fcm/send", - "OPERA": "https://fcm.googleapis.com/fcm/send", - "EDGE": "https://wns2-par02p.notify.windows.com/w", - "FIREFOX": "https://updates.push.services.mozilla.com/wpush/v2", - "SAFARI": "https://web.push.apple.com", - }) + application_config.setdefault( + "POST_URL", + { + "CHROME": "https://fcm.googleapis.com/fcm/send", + "OPERA": "https://fcm.googleapis.com/fcm/send", + "EDGE": "https://wns2-par02p.notify.windows.com/w", + "FIREFOX": "https://updates.push.services.mozilla.com/wpush/v2", + "SAFARI": "https://web.push.apple.com", + }, + ) application_config.setdefault("ERROR_TIMEOUT", 1) - def _validate_allowed_settings(self, application_id, application_config, allowed_settings): + def _validate_allowed_settings( + self, + application_id: str, + application_config: Dict[str, Any], + allowed_settings: List[str], + ) -> None: """Confirm only allowed settings are present.""" for setting_key in application_config.keys(): @@ -233,9 +260,12 @@ def _validate_allowed_settings(self, application_id, application_config, allowed ) def _validate_required_settings( - self, application_id, application_config, required_settings, - should_throw=True - ): + self, + application_id: str, + application_config: Dict[str, Any], + required_settings: List[str], + should_throw: bool = True, + ) -> bool: """All required keys must be present""" for setting_key in required_settings: @@ -250,7 +280,9 @@ def _validate_required_settings( return False return True - def _get_application_settings(self, application_id, platform, settings_key): + def _get_application_settings( + self, application_id: Optional[str], platform: str, settings_key: str + ) -> Any: """ Walks through PUSH_NOTIFICATIONS_SETTINGS to find the correct setting value or raises ImproperlyConfigured. @@ -259,14 +291,18 @@ def _get_application_settings(self, application_id, platform, settings_key): if not application_id: conf_cls = "push_notifications.conf.AppConfig" raise ImproperlyConfigured( - "{} requires the application_id be specified at all times.".format(conf_cls) + "{} requires the application_id be specified at all times.".format( + conf_cls + ) ) # verify that the application config exists app_config = self._settings.get("APPLICATIONS").get(application_id, None) if app_config is None: raise ImproperlyConfigured( - "No application configured with application_id: {}.".format(application_id) + "No application configured with application_id: {}.".format( + application_id + ) ) # fetch a setting for the incorrect type of platform @@ -275,7 +311,7 @@ def _get_application_settings(self, application_id, platform, settings_key): SETTING_MISMATCH.format( application_id=application_id, platform=app_config.get("PLATFORM"), - setting=settings_key + setting=settings_key, ) ) @@ -289,16 +325,16 @@ def _get_application_settings(self, application_id, platform, settings_key): return app_config.get(settings_key) - def get_firebase_app(self, application_id=None): + def get_firebase_app(self, application_id: Optional[str] = None) -> Any: return self._get_application_settings(application_id, "FCM", "FIREBASE_APP") - def has_auth_token_creds(self, application_id=None): + def has_auth_token_creds(self, application_id: Optional[str] = None) -> bool: return self.has_token_creds - def get_max_recipients(self, application_id=None): + def get_max_recipients(self, application_id: Optional[str] = None) -> int: return self._get_application_settings(application_id, "FCM", "MAX_RECIPIENTS") - def get_apns_certificate(self, application_id=None): + def get_apns_certificate(self, application_id: Optional[str] = None) -> str: r = self._get_application_settings(application_id, "APNS", "CERTIFICATE") if not isinstance(r, str): # probably the (Django) file, and file path should be got @@ -313,44 +349,53 @@ def get_apns_certificate(self, application_id=None): ) return r - def get_apns_auth_creds(self, application_id=None): - return \ - (self._get_apns_auth_key_path(application_id), + def get_apns_auth_creds( + self, application_id: Optional[str] = None + ) -> Tuple[str, str, str]: + return ( + self._get_apns_auth_key_path(application_id), self._get_apns_auth_key_id(application_id), - self._get_apns_team_id(application_id)) + self._get_apns_team_id(application_id), + ) - def _get_apns_auth_key_path(self, application_id=None): + def _get_apns_auth_key_path(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS", "AUTH_KEY_PATH") - def _get_apns_auth_key_id(self, application_id=None): + def _get_apns_auth_key_id(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS", "AUTH_KEY_ID") - def _get_apns_team_id(self, application_id=None): + def _get_apns_team_id(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS", "TEAM_ID") - def get_apns_use_sandbox(self, application_id=None): + def get_apns_use_sandbox(self, application_id: Optional[str] = None) -> bool: return self._get_application_settings(application_id, "APNS", "USE_SANDBOX") - def get_apns_use_alternative_port(self, application_id=None): - return self._get_application_settings(application_id, "APNS", "USE_ALTERNATIVE_PORT") + def get_apns_use_alternative_port( + self, application_id: Optional[str] = None + ) -> bool: + return self._get_application_settings( + application_id, "APNS", "USE_ALTERNATIVE_PORT" + ) - def get_apns_topic(self, application_id=None): + def get_apns_topic(self, application_id: Optional[str] = None) -> Optional[str]: return self._get_application_settings(application_id, "APNS", "TOPIC") - def get_wns_package_security_id(self, application_id=None): - return self._get_application_settings(application_id, "WNS", "PACKAGE_SECURITY_ID") + def get_wns_package_security_id(self, application_id: Optional[str] = None) -> str: + return self._get_application_settings( + application_id, "WNS", "PACKAGE_SECURITY_ID" + ) - def get_wns_secret_key(self, application_id=None): + def get_wns_secret_key(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "WNS", "SECRET_KEY") - def get_wp_post_url(self, application_id, browser): + def get_wp_post_url(self, application_id: str, browser: str) -> str: return self._get_application_settings(application_id, "WP", "POST_URL")[browser] - def get_wp_private_key(self, application_id=None): + def get_wp_private_key(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "WP", "PRIVATE_KEY") - def get_wp_claims(self, application_id=None): + def get_wp_claims(self, application_id: Optional[str] = None) -> Dict[str, Any]: return self._get_application_settings(application_id, "WP", "CLAIMS") - def get_wp_error_timeout(self, application_id=None): + def get_wp_error_timeout(self, application_id: Optional[str] = None) -> int: return self._get_application_settings(application_id, "WP", "ERROR_TIMEOUT") diff --git a/push_notifications/conf/base.py b/push_notifications/conf/base.py index bb1b633a..3e1c3b66 100644 --- a/push_notifications/conf/base.py +++ b/push_notifications/conf/base.py @@ -1,36 +1,37 @@ from django.core.exceptions import ImproperlyConfigured +from typing import Optional, Any, Collection class BaseConfig: - def get_firebase_app(self, application_id=None): + def get_firebase_app(self, application_id: Optional[str] = None) -> Any: raise NotImplementedError - def has_auth_token_creds(self, application_id=None): + def has_auth_token_creds(self, application_id: Optional[str] = None) -> bool: raise NotImplementedError - def get_apns_certificate(self, application_id=None): + def get_apns_certificate(self, application_id: Optional[str] = None) -> str: raise NotImplementedError - def get_apns_auth_creds(self, application_id=None): + def get_apns_auth_creds(self, application_id: Optional[str] = None) -> Any: raise NotImplementedError - def get_apns_use_sandbox(self, application_id=None): + def get_apns_use_sandbox(self, application_id: Optional[str] = None) -> bool: raise NotImplementedError - def get_apns_use_alternative_port(self, application_id=None): + def get_apns_use_alternative_port(self, application_id: Optional[str] = None) -> bool: raise NotImplementedError - def get_wns_package_security_id(self, application_id=None): + def get_wns_package_security_id(self, application_id: Optional[str] = None) -> str: raise NotImplementedError - def get_wns_secret_key(self, application_id=None): + def get_wns_secret_key(self, application_id: Optional[str] = None) -> str: raise NotImplementedError - def get_max_recipients(self, application_id=None): + def get_max_recipients(self, application_id: Optional[str] = None) -> int: raise NotImplementedError - def get_applications(self): + def get_applications(self) -> Collection[str]: """Returns a collection containing the configured applications.""" raise NotImplementedError @@ -38,7 +39,7 @@ def get_applications(self): # This works for both the certificate and the auth key (since that's just # a certificate). -def check_apns_certificate(ss): +def check_apns_certificate(ss: str) -> None: mode = "start" for s in ss.split("\n"): if mode == "start": diff --git a/push_notifications/conf/legacy.py b/push_notifications/conf/legacy.py index 0118b5a7..598e74b4 100644 --- a/push_notifications/conf/legacy.py +++ b/push_notifications/conf/legacy.py @@ -1,6 +1,8 @@ from django.core.exceptions import ImproperlyConfigured -from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS +from push_notifications.settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS +from typing import Any, Optional, Tuple, Dict + from .base import BaseConfig @@ -16,7 +18,7 @@ class empty: class LegacyConfig(BaseConfig): msg = "Setup PUSH_NOTIFICATIONS_SETTINGS properly to send messages" - def _get_application_settings(self, application_id, settings_key, error_message): + def _get_application_settings(self, application_id: Optional[str], settings_key: str, error_message: str) -> Any: """Legacy behaviour""" if not application_id: @@ -31,21 +33,21 @@ def _get_application_settings(self, application_id, settings_key, error_message) ) raise ImproperlyConfigured(msg) - def get_firebase_app(self, application_id=None): + def get_firebase_app(self, application_id: Optional[str] = None) -> Any: key = "FIREBASE_APP" msg = ( 'Set PUSH_NOTIFICATIONS_SETTINGS["{}"] to send messages through FCM.'.format(key) ) return self._get_application_settings(application_id, key, msg) - def get_max_recipients(self, application_id=None): + def get_max_recipients(self, application_id: Optional[str] = None) -> int: key = "FCM_MAX_RECIPIENTS" msg = ( 'Set PUSH_NOTIFICATIONS_SETTINGS["{}"] to send messages through FCM.'.format(key) ) return self._get_application_settings(application_id, key, msg) - def has_auth_token_creds(self, application_id=None): + def has_auth_token_creds(self, application_id: Optional[str] = None) -> bool: try: self._get_apns_auth_key(application_id) self._get_apns_auth_key_id(application_id) @@ -55,7 +57,7 @@ def has_auth_token_creds(self, application_id=None): return True - def get_apns_certificate(self, application_id=None): + def get_apns_certificate(self, application_id: Optional[str] = None) -> str: r = self._get_application_settings( application_id, "APNS_CERTIFICATE", "You need to setup PUSH_NOTIFICATIONS_SETTINGS properly to send messages" @@ -74,61 +76,61 @@ def get_apns_certificate(self, application_id=None): raise ImproperlyConfigured(msg) return r - def get_apns_auth_creds(self, application_id=None): + def get_apns_auth_creds(self, application_id: Optional[str] = None) -> Tuple[str, str, str]: return ( self._get_apns_auth_key(application_id), self._get_apns_auth_key_id(application_id), self._get_apns_team_id(application_id)) - def _get_apns_auth_key(self, application_id=None): + def _get_apns_auth_key(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS_AUTH_KEY_PATH", self.msg) - def _get_apns_team_id(self, application_id=None): + def _get_apns_team_id(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS_TEAM_ID", self.msg) - def _get_apns_auth_key_id(self, application_id=None): + def _get_apns_auth_key_id(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS_AUTH_KEY_ID", self.msg) - def get_apns_use_sandbox(self, application_id=None): + def get_apns_use_sandbox(self, application_id: Optional[str] = None) -> bool: return self._get_application_settings(application_id, "APNS_USE_SANDBOX", self.msg) - def get_apns_use_alternative_port(self, application_id=None): + def get_apns_use_alternative_port(self, application_id: Optional[str] = None) -> bool: return self._get_application_settings(application_id, "APNS_USE_ALTERNATIVE_PORT", self.msg) - def get_apns_topic(self, application_id=None): + def get_apns_topic(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS_TOPIC", self.msg) - def get_apns_host(self, application_id=None): + def get_apns_host(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS_HOST", self.msg) - def get_apns_port(self, application_id=None): + def get_apns_port(self, application_id: Optional[str] = None) -> int: return self._get_application_settings(application_id, "APNS_PORT", self.msg) - def get_apns_feedback_host(self, application_id=None): + def get_apns_feedback_host(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "APNS_FEEDBACK_HOST", self.msg) - def get_apns_feedback_port(self, application_id=None): + def get_apns_feedback_port(self, application_id: Optional[str] = None) -> int: return self._get_application_settings(application_id, "APNS_FEEDBACK_PORT", self.msg) - def get_wns_package_security_id(self, application_id=None): + def get_wns_package_security_id(self, application_id: Optional[str] = None) -> str: return self._get_application_settings(application_id, "WNS_PACKAGE_SECURITY_ID", self.msg) - def get_wns_secret_key(self, application_id=None): + def get_wns_secret_key(self, application_id: Optional[str] = None) -> str: msg = "Setup PUSH_NOTIFICATIONS_SETTINGS properly to send messages" return self._get_application_settings(application_id, "WNS_SECRET_KEY", msg) - def get_wp_post_url(self, application_id, browser): + def get_wp_post_url(self, application_id: str, browser: str) -> str: msg = "Setup PUSH_NOTIFICATIONS_SETTINGS properly to send messages" return self._get_application_settings(application_id, "WP_POST_URL", msg)[browser] - def get_wp_private_key(self, application_id=None): + def get_wp_private_key(self, application_id: Optional[str] = None) -> str: msg = "Setup PUSH_NOTIFICATIONS_SETTINGS properly to send messages" return self._get_application_settings(application_id, "WP_PRIVATE_KEY", msg) - def get_wp_claims(self, application_id=None): + def get_wp_claims(self, application_id: Optional[str] = None) -> Dict[str, Any] : msg = "Setup PUSH_NOTIFICATIONS_SETTINGS properly to send messages" return self._get_application_settings(application_id, "WP_CLAIMS", msg) - def get_wp_error_timeout(self, application_id=None): + def get_wp_error_timeout(self, application_id: Optional[str] = None) -> int: msg = "Setup PUSH_NOTIFICATIONS_SETTINGS properly to set a timeout" return self._get_application_settings(application_id, "WP_ERROR_TIMEOUT", msg) diff --git a/push_notifications/exceptions.py b/push_notifications/exceptions.py index 33fb4659..966eb990 100644 --- a/push_notifications/exceptions.py +++ b/push_notifications/exceptions.py @@ -1,9 +1,7 @@ class NotificationError(Exception): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(message) self.message = message - pass - # APNS class APNSError(NotificationError): @@ -15,7 +13,7 @@ class APNSUnsupportedPriority(APNSError): class APNSServerError(APNSError): - def __init__(self, status): + def __init__(self, status: str) -> None: super().__init__(status) self.status = status diff --git a/push_notifications/fields.py b/push_notifications/fields.py index ef1425d0..22fdae96 100644 --- a/push_notifications/fields.py +++ b/push_notifications/fields.py @@ -1,6 +1,6 @@ import re import struct - +from typing import Optional, Any from django import forms from django.core.validators import MaxValueValidator, MinValueValidator, RegexValidator from django.db import connection, models @@ -10,7 +10,7 @@ __all__ = ["HexadecimalField", "HexIntegerField"] UNSIGNED_64BIT_INT_MIN_VALUE = 0 -UNSIGNED_64BIT_INT_MAX_VALUE = 2 ** 64 - 1 +UNSIGNED_64BIT_INT_MAX_VALUE = 2**64 - 1 hex_re = re.compile(r"^(0x)?([0-9a-f])+$", re.I) @@ -20,23 +20,23 @@ ] -def _using_signed_storage(): +def _using_signed_storage() -> bool: return connection.vendor in signed_integer_vendors -def _signed_to_unsigned_integer(value): +def _signed_to_unsigned_integer(value: int) -> int: return struct.unpack("Q", struct.pack("q", value))[0] -def _unsigned_to_signed_integer(value): +def _unsigned_to_signed_integer(value: int) -> int: return struct.unpack("q", struct.pack("Q", value))[0] -def _hex_string_to_unsigned_integer(value): +def _hex_string_to_unsigned_integer(value: str) -> int: return int(value, 16) -def _unsigned_integer_to_hex_string(value): +def _unsigned_integer_to_hex_string(value: int) -> str: return hex(value).rstrip("L") @@ -44,16 +44,20 @@ class HexadecimalField(forms.CharField): """ A form field that accepts only hexadecimal numbers """ - def __init__(self, *args, **kwargs): + + def __init__(self, *args: Any, **kwargs: Any) -> None: self.default_validators = [ RegexValidator(hex_re, _("Enter a valid hexadecimal number"), "invalid") ] super().__init__(*args, **kwargs) - def prepare_value(self, value): + def prepare_value(self, value: Optional[Any]) -> str: # converts bigint from db to hex before it is displayed in admin - if value and not isinstance(value, str) \ - and connection.vendor in ("mysql", "sqlite"): + if ( + value + and not isinstance(value, str) + and connection.vendor in ("mysql", "sqlite") + ): value = _unsigned_integer_to_hex_string(value) return super(forms.CharField, self).prepare_value(value) @@ -73,10 +77,10 @@ class HexIntegerField(models.BigIntegerField): validators = [ MinValueValidator(UNSIGNED_64BIT_INT_MIN_VALUE), - MaxValueValidator(UNSIGNED_64BIT_INT_MAX_VALUE) + MaxValueValidator(UNSIGNED_64BIT_INT_MAX_VALUE), ] - def db_type(self, connection): + def db_type(self, connection: Any) -> str: if "mysql" == connection.vendor: return "bigint unsigned" elif "sqlite" == connection.vendor: @@ -84,8 +88,8 @@ def db_type(self, connection): else: return super().db_type(connection=connection) - def get_prep_value(self, value): - """ Return the integer value to be stored from the hex string """ + def get_prep_value(self, value: Optional[Any]) -> Optional[int]: + """Return the integer value to be stored from the hex string""" if value is None or value == "": return None if isinstance(value, str): @@ -94,29 +98,31 @@ def get_prep_value(self, value): value = _unsigned_to_signed_integer(value) return value - def from_db_value(self, value, *args): - """ Return an unsigned int representation from all db backends """ + def from_db_value(self, value: Optional[int], *args: Any) -> Optional[int]: + """Return an unsigned int representation from all db backends""" if value is None: return value if _using_signed_storage(): value = _signed_to_unsigned_integer(value) return value - def to_python(self, value): - """ Return a str representation of the hexadecimal """ + def to_python(self, value: Optional[Any]) -> Optional[str]: + """Return a str representation of the hexadecimal""" if isinstance(value, str): return value if value is None: return value return _unsigned_integer_to_hex_string(value) - def formfield(self, **kwargs): + def formfield(self, **kwargs: Any) -> HexadecimalField: defaults = {"form_class": HexadecimalField} defaults.update(kwargs) # yes, that super call is right return super(models.IntegerField, self).formfield(**defaults) - def run_validators(self, value): + def run_validators(self, value: Any) -> None: # make sure validation is performed on integer value not string value + # removed `return` since run_validators() never returns anything, + # it only raises ValidationError on failure value = _hex_string_to_unsigned_integer(value) - return super(models.BigIntegerField, self).run_validators(value) + super(models.BigIntegerField, self).run_validators(value) diff --git a/push_notifications/gcm.py b/push_notifications/gcm.py index 923322e9..3b8478a0 100644 --- a/push_notifications/gcm.py +++ b/push_notifications/gcm.py @@ -6,7 +6,7 @@ """ from copy import copy -from typing import List, Union +from typing import List, Union, Dict, Any, Generator, Optional from firebase_admin import messaging from firebase_admin.exceptions import FirebaseError, InvalidArgumentError @@ -22,7 +22,7 @@ ] -def dict_to_fcm_message(data: dict, dry_run=False, **kwargs) -> messaging.Message: +def dict_to_fcm_message(data: Dict[str, Any], dry_run: bool = False, **kwargs: Any) -> messaging.Message: """ Constructs a messaging.Message from the old dictionary. @@ -87,12 +87,12 @@ def dict_to_fcm_message(data: dict, dry_run=False, **kwargs) -> messaging.Messag return message -def _chunks(l, n): +def _chunks(lst: List[Any], n: int) -> Generator[List[Any], None, None]: """ - Yield successive chunks from list \a l with a maximum size \a n + Yield successive chunks from list \a lst with a maximum size \a n """ - for i in range(0, len(l), n): - yield l[i:i + n] + for i in range(0, len(lst), n): + yield lst[i:i + n] # Error codes: https://firebase.google.com/docs/reference/fcm/rest/v1/ErrorCode @@ -108,12 +108,11 @@ def _chunks(l, n): def _validate_exception_for_deactivation(exc: Union[FirebaseError]) -> bool: if not exc: return False - exc_type = type(exc) - if exc_type == str: + if isinstance(exc, str): return exc in fcm_error_list_str return ( - exc_type == InvalidArgumentError and exc.cause == "Invalid registration" - ) or (exc_type in fcm_error_list) + isinstance(exc, InvalidArgumentError) and exc.cause == "Invalid registration" + ) or (type(exc) in fcm_error_list) def _deactivate_devices_with_error_results( @@ -139,18 +138,18 @@ def _deactivate_devices_with_error_results( return deactivated_ids -def _prepare_message(message: messaging.Message, token: str): +def _prepare_message(message: messaging.Message, token: str) -> messaging.Message: message.token = token return copy(message) def send_message( - registration_ids, + registration_ids: Union[List[str], str, None], message: messaging.Message, - application_id=None, - dry_run=False, - **kwargs -): + application_id: Optional[str] = None, + dry_run: bool = False, + **kwargs: Any +) -> Optional[messaging.BatchResponse]: """ Sends an FCM notification to one or more registration_ids. The registration_ids can be a list or a single string. diff --git a/push_notifications/models.py b/push_notifications/models.py index b9194d78..6549516e 100644 --- a/push_notifications/models.py +++ b/push_notifications/models.py @@ -1,6 +1,6 @@ from django.db import models from django.utils.translation import gettext_lazy as _ - +from typing import List, Optional, Dict, Any from .fields import HexIntegerField from .settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS @@ -43,7 +43,7 @@ class Device(models.Model): class Meta: abstract = True - def __str__(self): + def __str__(self) -> str: return ( self.name or str(self.device_id or "") or @@ -52,12 +52,12 @@ def __str__(self): class GCMDeviceManager(models.Manager): - def get_queryset(self): + def get_queryset(self) -> "GCMDeviceQuerySet": return GCMDeviceQuerySet(self.model) class GCMDeviceQuerySet(models.query.QuerySet): - def send_message(self, message, **kwargs): + def send_message(self, message: Any, **kwargs: Any) -> Any: if self.exists(): from .gcm import dict_to_fcm_message, messaging from .gcm import send_message as fcm_send_message @@ -69,9 +69,9 @@ def send_message(self, message, **kwargs): # transform legacy data to new message object message = dict_to_fcm_message(data, **kwargs) - app_ids = self.filter(active=True).order_by( + app_ids = list(self.filter(active=True).order_by( "application_id" - ).values_list("application_id", flat=True).distinct() + ).values_list("application_id", flat=True).distinct()) responses = [] for app_id in app_ids: @@ -107,7 +107,7 @@ class GCMDevice(Device): class Meta: verbose_name = _("FCM device") - def send_message(self, message, **kwargs): + def send_message(self, message: Any, **kwargs: Any) -> Optional[Any]: from .gcm import dict_to_fcm_message, messaging from .gcm import send_message as fcm_send_message @@ -116,7 +116,7 @@ def send_message(self, message, **kwargs): return if not isinstance(message, messaging.Message): - data = kwargs.pop("extra", {}) + data: Dict[str, Any] = kwargs.pop("extra", {}) if message is not None: data["message"] = message # transform legacy data to new message object @@ -127,14 +127,13 @@ def send_message(self, message, **kwargs): application_id=self.application_id, **kwargs ) - class APNSDeviceManager(models.Manager): - def get_queryset(self): + def get_queryset(self) -> "APNSDeviceQuerySet": return APNSDeviceQuerySet(self.model) class APNSDeviceQuerySet(models.query.QuerySet): - def send_message(self, message, creds=None, **kwargs): + def send_message(self, message: Any, creds: Optional[Any] = None, **kwargs: Any) -> List[Any]: if self.exists(): try: from .apns_async import apns_send_bulk_message @@ -175,7 +174,7 @@ class APNSDevice(Device): class Meta: verbose_name = _("APNS device") - def send_message(self, message, creds=None, **kwargs): + def send_message(self, message: Any, creds: Optional[Any] = None, **kwargs: Any) -> Any: try: from .apns_async import apns_send_message except ImportError: @@ -190,23 +189,23 @@ def send_message(self, message, creds=None, **kwargs): class WNSDeviceManager(models.Manager): - def get_queryset(self): + def get_queryset(self) -> "WNSDeviceQuerySet": return WNSDeviceQuerySet(self.model) class WNSDeviceQuerySet(models.query.QuerySet): - def send_message(self, message, **kwargs): + def send_message(self, message: Any, **kwargs: Any) -> List[Any]: from .wns import wns_send_bulk_message - app_ids = self.filter(active=True).order_by("application_id").values_list( + app_ids = list(self.filter(active=True).order_by("application_id").values_list( "application_id", flat=True - ).distinct() + ).distinct()) res = [] for app_id in app_ids: - reg_ids = self.filter(active=True, application_id=app_id).values_list( + reg_ids = list(self.filter(active=True, application_id=app_id).values_list( "registration_id", flat=True - ) - r = wns_send_bulk_message(uri_list=list(reg_ids), message=message, **kwargs) + )) + r = wns_send_bulk_message(uri_list=reg_ids, message=message, **kwargs) if hasattr(r, "keys"): res += [r] elif hasattr(r, "__getitem__"): @@ -227,7 +226,7 @@ class WNSDevice(Device): class Meta: verbose_name = _("WNS device") - def send_message(self, message, **kwargs): + def send_message(self, message: Any, **kwargs: Any) -> str: from .wns import wns_send_message return wns_send_message( @@ -237,14 +236,14 @@ def send_message(self, message, **kwargs): class WebPushDeviceManager(models.Manager): - def get_queryset(self): + def get_queryset(self) -> "WebPushDeviceQuerySet": return WebPushDeviceQuerySet(self.model) class WebPushDeviceQuerySet(models.query.QuerySet): - def send_message(self, message, **kwargs): + def send_message(self, message: Any, **kwargs: Any) -> List[Any]: devices = self.filter(active=True).order_by("application_id").distinct() - res = [] + res: List[Any] = [] for device in devices: res.append(device.send_message(message)) @@ -270,10 +269,10 @@ class Meta: verbose_name = _("WebPush device") @property - def device_id(self): + def device_id(self) -> None: return None - def send_message(self, message, **kwargs): + def send_message(self, message: Any, **kwargs: Any) -> Any: from .webpush import webpush_send_message return webpush_send_message(self, message, **kwargs) diff --git a/push_notifications/py.typed b/push_notifications/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/push_notifications/webpush.py b/push_notifications/webpush.py index 0a938823..7b57edf5 100644 --- a/push_notifications/webpush.py +++ b/push_notifications/webpush.py @@ -1,16 +1,22 @@ import warnings from pywebpush import WebPushException, webpush - +from typing import Dict, Any from .conf import get_manager from .exceptions import WebPushError -def get_subscription_info(application_id, uri, browser, auth, p256dh): +def get_subscription_info( + application_id: str, uri: str, browser: str, auth: str, p256dh: str +) -> Dict[str, Any]: if uri.startswith("https://"): endpoint = uri else: - url = get_manager().get_wp_post_url(application_id, browser) + manager = get_manager() + if hasattr(manager, "get_wp_post_url"): + url = manager.get_wp_post_url(application_id, browser) + else: + raise AttributeError("Manager does not support get_wp_post_url method") endpoint = "{}/{}".format(url, uri) warnings.warn( "registration_id should be the full endpoint returned from pushManager.subscribe", @@ -22,23 +28,41 @@ def get_subscription_info(application_id, uri, browser, auth, p256dh): "keys": { "auth": auth, "p256dh": p256dh, - } + }, } -def webpush_send_message(device, message, **kwargs): +def webpush_send_message(device: Any, message: str, **kwargs: Any) -> Dict[str, Any]: subscription_info = get_subscription_info( - device.application_id, device.registration_id, - device.browser, device.auth, device.p256dh) + device.application_id, + device.registration_id, + device.browser, + device.auth, + device.p256dh, + ) try: results = {"results": [{"original_registration_id": device.registration_id}]} + manager = get_manager() + + vapid_private_key = None + if hasattr(manager, "get_wp_private_key"): + vapid_private_key = manager.get_wp_private_key(device.application_id) + + vapid_claims = None + if hasattr(manager, "get_wp_claims"): + vapid_claims = manager.get_wp_claims(device.application_id).copy() + + timeout = None + if hasattr(manager, "get_wp_error_timeout"): + timeout = manager.get_wp_error_timeout(device.application_id) + response = webpush( subscription_info=subscription_info, data=message, - vapid_private_key=get_manager().get_wp_private_key(device.application_id), - vapid_claims=get_manager().get_wp_claims(device.application_id).copy(), - timeout=get_manager().get_wp_error_timeout(device.application_id), - **kwargs + vapid_private_key=vapid_private_key, + vapid_claims=vapid_claims, + timeout=timeout, + **kwargs, ) if response.ok: results["success"] = 1 diff --git a/push_notifications/wns.py b/push_notifications/wns.py index 93304719..ac6560dd 100644 --- a/push_notifications/wns.py +++ b/push_notifications/wns.py @@ -9,7 +9,7 @@ import xml.etree.ElementTree as ET from django.core.exceptions import ImproperlyConfigured - +from typing import Dict, List, Optional, Any from .compat import HTTPError, Request, urlencode, urlopen from .conf import get_manager from .exceptions import NotificationError @@ -28,7 +28,9 @@ class WNSNotificationResponseError(WNSError): pass -def _wns_authenticate(scope="notify.windows.com", application_id=None): +def _wns_authenticate( + scope: str = "notify.windows.com", application_id: Optional[str] = None +) -> str: """ Requests an Access token for WNS communication. @@ -64,7 +66,9 @@ def _wns_authenticate(scope="notify.windows.com", application_id=None): if err.code == 400: # One of your settings is probably jacked up. # https://msdn.microsoft.com/en-us/library/windows/apps/xaml/hh868245 - raise WNSAuthenticationError("Authentication failed, check your WNS settings.") + raise WNSAuthenticationError( + "Authentication failed, check your WNS settings." + ) raise err oauth_data = response.read().decode("utf-8") @@ -82,7 +86,12 @@ def _wns_authenticate(scope="notify.windows.com", application_id=None): return access_token -def _wns_send(uri, data, wns_type="wns/toast", application_id=None): +def _wns_send( + uri: str, + data: bytes, + wns_type: str = "wns/toast", + application_id: Optional[str] = None, +) -> Any: """ Sends a notification data and authentication to WNS. @@ -103,7 +112,7 @@ def _wns_send(uri, data, wns_type="wns/toast", application_id=None): "X-WNS-Type": wns_type, # wns/toast | wns/badge | wns/tile | wns/raw } - if type(data) is str: + if isinstance(data, str): data = data.encode("utf-8") request = Request(uri, data, headers) @@ -139,7 +148,7 @@ def _wns_send(uri, data, wns_type="wns/toast", application_id=None): return response.read().decode("utf-8") -def _wns_prepare_toast(data, **kwargs): +def _wns_prepare_toast(data: Dict[str, List[str]], **kwargs: Any) -> bytes: """ Creates the xml tree for a `toast` notification @@ -170,8 +179,13 @@ def _wns_prepare_toast(data, **kwargs): def wns_send_message( - uri, message=None, xml_data=None, raw_data=None, application_id=None, **kwargs -): + uri: str, + message: Optional[Any] = None, + xml_data: Optional[Dict[str, Any]] = None, + raw_data: Optional[str] = None, + application_id: Optional[str] = None, + **kwargs: Any, +) -> str: """ Sends a notification request to WNS. There are four notification types that WNS can send: toast, tile, badge and raw. @@ -211,7 +225,9 @@ def wns_send_message( wns_type = "wns/toast" if isinstance(message, str): message = { - "text": [message, ], + "text": [ + message, + ], } prepared_data = _wns_prepare_toast(data=message, **kwargs) # Create a toast/tile/badge notification from a dictionary @@ -235,8 +251,13 @@ def wns_send_message( def wns_send_bulk_message( - uri_list, message=None, xml_data=None, raw_data=None, application_id=None, **kwargs -): + uri_list: List[str], + message: Optional[Any] = None, + xml_data: Optional[Dict[str, Any]] = None, + raw_data: Optional[str] = None, + application_id: Optional[str] = None, + **kwargs: Any, +) -> List[str]: """ WNS doesn't support bulk notification, so we loop through each uri. @@ -249,14 +270,18 @@ def wns_send_bulk_message( if uri_list: for uri in uri_list: r = wns_send_message( - uri=uri, message=message, xml_data=xml_data, - raw_data=raw_data, application_id=application_id, **kwargs + uri=uri, + message=message, + xml_data=xml_data, + raw_data=raw_data, + application_id=application_id, + **kwargs, ) res.append(r) return res -def dict_to_xml_schema(data): +def dict_to_xml_schema(data: Dict[str, Any]) -> ET.Element: """ Input a dictionary to be converted to xml. There should be only one key at the top level. The value must be a dict with (required) `children` key and @@ -322,7 +347,7 @@ def dict_to_xml_schema(data): return root -def _add_sub_elements_from_dict(parent, sub_dict): +def _add_sub_elements_from_dict(parent: ET.Element, sub_dict: Dict[str, Any]) -> None: """ Add SubElements to the parent element. @@ -357,7 +382,7 @@ def _add_sub_elements_from_dict(parent, sub_dict): sub_element.text = children -def _add_element_attrs(elem, attrs): +def _add_element_attrs(elem: ET.Element, attrs: Dict[str, str]) -> ET.Element: """ Add attributes to the given element.