Skip to content

Commit

Permalink
🌷 add type safety and type hinting to portal
Browse files Browse the repository at this point in the history
  • Loading branch information
ashleyzhang01 committed Nov 5, 2024
1 parent 7bf5144 commit fc30c55
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 39 deletions.
32 changes: 21 additions & 11 deletions backend/portal/serializers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Any, Dict, List, TypeAlias

from rest_framework import serializers

from portal.logic import check_targets, get_user_clubs, get_user_populations
from portal.models import Poll, PollOption, PollVote, Post, TargetPopulation


ClubCode: TypeAlias = str
ValidationData: TypeAlias = Dict[str, Any]


class TargetPopulationSerializer(serializers.ModelSerializer):
class Meta:
model = TargetPopulation
Expand All @@ -28,8 +34,8 @@ class Meta:
)
read_only_fields = ("id", "created_date")

def create(self, validated_data):
club_code = validated_data["club_code"]
def create(self, validated_data: ValidationData) -> Poll:
club_code: ClubCode = validated_data["club_code"]
# ensures user is part of club
if club_code not in [
x["club"]["code"] for x in get_user_clubs(self.context["request"].user)
Expand Down Expand Up @@ -78,7 +84,7 @@ def create(self, validated_data):

return super().create(validated_data)

def update(self, instance, validated_data):
def update(self, instance: Poll, validated_data: ValidationData) -> Poll:
# if Poll is updated, then approve should be false
if not self.context["request"].user.is_superuser:
validated_data["status"] = Poll.STATUS_DRAFT
Expand All @@ -96,7 +102,7 @@ class Meta:
)
read_only_fields = ("id", "vote_count")

def create(self, validated_data):
def create(self, validated_data: ValidationData) -> PollOption:
poll_options_count = PollOption.objects.filter(poll=validated_data["poll"]).count()
if poll_options_count >= 5:
raise serializers.ValidationError(
Expand Down Expand Up @@ -142,7 +148,7 @@ class Meta:
"created_date",
)

def create(self, validated_data):
def create(self, validated_data: ValidationData) -> PollVote:

options = validated_data["poll_options"]
id_hash = validated_data["id_hash"]
Expand Down Expand Up @@ -209,7 +215,7 @@ class PostSerializer(serializers.ModelSerializer):
image = serializers.ImageField(write_only=True, required=False, allow_null=True)
image_url = serializers.SerializerMethodField("get_image_url")

def get_image_url(self, obj):
def get_image_url(self, obj: Post) -> str | None:
# use thumbnail if exists
image = obj.image

Expand Down Expand Up @@ -243,7 +249,9 @@ class Meta:
)
read_only_fields = ("id", "created_date", "target_populations")

def parse_target_populations(self, raw_target_populations):
def parse_target_populations(
self, raw_target_populations: List[int] | str
) -> List[TargetPopulation]:
if isinstance(raw_target_populations, list):
ids = raw_target_populations
else:
Expand All @@ -254,7 +262,9 @@ def parse_target_populations(self, raw_target_populations):
)
return TargetPopulation.objects.filter(id__in=ids)

def update_target_populations(self, target_populations):
def update_target_populations(
self, target_populations: List[TargetPopulation]
) -> List[TargetPopulation]:
year = False
major = False
school = False
Expand All @@ -281,8 +291,8 @@ def update_target_populations(self, target_populations):

return target_populations

def create(self, validated_data):
club_code = validated_data["club_code"]
def create(self, validated_data: ValidationData) -> Post:
club_code: ClubCode = validated_data["club_code"]
# Ensures user is part of club
if club_code not in [
x["club"]["code"] for x in get_user_clubs(self.context["request"].user)
Expand All @@ -309,7 +319,7 @@ def create(self, validated_data):

return instance

def update(self, instance, validated_data):
def update(self, instance: Post, validated_data: ValidationData) -> Post:
# if post is updated, then approved should be false
if not self.context["request"].user.is_superuser:
validated_data["status"] = Post.STATUS_DRAFT
Expand Down
71 changes: 43 additions & 28 deletions backend/portal/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Dict, List, TypeAlias

from django.contrib.auth import get_user_model
from django.db.models import Count, Q
from django.db.models import Count, Q, QuerySet
from django.db.models.functions import Trunc
from django.utils import timezone
from rest_framework import generics, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView

Expand Down Expand Up @@ -35,6 +38,14 @@
)


PollQuerySet: TypeAlias = QuerySet[Poll]
PostQuerySet: TypeAlias = QuerySet[Post]
PollVoteQuerySet: TypeAlias = QuerySet[PollVote]
ClubData: TypeAlias = List[Dict[str, Any]]
PollOptionQuerySet: TypeAlias = QuerySet[PollOption]
TimeSeriesData: TypeAlias = Dict[str, Any]
VoteStatistics: TypeAlias = Dict[str, Any]

User = get_user_model()


Expand All @@ -43,7 +54,7 @@ class UserInfo(APIView):

permission_classes = [IsAuthenticated]

def get(self, request):
def get(self, request: Request) -> Response:
return Response({"user": get_user_info(request.user)})


Expand All @@ -52,10 +63,11 @@ class UserClubs(APIView):

permission_classes = [IsAuthenticated]

