diff --git a/formtools/wizard/forms.py b/formtools/wizard/forms.py index 11cf4222..70d219f5 100644 --- a/formtools/wizard/forms.py +++ b/formtools/wizard/forms.py @@ -1,9 +1,20 @@ from django import forms +from django.core.exceptions import ValidationError class ManagementForm(forms.Form): """ ``ManagementForm`` is used to keep track of the current wizard step. """ + def __init__(self, steps, **kwargs): + self.steps = steps + super().__init__(**kwargs) + template_name = "django/forms/p.html" # Remove when Django 5.0 is minimal version. current_step = forms.CharField(widget=forms.HiddenInput) + + def clean_current_step(self): + data = self.cleaned_data['current_step'] + if data not in self.steps: + raise ValidationError("Invalid step name.") + return data diff --git a/formtools/wizard/views.py b/formtools/wizard/views.py index 596d6226..fe37a741 100644 --- a/formtools/wizard/views.py +++ b/formtools/wizard/views.py @@ -45,12 +45,12 @@ def __repr__(self): @property def all(self): - "Returns the names of all steps/forms." + """Returns the names of all steps/forms.""" return list(self._wizard.get_form_list()) @property def count(self): - "Returns the total number of steps/forms in this the wizard." + """Returns the total number of steps/forms in this the wizard.""" return len(self.all) @property @@ -63,27 +63,27 @@ def current(self): @property def first(self): - "Returns the name of the first step." + """Returns the name of the first step.""" return self.all[0] @property def last(self): - "Returns the name of the last step." + """Returns the name of the last step.""" return self.all[-1] @property def next(self): - "Returns the next step." + """Returns the next step.""" return self._wizard.get_next_step() @property def prev(self): - "Returns the previous step." + """Returns the previous step.""" return self._wizard.get_prev_step() @property def index(self): - "Returns the index for the current step." + """Returns the index for the current step.""" return self._wizard.get_step_index() @property @@ -277,7 +277,7 @@ def post(self, *args, **kwargs): return self.render_goto_step(wizard_goto_step) # Check if form was refreshed - management_form = ManagementForm(self.request.POST, prefix=self.prefix) + management_form = ManagementForm(steps=self.steps.all, data=self.request.POST, prefix=self.prefix) if not management_form.is_valid(): raise SuspiciousOperation(_('ManagementForm data is missing or has been tampered.')) @@ -576,7 +576,7 @@ def get_context_data(self, form, **kwargs): context['wizard'] = { 'form': form, 'steps': self.steps, - 'management_form': ManagementForm(prefix=self.prefix, initial={ + 'management_form': ManagementForm(steps=self.steps.all, prefix=self.prefix, initial={ 'current_step': self.steps.current, }), } diff --git a/tests/wizard/wizardtests/tests.py b/tests/wizard/wizardtests/tests.py index 96a6c06d..49527080 100644 --- a/tests/wizard/wizardtests/tests.py +++ b/tests/wizard/wizardtests/tests.py @@ -73,6 +73,18 @@ def test_form_post_mgmt_data_missing(self): # view should return HTTP 400 Bad Request self.assertEqual(response.status_code, 400) + def test_invalid_step_data(self): + wizard_step_data = self.wizard_step_data[0].copy() + + # Replace the current step with invalid data + for key in list(wizard_step_data.keys()): + if "current_step" in key: + wizard_step_data[key] = "not-a-valid-step" + + response = self.client.post(self.wizard_url, wizard_step_data) + # view should return HTTP 400 Bad Request + self.assertEqual(response.status_code, 400) + def test_form_post_success(self): response = self.client.post(self.wizard_url, self.wizard_step_data[0]) wizard = response.context['wizard']