diff --git a/tests/apiv2/test_milestones.py b/tests/apiv2/test_milestones.py index db91c5ca9..7bcae053d 100644 --- a/tests/apiv2/test_milestones.py +++ b/tests/apiv2/test_milestones.py @@ -1,3 +1,4 @@ +from tests import randgen from tests.util import APITestCase from tracker import models from tracker.api import messages @@ -18,13 +19,19 @@ def _format_milestone(self, milestone, with_event=True): 'short_description': milestone.short_description, 'start': milestone.start, 'name': milestone.name, + 'run': milestone.run_id, 'visible': milestone.visible, } def setUp(self): super().setUp() + self.run = randgen.generate_runs(self.rand, self.event, 1, ordered=True)[0] self.public_milestone = models.Milestone.objects.create( - event=self.event, name='Public Milestone', amount=500.0, visible=True + event=self.event, + name='Public Milestone', + amount=500.0, + visible=True, + run=self.run, ) self.hidden_milestone = models.Milestone.objects.create( event=self.event, name='Hidden Milestone', amount=1500.0, visible=False @@ -43,7 +50,7 @@ def test_serializer(self): ) def test_fetch(self): - with self.subTest('happy path'), self.saveSnapshot(): + with self.saveSnapshot(): with self.subTest('public'): serialized = MilestoneSerializer(self.public_milestone) data = self.get_detail(self.public_milestone) @@ -83,6 +90,7 @@ def test_create(self): data={ 'name': 'New Milestone 2', 'amount': 1250, + 'run': self.run.pk, }, user=self.add_user, kwargs={'event_pk': self.event.pk}, @@ -134,6 +142,15 @@ def test_create(self): }, status_code=403, ) + self.post_new( + data={ + 'name': 'Mismatched Event Milestone', + 'amount': 100, + 'event': self.blank_event.pk, + 'run': self.run.pk, + }, + status_code=400, + ) self.post_new(user=None, status_code=403) with self.subTest('user with locked permission'): diff --git a/tests/test_api.py b/tests/test_api.py index 27e7ab40a..aece95e0d 100755 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -843,6 +843,7 @@ def format_milestone(cls, milestone, request): description=milestone.description, short_description=milestone.short_description, public=str(milestone), + run=milestone.run_id, visible=milestone.visible, ), model='tracker.milestone', diff --git a/tests/util.py b/tests/util.py index 10a3bc950..df12b5057 100644 --- a/tests/util.py +++ b/tests/util.py @@ -549,19 +549,19 @@ def _compare_value(self, key, expected, found): def _compare_model( self, expected_model, found_model, partial, prefix='', *, missing_ok=None ): - missing_ok = missing_ok or [] + missing_ok = set(missing_ok or []) self.assertIsInstance(found_model, dict, 'found_model was not a dict') self.assertIsInstance(expected_model, dict, 'expected_model was not a dict') + found_keys = set(found_model.keys()) + expected_keys = set(expected_model.keys()) if partial: extra_keys = [] else: - extra_keys = set(found_model.keys()) - set(expected_model.keys()) - missing_keys = ( - set(expected_model.keys()) - set(found_model.keys()) - set(missing_ok) - ) + extra_keys = found_keys - expected_keys + missing_keys = expected_keys - found_keys - missing_ok unequal_keys = [ k - for k in expected_model.keys() + for k in expected_keys if k in found_model and not isinstance(found_model[k], (list, dict)) and not self._compare_value(k, expected_model[k], found_model[k]) @@ -574,10 +574,11 @@ def _compare_model( found_model[k], partial, prefix=k, - missing_ok=missing_ok, + missing_ok=missing_ok + | {'event'}, # always ok to be missing 'event' on nested objects ), ) - for k in expected_model.keys() + for k in expected_keys if k in found_model and isinstance(found_model[k], dict) ] nested_objects = [n for n in nested_objects if n[1]] @@ -585,7 +586,7 @@ def _compare_model( f'{prefix}.' if prefix else '' + f'{k}': self._compare_lists( expected_model[k], found_model[k], partial, prefix=k ) - for k in expected_model.keys() + for k in expected_keys if k in found_model and isinstance(found_model[k], list) } for k, v in nested_list_keys.items(): diff --git a/tracker/admin/donation.py b/tracker/admin/donation.py index 675c7ca9d..b059d5114 100644 --- a/tracker/admin/donation.py +++ b/tracker/admin/donation.py @@ -401,7 +401,7 @@ def merge_donors(self, request, queryset): @register(models.Milestone) class MilestoneAdmin(EventLockedMixin, EventReadOnlyMixin, CustomModelAdmin): - autocomplete_fields = ('event',) + autocomplete_fields = ('event', 'run') search_fields = ('name', 'description', 'short_description') list_filter = ('event',) list_display = ('name', 'event', 'start', 'amount', 'visible') diff --git a/tracker/api/serializers.py b/tracker/api/serializers.py index fee6024ea..6858cd8c5 100644 --- a/tracker/api/serializers.py +++ b/tracker/api/serializers.py @@ -12,6 +12,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import serializers from rest_framework.exceptions import ErrorDetail, ValidationError +from rest_framework.relations import PrimaryKeyRelatedField from rest_framework.serializers import ListSerializer, as_serializer_error from rest_framework.utils import model_meta from rest_framework.validators import UniqueTogetherValidator @@ -946,6 +947,7 @@ class MilestoneSerializer( ): type = ClassNameField() event = EventSerializer() + run = PrimaryKeyRelatedField(queryset=SpeedRun.objects.all(), required=False) class Meta: model = Milestone @@ -956,6 +958,7 @@ class Meta: 'start', 'amount', 'name', + 'run', 'visible', 'description', 'short_description', diff --git a/tracker/migrations/0049_add_milestone_run.py b/tracker/migrations/0049_add_milestone_run.py new file mode 100644 index 000000000..4c300681f --- /dev/null +++ b/tracker/migrations/0049_add_milestone_run.py @@ -0,0 +1,19 @@ +# Generated by Django 5.1.4 on 2025-01-09 14:38 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('tracker', '0048_remove_old_ad_permission'), + ] + + operations = [ + migrations.AddField( + model_name='milestone', + name='run', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, to='tracker.speedrun'), + ), + ] diff --git a/tracker/models/donation.py b/tracker/models/donation.py index e5f4121d7..1f7bd2323 100644 --- a/tracker/models/donation.py +++ b/tracker/models/donation.py @@ -663,6 +663,13 @@ class Milestone(models.Model): validators=[positive, nonzero], ) name = models.CharField(max_length=64) + run = models.ForeignKey( + 'tracker.SpeedRun', + blank=True, + null=True, + default=None, + on_delete=models.SET_NULL, + ) visible = models.BooleanField(default=False) description = models.TextField(max_length=1024, blank=True) short_description = models.TextField( @@ -675,6 +682,8 @@ class Milestone(models.Model): def clean(self): if self.start >= self.amount: raise ValidationError({'start': 'start must be less than amount'}) + if self.run_id and self.run.event_id != self.event_id: + raise ValidationError({'run': 'Run does not belong to that event'}) def __str__(self): return f'{self.event.name} -- {self.name} -- {self.amount}'