def get(self, request):
club_data = []
for club in get_user_clubs(request.user):
club_data.append(get_club_info(request.user, club["club"]["code"]))
def get(self, request: Request) -> Response:
club_data: ClubData = [
get_club_info(request.user, club["club"]["code"])
for club in get_user_clubs(request.user)
]
return Response({"clubs": club_data})


Expand Down Expand Up @@ -90,7 +102,7 @@ class Polls(viewsets.ModelViewSet):
permission_classes = [PollOwnerPermission | IsSuperUser]
serializer_class = PollSerializer

def get_queryset(self):
def get_queryset(self) -> PollQuerySet:
# all polls if superuser, polls corresponding to club for regular user
return (
Poll.objects.all()
Expand All @@ -101,7 +113,7 @@ def get_queryset(self):
)

@action(detail=False, methods=["post"])
def browse(self, request):
def browse(self, request: Request) -> Response:
"""Returns list of all possible polls user can answer but has yet to
For admins, returns list of all polls they have not voted for and have yet to expire
"""
Expand Down Expand Up @@ -156,14 +168,14 @@ def browse(self, request):
)

@action(detail=False, methods=["get"], permission_classes=[IsSuperUser])
def review(self, request):
def review(self, request: Request) -> Response:
"""Returns list of all Polls that admins still need to approve of"""
return Response(
RetrievePollSerializer(Poll.objects.filter(status=Poll.STATUS_DRAFT), many=True).data
)

@action(detail=True, methods=["get"])
def option_view(self, request, pk=None):
def option_view(self, request: Request, pk: int = None) -> Response:
"""Returns information on specific poll, including options and vote counts"""
return Response(RetrievePollSerializer(Poll.objects.filter(id=pk).first(), many=False).data)

Expand All @@ -184,7 +196,7 @@ class PollOptions(viewsets.ModelViewSet):
permission_classes = [OptionOwnerPermission | IsSuperUser]
serializer_class = PollOptionSerializer

def get_queryset(self):
def get_queryset(self) -> PollOptionQuerySet:
# if user is admin, they can update anything
# if user is not admin, they can only update their own options
return (
Expand All @@ -207,26 +219,26 @@ class PollVotes(viewsets.ModelViewSet):
permission_classes = [PollOwnerPermission | IsSuperUser]
serializer_class = PollVoteSerializer

def get_queryset(self):
def get_queryset(self) -> PollVoteQuerySet:
return PollVote.objects.none()

@action(detail=False, methods=["post"])
def recent(self, request):
def recent(self, request: Request) -> Response:

id_hash = request.data["id_hash"]

poll_votes = PollVote.objects.filter(id_hash=id_hash).order_by("-created_date").first()
return Response(RetrievePollVoteSerializer(poll_votes).data)

@action(detail=False, methods=["post"])
def all(self, request):
def all(self, request: Request) -> Response:

id_hash = request.data["id_hash"]

poll_votes = PollVote.objects.filter(id_hash=id_hash).order_by("-created_date")
return Response(RetrievePollVoteSerializer(poll_votes, many=True).data)

def create(self, request, *args, **kwargs):
def create(self, request: Request, *args: Any, **kwargs: Any) -> Response:
record_analytics(Metric.PORTAL_POLL_VOTED, request.user.username)
return super().create(request, *args, **kwargs)

Expand All @@ -236,18 +248,21 @@ class PollVoteStatistics(APIView):

permission_classes = [TimeSeriesPermission | IsSuperUser]

def get(self, request, poll_id):
return Response(
{
"time_series": PollVote.objects.filter(poll__id=poll_id)
.annotate(date=Trunc("created_date", "day"))
.values("date")
.annotate(votes=Count("date"))
.order_by("date"),
"poll_statistics": get_demographic_breakdown(poll_id),
}
def get(self, request: Request, poll_id: int) -> Response:
time_series = (
PollVote.objects.filter(poll__id=poll_id)
.annotate(date=Trunc("created_date", "day"))
.values("date")
.annotate(votes=Count("date"))
.order_by("date")
)

statistics: VoteStatistics = {
"time_series": time_series,
"poll_statistics": get_demographic_breakdown(poll_id),
}
return Response(statistics)


class Posts(viewsets.ModelViewSet):
"""
Expand All @@ -270,7 +285,7 @@ class Posts(viewsets.ModelViewSet):
permission_classes = [PostOwnerPermission | IsSuperUser]
serializer_class = PostSerializer

def get_queryset(self):
def get_queryset(self) -> PostQuerySet:
return (
Post.objects.all()
if self.request.user.is_superuser
Expand All @@ -280,7 +295,7 @@ def get_queryset(self):
)

@action(detail=False, methods=["get"])
def browse(self, request):
def browse(self, request: Request) -> Response:
"""
Returns a list of all posts that are targeted at the current user
For admins, returns list of posts that they have not approved and have yet to expire
Expand Down Expand Up @@ -318,7 +333,7 @@ def browse(self, request):
)

@action(detail=False, methods=["get"], permission_classes=[IsSuperUser])
def review(self, request):
def review(self, request: Request) -> Response:
"""Returns a list of all Posts that admins still need to approve of"""
return Response(
PostSerializer(Post.objects.filter(status=Poll.STATUS_DRAFT), many=True).data
Expand Down

0 comments on commit fc30c55

Please sign in to comment.