diff --git a/autograder/core/tests/test_utils.py b/autograder/core/tests/test_utils.py index 94ea35d4..d8da6925 100644 --- a/autograder/core/tests/test_utils.py +++ b/autograder/core/tests/test_utils.py @@ -1,3 +1,4 @@ +import copy import datetime import os import tempfile @@ -14,6 +15,52 @@ from autograder.utils.testing import UnitTestBase +class DateTimesAreEqualTest(UnitTestBase): + def test_equal_datetimes(self): + now = datetime.datetime.now() + now2 = copy.deepcopy(now) + self.assertTrue(core_ut.datetimes_are_equal(now, now2)) + self.assertTrue(core_ut.datetimes_are_equal(now2, now)) + + def test_unequal_datetimes(self): + now = datetime.datetime.now() + later = now + datetime.timedelta(minutes=5) + self.assertFalse(core_ut.datetimes_are_equal(now, later)) + self.assertFalse(core_ut.datetimes_are_equal(later, now)) + + def test_equal_strings(self): + time1 = '2024-09-12 17:52:05.538324+00:00' + time2 = '2024-09-12 17:52:05.538324Z' + self.assertTrue(core_ut.datetimes_are_equal(time1, time2)) + self.assertTrue(core_ut.datetimes_are_equal(time2, time1)) + self.assertTrue(core_ut.datetimes_are_equal(time1, time1)) + + def test_unequal_strings(self): + time1 = '2025-09-12 17:52:05.538324+00:00' + time2 = '2024-09-12 17:52:05.538324Z' + self.assertFalse(core_ut.datetimes_are_equal(time1, time2)) + self.assertFalse(core_ut.datetimes_are_equal(time2, time1)) + + def test_none(self): + time1 = None + time2 = datetime.datetime.now() + self.assertTrue(core_ut.datetimes_are_equal(time1, time1)) + self.assertFalse(core_ut.datetimes_are_equal(time1, time2)) + self.assertFalse(core_ut.datetimes_are_equal(time2, time1)) + + def test_mixed_equal(self): + time1 = datetime.datetime.now() + time2 = time1.isoformat() + self.assertTrue(core_ut.datetimes_are_equal(time1, time2)) + self.assertTrue(core_ut.datetimes_are_equal(time2, time1)) + + def test_mixed_unequal(self): + now = datetime.datetime.now() + later = (now + datetime.timedelta(minutes=5)).isoformat() + self.assertFalse(core_ut.datetimes_are_equal(now, later)) + self.assertFalse(core_ut.datetimes_are_equal(later, now)) + + class DiffTestCase(SimpleTestCase): def setUp(self): super().setUp() diff --git a/autograder/core/utils.py b/autograder/core/utils.py index 052d35e9..2d1f42e4 100644 --- a/autograder/core/utils.py +++ b/autograder/core/utils.py @@ -12,6 +12,7 @@ from django.conf import settings from django.core import exceptions from django.utils import timezone +from django.utils.dateparse import parse_datetime from . import constants as const @@ -22,6 +23,37 @@ from .models.submission import Submission +def datetimes_are_equal(time1: datetime.datetime | str | None, + time2: datetime.datetime | str | None) -> bool: + """ + Returns true if time1 is equal to time2. Each argument may be either a `datetime.datetime` + object, a valid ISO formatted string, or `None`. + + :raises: ValueError when either argument is an invalid datetime string. + """ + # need to do these checks because parse_datetime returns None when passed + # an invalid string. + if time1 is None and time2 is not None: + return False + elif time1 is not None and time2 is None: + return False + elif time1 is None and time2 is None: + return True + + if isinstance(time1, str): + time1_parsed = parse_datetime(time1) + if time1_parsed is None: + raise ValueError(f"{time1} is not a valid time") + time1 = time1_parsed + if isinstance(time2, str): + time2_parsed = parse_datetime(time2) + if time2_parsed is None: + raise ValueError(f"{time2} is not a valid time") + time2 = time2_parsed + + return time1 == time2 + + class InvalidSoftDeadlineError(Exception): """ Raised when a soft deadline or due date is an invalid value (not None or diff --git a/autograder/rest_api/tests/test_views/test_group_views/test_group_views.py b/autograder/rest_api/tests/test_views/test_group_views/test_group_views.py index f40fb8f2..caa19d1b 100644 --- a/autograder/rest_api/tests/test_views/test_group_views/test_group_views.py +++ b/autograder/rest_api/tests/test_views/test_group_views/test_group_views.py @@ -6,11 +6,12 @@ from django.core.files.uploadedfile import SimpleUploadedFile from django.urls import reverse from django.utils import timezone -from rest_framework import status +from rest_framework import exceptions, status from rest_framework.test import APIClient import autograder.core.models as ag_models import autograder.utils.testing.model_obj_builders as obj_build +from autograder.rest_api.views.group_views import clean_extended_due_dates from autograder.rest_api.serialize_user import serialize_user from autograder.rest_api.tests.test_views.ag_view_test_base import AGViewTestBase from autograder.utils.testing import UnitTestBase @@ -491,7 +492,11 @@ def test_admin_update_group_deprecated_extended_due_date(self): self.do_patch_object_test( group, self.client, self.admin, self.group_url(group), - {'extended_due_date': self.new_due_date}, + { + 'extended_due_date': self.new_due_date, + 'soft_extended_due_date': group.soft_extended_due_date, + 'hard_extended_due_date': group.hard_extended_due_date + }, expected_response_overrides={ 'extended_due_date': self.new_due_date.replace(second=0, microsecond=0), 'soft_extended_due_date': self.new_due_date.replace(second=0, microsecond=0), @@ -521,7 +526,10 @@ def test_admin_update_soft_extended_due_date(self): valid_soft_extended_due_date = self.new_due_date - datetime.timedelta(days=1) self.do_patch_object_test( group, self.client, self.admin, self.group_url(group), - {'soft_extended_due_date': valid_soft_extended_due_date}, + { + 'soft_extended_due_date': valid_soft_extended_due_date, + 'extended_due_date': group.extended_due_date + }, expected_response_overrides={ 'soft_extended_due_date': valid_soft_extended_due_date.replace( second=0, microsecond=0), @@ -551,7 +559,10 @@ def test_admin_update_hard_extended_due_date(self): valid_hard_extended_due_date = self.new_due_date + datetime.timedelta(days=1) self.do_patch_object_test( group, self.client, self.admin, self.group_url(group), - {'hard_extended_due_date': valid_hard_extended_due_date}, + { + 'hard_extended_due_date': valid_hard_extended_due_date, + 'extended_due_date': group.extended_due_date + }, expected_response_overrides={ 'hard_extended_due_date': valid_hard_extended_due_date.replace( second=0, microsecond=0), @@ -793,3 +804,66 @@ def test_non_admin_delete_group_permission_denied(self) -> None: self.group = ag_models.Group.objects.get(pk=self.group.pk) self.assertEqual(original_member_names, self.group.member_names) + + +class CleanExtendedDueDatesTests(UnitTestBase): + def setUp(self): + super().setUp() + self.group = obj_build.make_group() + self.new_time = '2020-09-12 17:52:05.538324Z' + + def test_mixed_use_deprecated_extended_due_date(self) -> None: + update_data = { + 'extended_due_date': self.new_time, + 'soft_extended_due_date': self.new_time, + 'foo': 'bar' + } + with self.assertRaises(exceptions.ValidationError) as cm: + _ = clean_extended_due_dates(update_data, self.group) + self.assertIn('extended_due_date', cm.exception.detail) + + update_data = {'extended_due_date': self.new_time, 'hard_extended_due_date': self.new_time} + with self.assertRaises(exceptions.ValidationError) as cm: + _ = clean_extended_due_dates(update_data, self.group) + self.assertIn('extended_due_date', cm.exception.detail) + + def test_updated_deprecated_extended_due_date(self) -> None: + update_data = { + 'extended_due_date': self.new_time, + 'soft_extended_due_date': self.group.soft_extended_due_date, + 'hard_extended_due_date': self.group.hard_extended_due_date, + 'foo': 'bar' + } + new_update_data = clean_extended_due_dates(update_data, self.group) + expected_new_update_data = {'extended_due_date': self.new_time, 'foo': 'bar'} + self.assertEqual(new_update_data, expected_new_update_data) + + def test_updated_soft_extended_due_date(self) -> None: + update_data = { + 'extended_due_date': self.group.extended_due_date, + 'soft_extended_due_date': self.new_time, + 'hard_extended_due_date': self.group.hard_extended_due_date, + 'foo': 'bar' + } + new_update_data = clean_extended_due_dates(update_data, self.group) + expected_new_update_data = { + 'soft_extended_due_date': self.new_time, + 'hard_extended_due_date': self.group.hard_extended_due_date, + 'foo': 'bar' + } + self.assertEqual(new_update_data, expected_new_update_data) + + def test_updated_hard_extended_due_date(self) -> None: + update_data = { + 'extended_due_date': self.group.extended_due_date, + 'soft_extended_due_date': self.group.soft_extended_due_date, + 'hard_extended_due_date': self.new_time, + 'foo': 'bar' + } + new_update_data = clean_extended_due_dates(update_data, self.group) + expected_new_update_data = { + 'soft_extended_due_date': self.group.soft_extended_due_date, + 'hard_extended_due_date': self.new_time, + 'foo': 'bar' + } + self.assertEqual(new_update_data, expected_new_update_data) diff --git a/autograder/rest_api/views/group_views.py b/autograder/rest_api/views/group_views.py index 8629c018..0bfdd2dc 100644 --- a/autograder/rest_api/views/group_views.py +++ b/autograder/rest_api/views/group_views.py @@ -1,3 +1,5 @@ +import copy +from datetime import datetime import itertools import os import shutil @@ -7,9 +9,10 @@ from django.db import transaction from django.db.models import Prefetch from django.utils import timezone +from django.utils.dateparse import parse_datetime from django.utils.decorators import method_decorator from drf_composable_permissions.p import P -from rest_framework import decorators, mixins, permissions, response, status +from rest_framework import decorators, exceptions, mixins, permissions, response, status import autograder.core.models as ag_models import autograder.core.utils as core_ut @@ -214,7 +217,7 @@ def get(self, *args, **kwargs): def patch(self, request, *args, **kwargs): group = self.get_object() - update_data = dict(request.data) + update_data = clean_extended_due_dates(dict(request.data), group) if 'member_names' in update_data: users = [ User.objects.get_or_create( @@ -382,3 +385,65 @@ def _get_merged_extended_due_date(self, group1, group2): return group1.extended_due_date return max(group1.extended_due_date, group2.extended_due_date) + + +def clean_extended_due_dates(update_data: dict, old_group: ag_models.Group) -> dict: + """ + Return a new dict without 'soft_extended_due_date' and without 'hard_extended_due_date' + if `update_data` contains a changed value for 'extended_due_date'. If it contains just + a changed value to either 'soft_extended_due_date' or 'hard_extended_due_date', return + a new dict without 'extended_due_date'. + + :raises: exceptions.ValidationError if `update_data` contains new values for + both deprecated 'extended_due_date' and one of 'soft_extended_due_date' or + 'hard_extended_due_date', or if any datetime strings in `update_data` are invalid. + """ + legacy_changed = current_changed = False + try: + if ('extended_due_date' in update_data and not core_ut.datetimes_are_equal( + update_data['extended_due_date'], old_group.extended_due_date)): + legacy_changed = True + except(ValueError) as e: + raise exceptions.ValidationError({ + 'extended_due_date': str(e) + }) + + try: + if ('soft_extended_due_date' in update_data and not core_ut.datetimes_are_equal( + update_data['soft_extended_due_date'], old_group.soft_extended_due_date)): + current_changed = True + except(ValueError) as e: + raise exceptions.ValidationError({ + 'soft_extended_due_date': str(e) + }) + + try: + if ('hard_extended_due_date' in update_data and not core_ut.datetimes_are_equal( + update_data['hard_extended_due_date'], old_group.hard_extended_due_date)): + current_changed = True + except(ValueError) as e: + raise exceptions.ValidationError({ + 'hard_extended_due_date': str(e) + }) + + if legacy_changed and current_changed: + raise exceptions.ValidationError({ + 'extended_due_date': ( + 'Extended due date is deprecated and may not be used along with soft and' + + ' hard extended due dates') + }) + + elif legacy_changed: + return _deepcopy_dict_without_fields(update_data, 'soft_extended_due_date', + 'hard_extended_due_date') + elif current_changed: + return _deepcopy_dict_without_fields(update_data, 'extended_due_date') + else: + return _deepcopy_dict_without_fields(update_data) + + +def _deepcopy_dict_without_fields(dict_to_copy: dict, *keys: str) -> dict: + return { + key: copy.deepcopy(dict_to_copy[key]) for key in dict_to_copy + if key not in keys + }