Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions push_notifications/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from django.contrib import admin, messages
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from django.http import HttpRequest
from django.db.models import QuerySet

from .exceptions import APNSServerError, GCMError, WebPushError
from .models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
Expand All @@ -20,9 +22,9 @@ class DeviceAdmin(admin.ModelAdmin):
if hasattr(User, "USERNAME_FIELD"):
search_fields = ("name", "device_id", "user__%s" % (User.USERNAME_FIELD))
else:
search_fields = ("name", "device_id")
search_fields = ("name", "device_id", "")

def send_messages(self, request, queryset, bulk=False):
def send_messages(self, request: HttpRequest, queryset: QuerySet, bulk: bool = False) -> None:
"""
Provides error handling for DeviceAdmin send_message and send_bulk_message methods.
"""
Expand Down Expand Up @@ -105,22 +107,22 @@ def send_messages(self, request, queryset, bulk=False):
msg = _("All messages were sent: %s" % (ret))
self.message_user(request, msg)

def send_message(self, request, queryset):
def send_message(self, request: HttpRequest, queryset: QuerySet) -> None:
self.send_messages(request, queryset)

send_message.short_description = _("Send test message")

def send_bulk_message(self, request, queryset):
def send_bulk_message(self, request: HttpRequest, queryset: QuerySet) -> None:
self.send_messages(request, queryset, True)

send_bulk_message.short_description = _("Send test message in bulk")

def enable(self, request, queryset):
def enable(self, request: HttpRequest, queryset: QuerySet) -> None:
queryset.update(active=True)

enable.short_description = _("Enable selected devices")

def disable(self, request, queryset):
def disable(self, request: HttpRequest, queryset: QuerySet) -> None:
queryset.update(active=False)

disable.short_description = _("Disable selected devices")
Expand All @@ -132,7 +134,7 @@ class GCMDeviceAdmin(DeviceAdmin):
)
list_filter = ("active", "cloud_message_type")

def send_messages(self, request, queryset, bulk=False):
def send_messages(self, request: HttpRequest, queryset: QuerySet, bulk: bool = False) -> None:
"""
Provides error handling for DeviceAdmin send_message and send_bulk_message methods.
"""
Expand Down Expand Up @@ -171,7 +173,7 @@ class WebPushDeviceAdmin(DeviceAdmin):
if hasattr(User, "USERNAME_FIELD"):
search_fields = ("name", "registration_id", "user__%s" % (User.USERNAME_FIELD))
else:
search_fields = ("name", "registration_id")
search_fields = ("name", "registration_id", "")


admin.site.register(APNSDevice, DeviceAdmin)
Expand Down
98 changes: 61 additions & 37 deletions push_notifications/api/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from rest_framework.serializers import ModelSerializer, Serializer, ValidationError
from rest_framework.viewsets import ModelViewSet

from ..fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re
from ..models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
from push_notifications.fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re
from push_notifications.models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
from push_notifications.settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
from typing import Any, Union, Dict, Optional


# Fields
Expand All @@ -15,25 +16,30 @@ class HexIntegerField(IntegerField):
Store an integer represented as a hex string of form "0x01".
"""

def to_internal_value(self, data):
def to_internal_value(self, data: Union[str, int]) -> int:
# validate hex string and convert it to the unsigned
# integer representation for internal use
try:
data = int(data, 16) if type(data) != int else data
data = int(data, 16) if not isinstance(data, int) else data
except ValueError:
raise ValidationError("Device ID is not a valid hex number")
return super().to_internal_value(data)

def to_representation(self, value):
def to_representation(self, value: int) -> int:
return value


# Serializers
class DeviceSerializerMixin(ModelSerializer):
class Meta:
fields = (
"id", "name", "application_id", "registration_id", "device_id",
"active", "date_created"
"id",
"name",
"application_id",
"registration_id",
"device_id",
"active",
"date_created",
)
read_only_fields = ("date_created",)

Expand All @@ -45,8 +51,7 @@ class APNSDeviceSerializer(ModelSerializer):
class Meta(DeviceSerializerMixin.Meta):
model = APNSDevice

def validate_registration_id(self, value):

def validate_registration_id(self, value: str) -> str:
# https://developer.apple.com/documentation/uikit/uiapplicationdelegate/1622958-application
# As of 02/2023 APNS tokens (registration_id) "are of variable length. Do not hard-code their size."
if hex_re.match(value) is None:
Expand All @@ -56,10 +61,10 @@ def validate_registration_id(self, value):


class UniqueRegistrationSerializerMixin(Serializer):
def validate(self, attrs):
devices = None
primary_key = None
request_method = None
def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
devices: Optional[Any] = None
primary_key: Optional[Any] = None
request_method: Optional[str] = None

if self.initial_data.get("registration_id", None):
if self.instance:
Expand All @@ -76,9 +81,10 @@ def validate(self, attrs):

Device = self.Meta.model
if request_method == "update":
reg_id = attrs.get("registration_id", self.instance.registration_id)
devices = Device.objects.filter(registration_id=reg_id) \
.exclude(id=primary_key)
reg_id: str = attrs.get("registration_id", self.instance.registration_id)
devices = Device.objects.filter(registration_id=reg_id).exclude(
id=primary_key
)
elif request_method == "create":
devices = Device.objects.filter(registration_id=attrs["registration_id"])

Expand All @@ -92,20 +98,26 @@ class GCMDeviceSerializer(UniqueRegistrationSerializerMixin, ModelSerializer):
help_text="ANDROID_ID / TelephonyManager.getDeviceId() (e.g: 0x01)",
style={"input_type": "text"},
required=False,
allow_null=True
allow_null=True,
)

class Meta(DeviceSerializerMixin.Meta):
model = GCMDevice
fields = (
"id", "name", "registration_id", "device_id", "active", "date_created",
"cloud_message_type", "application_id",
"id",
"name",
"registration_id",
"device_id",
"active",
"date_created",
"cloud_message_type",
"application_id",
)
extra_kwargs = {"id": {"read_only": False, "required": False}}

def validate_device_id(self, value):
def validate_device_id(self, value: Optional[int] = None) -> Optional[int]:
# device ids are 64 bit unsigned values
if value > UNSIGNED_64BIT_INT_MAX_VALUE:
if value is not None and value > UNSIGNED_64BIT_INT_MAX_VALUE:
raise ValidationError("Device ID is out of range")
return value

Expand All @@ -119,26 +131,36 @@ class WebPushDeviceSerializer(UniqueRegistrationSerializerMixin, ModelSerializer
class Meta(DeviceSerializerMixin.Meta):
model = WebPushDevice
fields = (
"id", "name", "registration_id", "active", "date_created",
"p256dh", "auth", "browser", "application_id",
"id",
"name",
"registration_id",
"active",
"date_created",
"p256dh",
"auth",
"browser",
"application_id",
)


# Permissions
class IsOwner(permissions.BasePermission):
def has_object_permission(self, request, view, obj):
def has_object_permission(self, request: Any, view: Any, obj: Any) -> bool:
# must be the owner to view the object
return obj.user == request.user


# Mixins
class DeviceViewSetMixin:
lookup_field = "registration_id"

def create(self, request, *args, **kwargs):
serializer = None
is_update = False
if SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID") and self.lookup_field in request.data:
lookup_field: str = "registration_id"

def create(self, request: Any, *args: Any, **kwargs: Any) -> Response:
serializer: Optional[Any] = None
is_update: bool = False
if (
SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID")
and self.lookup_field in request.data
):
instance = self.queryset.model.objects.filter(
registration_id=request.data[self.lookup_field]
).first()
Expand All @@ -155,23 +177,25 @@ def create(self, request, *args, **kwargs):
else:
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)

def perform_create(self, serializer):
def perform_create(self, serializer: Serializer) -> Any:
if self.request.user.is_authenticated:
serializer.save(user=self.request.user)
return super().perform_create(serializer)

def perform_update(self, serializer):
def perform_update(self, serializer: Serializer) -> Any:
if self.request.user.is_authenticated:
serializer.save(user=self.request.user)
return super().perform_update(serializer)


class AuthorizedMixin:
permission_classes = (permissions.IsAuthenticated, IsOwner)
permission_classes: tuple = (permissions.IsAuthenticated, IsOwner)

def get_queryset(self):
def get_queryset(self) -> Any:
# filter all devices to only those belonging to the current user
return self.queryset.filter(user=self.request.user)

Expand Down Expand Up @@ -207,7 +231,7 @@ class WNSDeviceAuthorizedViewSet(AuthorizedMixin, WNSDeviceViewSet):
class WebPushDeviceViewSet(DeviceViewSetMixin, ModelViewSet):
queryset = WebPushDevice.objects.all()
serializer_class = WebPushDeviceSerializer
lookup_value_regex = '.+'
lookup_value_regex: str = ".+"


class WebPushDeviceAuthorizedViewSet(AuthorizedMixin, WebPushDeviceViewSet):
Expand Down
79 changes: 53 additions & 26 deletions push_notifications/apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
"""

import time

from typing import Optional, Dict, Any, List, Union
from apns2 import client as apns2_client
from apns2 import credentials as apns2_credentials
from apns2 import errors as apns2_errors
from apns2 import payload as apns2_payload

from . import models
from .conf import get_manager
from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError
from .exceptions import APNSUnsupportedPriority, APNSServerError


def _apns_create_socket(creds=None, application_id=None):
def _apns_create_socket(creds: Optional[apns2_credentials.Credentials] = None, application_id: Optional[str] = None) -> apns2_client.APNsClient:
if creds is None:
if not get_manager().has_auth_token_creds(application_id):
cert = get_manager().get_apns_certificate(application_id)
Expand All @@ -39,31 +39,48 @@ def _apns_create_socket(creds=None, application_id=None):


def _apns_prepare(
token, alert, application_id=None, badge=None, sound=None, category=None,
content_available=False, action_loc_key=None, loc_key=None, loc_args=[],
extra={}, mutable_content=False, thread_id=None, url_args=None):
if action_loc_key or loc_key or loc_args:
apns2_alert = apns2_payload.PayloadAlert(
body=alert if alert else {}, body_localized_key=loc_key,
body_localized_args=loc_args, action_localized_key=action_loc_key)
else:
apns2_alert = alert

if callable(badge):
badge = badge(token)

return apns2_payload.Payload(
alert=apns2_alert, badge=badge, sound=sound, category=category,
url_args=url_args, custom=extra, thread_id=thread_id,
content_available=content_available, mutable_content=mutable_content)
token: str,
alert: Optional[str],
application_id: Optional[str] = None,
badge: Optional[int] = None,
sound: Optional[str] = None,
category: Optional[str] = None,
content_available: bool = False,
action_loc_key: Optional[str] = None,
loc_key: Optional[str] = None,
loc_args: List[Any] = [],
extra: Dict[str, Any] = {},
mutable_content: bool = False,
thread_id: Optional[str] = None,
url_args: Optional[list] = None
) -> apns2_payload.Payload:
if action_loc_key or loc_key or loc_args:
apns2_alert = apns2_payload.PayloadAlert(
body=alert if alert else {}, body_localized_key=loc_key,
body_localized_args=loc_args, action_localized_key=action_loc_key)
else:
apns2_alert = alert

if callable(badge):
badge = badge(token)

return apns2_payload.Payload(
alert=apns2_alert, badge=badge, sound=sound, category=category,
url_args=url_args, custom=extra, thread_id=thread_id,
content_available=content_available, mutable_content=mutable_content)


def _apns_send(
registration_id, alert, batch=False, application_id=None, creds=None, **kwargs
):
registration_id: Union[str, List[str]],
alert: Optional[str] = None,
batch: bool = False,
application_id: Optional[str] = None,
creds: Optional[apns2_credentials.Credentials] = None,
**kwargs: Any
) -> Optional[Dict[str, str]]:
client = _apns_create_socket(creds=creds, application_id=application_id)

notification_kwargs = {}
notification_kwargs: Dict[str, Any] = {}

# if expiration isn"t specified use 1 month from now
notification_kwargs["expiration"] = kwargs.pop("expiration", None)
Expand Down Expand Up @@ -97,7 +114,13 @@ def _apns_send(
)


def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs):
def apns_send_message(
registration_id: str,
alert: Optional[str] = None,
application_id: Optional[str] = None,
creds: Optional[apns2_credentials.Credentials] = None,
**kwargs: Any
) -> None:
"""
Sends an APNS notification to a single registration_id.
This will send the notification as form data.
Expand All @@ -122,8 +145,12 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, *


def apns_send_bulk_message(
registration_ids, alert, application_id=None, creds=None, **kwargs
):
registration_ids: List[str],
alert: Optional[str] = None,
application_id: Optional[str] = None,
creds: Optional[apns2_credentials.Credentials] = None,
**kwargs: Any
) -> Optional[Dict[str, str]]:
"""
Sends an APNS notification to one or more registration_ids.
The registration_ids argument needs to be a list.
Expand Down
Loading
Loading