diff --git a/measures/editors.py b/measures/editors.py new file mode 100644 index 000000000..0a55451f2 --- /dev/null +++ b/measures/editors.py @@ -0,0 +1,118 @@ +from typing import Dict +from typing import List +from typing import Type + +from django.db.transaction import atomic + +from workbaskets import models as workbasket_models +from measures import models as measure_models +from common.util import TaricDateRange +from common.validators import UpdateType +from common.models.utils import override_current_transaction +from measures.util import update_measure_components +from measures.util import update_measure_condition_components +from measures.util import update_measure_excluded_geographical_areas +from measures.util import update_measure_footnote_associations + + +class MeasuresEditor: + """Utility class used to edit measures from measures wizard accumulated + data.""" + + workbasket: Type["workbasket_models.WorkBasket"] + """The workbasket with which created measures will be associated.""" + + selected_measures: List + """ The measures in which the edits will apply to.""" + + data: Dict + """Validated, cleaned and accumulated data created by the Form instances of + `MeasureEditWizard`.""" + + def __init__( + self, + workbasket: Type["workbasket_models.WorkBasket"], + selected_measures: List, + data: Dict, + ): + self.workbasket = workbasket + self.selected_measures = selected_measures + self.data = data + + @atomic + def edit_measures(self) -> List["measure_models.Measure"]: + """ + Returns a list of the edited measures. + + `data` must be a dictionary + of the accumulated cleaned / validated data created from the + `MeasureEditWizard`. + """ + + with override_current_transaction( + transaction=self.workbasket.current_transaction, + ): + new_start_date = self.data.get("start_date", None) + new_end_date = self.data.get("end_date", False) + new_quota_order_number = self.data.get("order_number", None) + new_generating_regulation = self.data.get("generating_regulation", None) + new_duties = self.data.get("duties", None) + new_exclusions = [ + e["excluded_area"] + for e in self.data.get("formset-geographical_area_exclusions", []) + ] + + edited_measures = [] + + if self.selected_measures: + for measure in self.selected_measures: + new_measure = measure.new_version( + workbasket=self.workbasket, + update_type=UpdateType.UPDATE, + valid_between=TaricDateRange( + lower=( + new_start_date + if new_start_date + else measure.valid_between.lower + ), + upper=( + new_end_date + if new_end_date + else measure.valid_between.upper + ), + ), + order_number=( + new_quota_order_number + if new_quota_order_number + else measure.order_number + ), + generating_regulation=( + new_generating_regulation + if new_generating_regulation + else measure.generating_regulation + ), + ) + update_measure_components( + measure=new_measure, + duties=new_duties, + workbasket=self.workbasket, + ) + update_measure_condition_components( + measure=new_measure, + workbasket=self.workbasket, + ) + update_measure_excluded_geographical_areas( + edited="geographical_area_exclusions" + in self.data.get("fields_to_edit", []), + measure=new_measure, + exclusions=new_exclusions, + workbasket=self.workbasket, + ) + update_measure_footnote_associations( + measure=new_measure, + workbasket=self.workbasket, + ) + + edited_measures.append(new_measure.id) + + return edited_measures diff --git a/measures/forms/wizard.py b/measures/forms/wizard.py index 2e68915ec..d83142df3 100644 --- a/measures/forms/wizard.py +++ b/measures/forms/wizard.py @@ -781,7 +781,7 @@ def __init__(self, *args, **kwargs): ) -class MeasureStartDateForm(forms.Form): +class MeasureStartDateForm(forms.Form, SerializableFormMixin): start_date = DateInputFieldFixed( label="Start date", help_text="For example, 27 3 2008", @@ -819,8 +819,34 @@ def clean(self): return cleaned_data + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter( + pk__in=serialized_selected_measures_pks + ) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs + -class MeasureEndDateForm(forms.Form): +class MeasureEndDateForm(forms.Form, SerializableFormMixin): end_date = DateInputFieldFixed( label="End date", help_text="For example, 27 3 2008", @@ -861,8 +887,34 @@ def clean(self): return cleaned_data + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter( + pk__in=serialized_selected_measures_pks + ) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs + -class MeasureRegulationForm(forms.Form): +class MeasureRegulationForm(forms.Form, SerializableFormMixin): generating_regulation = AutoCompleteField( label="Regulation ID", help_text="Select the regulation which provides the legal basis for the measures.", @@ -888,8 +940,34 @@ def __init__(self, *args, **kwargs): ), ) + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter( + pk__in=serialized_selected_measures_pks + ) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs + -class MeasureDutiesForm(forms.Form): +class MeasureDutiesForm(forms.Form, SerializableFormMixin): duties = forms.CharField( label="Duties", help_text="Enter the duty that applies to the measures.", @@ -932,6 +1010,32 @@ def clean(self): return cleaned_data + @classmethod + def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: + selected_measures = kwargs.get("selected_measures") + selected_measures_pks = [] + for measure in selected_measures: + selected_measures_pks.append(measure.id) + + serializable_kwargs = { + "selected_measures": selected_measures_pks, + } + + return serializable_kwargs + + @classmethod + def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: + serialized_selected_measures_pks = form_kwargs.get("selected_measures") + deserialized_selected_measures = models.Measure.objects.filter( + pk__in=serialized_selected_measures_pks + ) + + kwargs = { + "selected_measures": deserialized_selected_measures, + } + + return kwargs + class MeasureGeographicalAreaExclusionsForm(forms.Form): excluded_area = forms.ModelChoiceField( @@ -965,7 +1069,7 @@ def __init__(self, *args, **kwargs): ) -class MeasureGeographicalAreaExclusionsFormSet(FormSet): +class MeasureGeographicalAreaExclusionsFormSet(FormSet, SerializableFormMixin): """Allows editing the geographical area exclusions of multiple measures in `MeasureEditWizard`.""" diff --git a/measures/migrations/0017_measuresbulkeditor.py b/measures/migrations/0017_measuresbulkeditor.py new file mode 100644 index 000000000..9ea8940c4 --- /dev/null +++ b/measures/migrations/0017_measuresbulkeditor.py @@ -0,0 +1,84 @@ +# Generated by Django 4.2.15 on 2024-09-02 16:06 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django_fsm +import measures.models.bulk_processing + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("workbaskets", "0008_datarow_dataupload"), + ("measures", "0016_measuresbulkcreator"), + ] + + operations = [ + migrations.CreateModel( + name="MeasuresBulkEditor", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "task_id", + models.CharField(blank=True, max_length=50, null=True, unique=True), + ), + ( + "processing_state", + django_fsm.FSMField( + choices=[ + ("AWAITING_PROCESSING", "Awaiting processing"), + ("CURRENTLY_PROCESSING", "Currently processing"), + ("SUCCESSFULLY_PROCESSED", "Successfully processed"), + ("FAILED_PROCESSING", "Failed processing"), + ("CANCELLED", "Cancelled"), + ], + db_index=True, + default="AWAITING_PROCESSING", + editable=False, + max_length=50, + protected=True, + ), + ), + ( + "successfully_processed_count", + models.PositiveIntegerField(default=0), + ), + ("form_data", models.JSONField()), + ("form_kwargs", models.JSONField()), + ("selected_measures", models.JSONField()), + ( + "user", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ( + "workbasket", + models.ForeignKey( + editable=False, + null=True, + on_delete=measures.models.bulk_processing.REVOKE_TASKS_AND_SET_NULL, + to="workbaskets.workbasket", + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/measures/models/__init__.py b/measures/models/__init__.py index a4c641802..f49a0d12e 100644 --- a/measures/models/__init__.py +++ b/measures/models/__init__.py @@ -1,5 +1,6 @@ from measures.models.bulk_processing import BulkProcessor from measures.models.bulk_processing import MeasuresBulkCreator +from measures.models.bulk_processing import MeasuresBulkEditor from measures.models.bulk_processing import ProcessingState from measures.models.tracked_models import AdditionalCodeTypeMeasureType from measures.models.tracked_models import DutyExpression @@ -23,6 +24,7 @@ # - Classes exported from bulk_processing.py. "BulkProcessor", "MeasuresBulkCreator", + "MeasuresBulkEditor", "ProcessingState", # - Classes exported from tracked_model.py. "AdditionalCodeTypeMeasureType", diff --git a/measures/models/bulk_processing.py b/measures/models/bulk_processing.py index aef1d8763..908e90d98 100644 --- a/measures/models/bulk_processing.py +++ b/measures/models/bulk_processing.py @@ -18,6 +18,7 @@ from common.models.mixins import TimestampedMixin from common.models.utils import override_current_transaction from measures.models.tracked_models import Measure +from measures.editors import MeasuresEditor logger = logging.getLogger(__name__) @@ -414,3 +415,168 @@ def _log_form_errors(self, form_class, form_or_formset) -> None: for form_errors in errors: for error_key, error_values in form_errors.items(): logger.error(f"{error_key}: {error_values}") + + +class MeasuresBulkEditorManager(models.Manager): + """Model Manager for MeasuresBulkEditor models.""" + + def create( + self, + form_data: Dict, + form_kwargs: Dict, + workbasket, + user, + selected_measures, + **kwargs, + ) -> "MeasuresBulkCreator": + """Create and save an instance of MeasuresBulkEditor.""" + + return super().create( + form_data=form_data, + form_kwargs=form_kwargs, + workbasket=workbasket, + user=user, + selected_measures=selected_measures, + **kwargs, + ) + + +class MeasuresBulkEditor(BulkProcessor): + """ + Model class used to bulk edit Measures instances from serialized form + data. + The stored form data is serialized and deserialized by Forms that subclass + SerializableFormMixin. + """ + + objects = MeasuresBulkEditorManager() + + form_data = models.JSONField() + """Dictionary of all Form.data, used to reconstruct bound Form instances as + if the form data had been sumbitted by the user within the measure wizard + process.""" + + form_kwargs = models.JSONField() + """Dictionary of all form init data, excluding a form's `data` param (which + is preserved via this class's `form_data` attribute).""" + + selected_measures = models.JSONField() + """List of all measures that have been selected for bulk editing.""" + + workbasket = models.ForeignKey( + "workbaskets.WorkBasket", + on_delete=REVOKE_TASKS_AND_SET_NULL, + null=True, + editable=False, + ) + """The workbasket with which created measures are associated.""" + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=SET_NULL, + null=True, + editable=False, + ) + """The user who submitted the task to create measures.""" + + def schedule_task(self) -> AsyncResult: + """Implementation of base class method.""" + + from measures.tasks import bulk_edit_measures + + async_result = bulk_edit_measures.apply_async( + kwargs={ + "measures_bulk_editor_pk": self.pk, + }, + countdown=1, + ) + self.task_id = async_result.id + self.save() + + logger.info( + f"Measure bulk edit scheduled on task with ID {async_result.id}" + f"using MeasuresBulkEditor.pk={self.pk}.", + ) + + return async_result + + @atomic + def edit_measures(self) -> Iterable[Measure]: + logger.info("INSIDE EDIT MEASURES - BULK PROCESSING") + + with override_current_transaction( + transaction=self.workbasket.current_transaction, + ): + cleaned_data = self.get_forms_cleaned_data() + deserialized_selected_measures = Measure.objects.filter( + pk__in=self.selected_measures + ) + + measures_editor = MeasuresEditor( + self.workbasket, deserialized_selected_measures, cleaned_data + ) + return measures_editor.edit_measures() + + def get_forms_cleaned_data(self) -> Dict: + """ + Returns a merged dictionary of all Form cleaned_data. + + If a Form's data contains a `FormSet`, the key will be prefixed with + "formset-" and contain a list of the formset cleaned_data dictionaries. + + If form validation errors are encountered when constructing cleaned + data, then this function raises Django's `ValidationError` exception. + """ + all_cleaned_data = {} + + from measures.views import MeasureEditWizard + + for form_key, form_class in MeasureEditWizard.data_form_list: + + if form_key not in self.form_data: + # Forms are conditionally included during step processing - see + # `MeasureEditWizard.show_step()` for details. + continue + + data = self.form_data[form_key] + kwargs = form_class.deserialize_init_kwargs(self.form_kwargs[form_key]) + + form = form_class(data=data, **kwargs) + + if not form.is_valid(): + self._log_form_errors(form_class=form_class, form_or_formset=form) + raise ValidationError( + f"{form_class.__name__} has {len(form.errors)} errors.", + ) + + if isinstance(form.cleaned_data, (tuple, list)): + all_cleaned_data[f"formset-{form_key}"] = form.cleaned_data + else: + all_cleaned_data.update(form.cleaned_data) + + return all_cleaned_data + + def _log_form_errors(self, form_class, form_or_formset) -> None: + """Output errors associated with a Form or Formset instance, handling + output for each instance type in a uniform manner.""" + + logger.error( + f"MeasuresBulkEditor.edit_measures() - " + f"{form_class.__name__} has {len(form_or_formset.errors)} errors.", + ) + + # Form.errors is a dictionary of errors, but FormSet.errors is a + # list of dictionaries of Form.errors. Access their errors in + # a uniform manner. + errors = [] + + if isinstance(form_or_formset, BaseFormSet): + errors = [ + {"formset_errors": form_or_formset.non_form_errors()}, + ] + form_or_formset.errors + else: + errors = [form_or_formset.errors] + + for form_errors in errors: + for error_key, error_values in form_errors.items(): + logger.error(f"{error_key}: {error_values}") diff --git a/measures/tasks.py b/measures/tasks.py index ed61c3f6d..755e64fc8 100644 --- a/measures/tasks.py +++ b/measures/tasks.py @@ -2,6 +2,7 @@ from common.celery import app from measures.models import MeasuresBulkCreator +from measures.models import MeasuresBulkEditor logger = logging.getLogger(__name__) @@ -43,3 +44,36 @@ def bulk_create_measures(measures_bulk_creator_pk: int) -> None: f"succeeded but created no measures in " f"WorkBasket({measures_bulk_creator.workbasket.pk}).", ) + + +@app.task +def bulk_edit_measures(measures_bulk_editor_pk: int) -> None: + """Bulk edit measures from serialized measures form data saved within an + instance of MeasuresBulkEditor.""" + + measures_bulk_editor = MeasuresBulkEditor.objects.get(pk=measures_bulk_editor_pk) + measures_bulk_editor.begin_processing() + measures_bulk_editor.save() + + try: + measures = measures_bulk_editor.edit_measures() + except Exception as e: + measures_bulk_editor.processing_failed() + measures_bulk_editor.save() + logger.error( + f"MeasuresBulkCreator({measures_bulk_editor.pk}) task failed " + f"attempting to edit measures in " + f"WorkBasket({measures_bulk_editor.workbasket.pk}).", + ) + raise e + + measures_bulk_editor.processing_succeeded() + measures_bulk_editor.successfully_processed_count = len(measures) + measures_bulk_editor.save() + + if measures: + logger.info( + f"MeasuresBulkEditor({measures_bulk_editor.pk}) task " + f"succeeded in editing {len(measures)} Measures in " + f"WorkBasket({measures_bulk_editor.workbasket.pk}).", + ) diff --git a/measures/tests/conftest.py b/measures/tests/conftest.py index 474d0bf16..df5d8e12c 100644 --- a/measures/tests/conftest.py +++ b/measures/tests/conftest.py @@ -325,11 +325,34 @@ def mock_request(rf, valid_user, valid_user_client): return request +@pytest.fixture() +def measure_edit_start_date_form_data(): + return { + "start_date_0": 1, + "start_date_1": 1, + "start_date_2": 2023, + } + + +@pytest.fixture() +def measure_edit_end_date_form_data(): + return { + "end_date_0": 2, + "end_date_1": 2, + "end_date_2": 2026, + } + + @pytest.fixture() def measure_regulation_id_form_data(): return {"generating_regulation": factories.RegulationFactory.create().pk} +@pytest.fixture() +def measure_edit_regulation_form_data(): + return {"generating_regulation": factories.RegulationFactory.create()} + + @pytest.fixture() def measure_details_form_data(date_ranges): return { @@ -487,6 +510,11 @@ def measure_geo_area_geo_group_exclusions_form_data(erga_omnes): } +@pytest.fixture() +def measure_edit_duties_form_data(): + return {"duties": "4%"} + + @pytest.fixture() def simple_measures_bulk_creator( user_empty_workbasket, @@ -503,9 +531,34 @@ def simple_measures_bulk_creator( @pytest.fixture() -def mocked_schedule_apply_async(): +def mocked_create_schedule_apply_async(): with patch( "measures.tasks.bulk_create_measures.apply_async", return_value=MagicMock(id=faker.Faker().uuid4()), ) as apply_async_mock: yield apply_async_mock + + +@pytest.fixture() +def simple_measures_bulk_editor( + user_empty_workbasket, + approved_transaction, +): + from measures.tests.factories import MeasuresBulkEditorFactory + + return MeasuresBulkEditorFactory.create( + form_data={}, + form_kwargs={}, + workbasket=user_empty_workbasket, + selected_measures=[], + user=None, + ) + + +@pytest.fixture() +def mocked_edit_schedule_apply_async(): + with patch( + "measures.tasks.bulk_edit_measures.apply_async", + return_value=MagicMock(id=faker.Faker().uuid4()), + ) as apply_async_mock: + yield apply_async_mock diff --git a/measures/tests/factories.py b/measures/tests/factories.py index 082d9b814..6b8301fad 100644 --- a/measures/tests/factories.py +++ b/measures/tests/factories.py @@ -106,3 +106,15 @@ class Meta: workbasket = factory.SubFactory(factories.WorkBasketFactory) form_data = {} form_kwargs = {} + + +class MeasuresBulkEditorFactory(factory.django.DjangoModelFactory): + class Meta: + model = "measures.MeasuresBulkEditor" + + user = factory.SubFactory(factories.UserFactory) + created_at = factory.Faker("date_object") + workbasket = factory.SubFactory(factories.WorkBasketFactory) + form_data = {} + form_kwargs = {} + selected_measures = [] diff --git a/measures/tests/test_bulk_processing.py b/measures/tests/test_bulk_processing.py index 189cca0ac..0b29330df 100644 --- a/measures/tests/test_bulk_processing.py +++ b/measures/tests/test_bulk_processing.py @@ -9,9 +9,12 @@ from common.tests import factories from common.util import TaricDateRange from common.validators import ApplicabilityCode +from measures import forms from measures.models import MeasuresBulkCreator +from measures.models import MeasuresBulkEditor from measures.models import ProcessingState from measures.tests.factories import MeasuresBulkCreatorFactory +from measures.tests.factories import MeasuresBulkEditorFactory from measures.validators import MeasureExplosionLevel pytestmark = pytest.mark.django_db @@ -19,14 +22,14 @@ def test_schedule_task_bulk_measures_create( simple_measures_bulk_creator, - mocked_schedule_apply_async, + mocked_create_schedule_apply_async, ): - """Test that calling MeasuresBulkCreator.shedule() correctly schedules a + """Test that calling MeasuresBulkCreator.schedule() correctly schedules a Celery task.""" simple_measures_bulk_creator.schedule_task() - mocked_schedule_apply_async.assert_called_once_with( + mocked_create_schedule_apply_async.assert_called_once_with( kwargs={ "measures_bulk_creator_pk": simple_measures_bulk_creator.pk, }, @@ -34,9 +37,26 @@ def test_schedule_task_bulk_measures_create( ) +def test_schedule_task_bulk_measures_edit( + simple_measures_bulk_editor, + mocked_edit_schedule_apply_async, +): + """Test that calling MeasuresBulkCreator.schedule() correctly schedules a + Celery task.""" + + simple_measures_bulk_editor.schedule_task() + + mocked_edit_schedule_apply_async.assert_called_once_with( + kwargs={ + "measures_bulk_editor_pk": simple_measures_bulk_editor.pk, + }, + countdown=ANY, + ) + + def test_REVOKE_TASKS_AND_SET_NULL( simple_measures_bulk_creator, - mocked_schedule_apply_async, + mocked_create_schedule_apply_async, ): """Test that deleting an object, referenced by a ForeignKey field that has `on_delete=BulkProcessor.REVOKE_TASKS_AND_SET_NULL`, correctly revokes any @@ -57,7 +77,7 @@ def test_REVOKE_TASKS_AND_SET_NULL( def test_cancel_task( simple_measures_bulk_creator, - mocked_schedule_apply_async, + mocked_create_schedule_apply_async, ): """Test BulkProcessor.cancel_task() behaviours correctly apply.""" @@ -167,6 +187,77 @@ def test_bulk_creator_get_forms_cleaned_data( } +# Run the form and get the form data from the sync done +@patch("measures.parsers.DutySentenceParser") +def test_bulk_editor_get_forms_cleaned_data( + mock_duty_sentence_parser, + user_empty_workbasket, + duty_sentence_parser, +): + + mock_duty_sentence_parser.return_value = duty_sentence_parser + + geo_area1 = factories.GeographicalAreaFactory.create() + geo_area2 = factories.GeographicalAreaFactory.create() + measure_1 = factories.MeasureFactory.create() + measure_2 = factories.MeasureFactory.create() + measure_3 = factories.MeasureFactory.create() + regulation = factories.RegulationFactory() + order_number = factories.QuotaOrderNumberFactory.create() + + selected_measures = [measure_1.pk, measure_2.pk, measure_3.pk] + + form_data = { + "start_date": { + "start_date_0": 1, + "start_date_1": 1, + "start_date_2": 2023, + }, + "end_date": { + "end_date_0": 2, + "end_date_1": 2, + "end_date_2": 2026, + }, + "quota_order_number": {"order_number": order_number.pk}, + "regulation": {"generating_regulation": regulation.pk}, + "duties": {"duties": "4%"}, + "geographical_area_exclusions": { + "form-0-excluded_area": geo_area1.pk, + "form-1-excluded_area": geo_area2.pk, + }, + } + + form_kwargs = { + "start_date": {"selected_measures": selected_measures}, + "end_date": {"selected_measures": selected_measures}, + "quota_order_number": {}, + "regulation": {"selected_measures": selected_measures}, + "duties": {"selected_measures": selected_measures}, + "geographical_area_exclusions": {}, + } + + mock_bulk_editor = MeasuresBulkEditorFactory.create( + form_data=form_data, + form_kwargs=form_kwargs, + workbasket=user_empty_workbasket, + selected_measures=selected_measures, + user=None, + ) + with override_current_transaction(user_empty_workbasket.current_transaction): + data = mock_bulk_editor.get_forms_cleaned_data() + assert data == { + "start_date": datetime.date(2023, 1, 1), + "end_date": datetime.date(2026, 2, 2), + "generating_regulation": regulation, + "order_number": order_number, + "duties": "4%", + "formset-geographical_area_exclusions": [ + {"excluded_area": geo_area1, "DELETE": False}, + {"excluded_area": geo_area2, "DELETE": False}, + ], + } + + @patch("measures.parsers.DutySentenceParser") @patch("measures.forms.wizard.LarkDutySentenceParser") def test_bulk_creator_get_forms_cleaned_data_errors( @@ -225,3 +316,112 @@ def test_bulk_creator_get_forms_cleaned_data_errors( with override_current_transaction(user_empty_workbasket.current_transaction): with pytest.raises(ValidationError): mock_bulk_creator.get_forms_cleaned_data() + + +@patch("measures.parsers.DutySentenceParser") +def test_bulk_editor_get_forms_cleaned_data_errors( + mock_duty_sentence_parser, + user_empty_workbasket, + duty_sentence_parser, +): + mock_duty_sentence_parser.return_value = duty_sentence_parser + + measure_1 = factories.MeasureFactory.create() + measure_2 = factories.MeasureFactory.create() + measure_3 = factories.MeasureFactory.create() + + selected_measures = [measure_1.pk, measure_2.pk, measure_3.pk] + + form_data = { + "start_date": { + "start_date_0": "", + "start_date_1": "", + "start_date_2": "", + }, + "end_date": { + "end_date_0": "", + "end_date_1": "", + "end_date_2": "", + }, + "quota_order_number": {"order_number": ""}, + "regulation": {"generating_regulation": ""}, + "duties": {"duties": ""}, + "geographical_area_exclusions": { + "form-0-excluded_area": "", + "form-1-excluded_area": "", + }, + } + + form_kwargs = { + "start_date": {"selected_measures": selected_measures}, + "end_date": {"selected_measures": selected_measures}, + "quota_order_number": {}, + "regulation": {"selected_measures": selected_measures}, + "duties": {"selected_measures": selected_measures}, + "geographical_area_exclusions": {}, + } + + mock_bulk_editor = MeasuresBulkEditorFactory.create( + form_data=form_data, + form_kwargs=form_kwargs, + workbasket=user_empty_workbasket, + selected_measures=selected_measures, + user=None, + ) + + with override_current_transaction(user_empty_workbasket.current_transaction): + with pytest.raises(ValidationError): + mock_bulk_editor.get_forms_cleaned_data() + + +@pytest.mark.parametrize( + "form_class, form_data, expected_error", + [ + ( + forms.MeasureStartDateForm, + { + "start_date": { + "start_date_0": "", + "start_date_1": "", + "start_date_2": "", + } + }, + "Enter the day, month and year", + ), + ( + forms.MeasureRegulationForm, + {"regulation": {"generating_regulation": ""}}, + "This field is required", + ), + ], + ids=[ + "measure_edit_start_date_form", + "measure_edit_regulation_form", + ], +) +def test_bulk_editor_log_form_errors_displays_detailed_error( + form_class, + form_data, + expected_error, + caplog, +): + + # Create measures for the form kwargs + measure_1 = factories.MeasureFactory.create() + measure_2 = factories.MeasureFactory.create() + measure_3 = factories.MeasureFactory.create() + + form_kwargs = {"selected_measures": [measure_1, measure_2, measure_3]} + form = form_class(data=form_data, **form_kwargs) + + mock_bulk_editor = MeasuresBulkEditorFactory.create() + + import logging + + # Ensure logging propagation is enabled else log messages won't + # reach this module. + logger = logging.getLogger("measures") + logger.propagate = True + + mock_bulk_editor._log_form_errors(form_class=form_class, form_or_formset=form) + assert expected_error in caplog.text diff --git a/measures/tests/test_forms.py b/measures/tests/test_forms.py index ac6123392..294fe9ea5 100644 --- a/measures/tests/test_forms.py +++ b/measures/tests/test_forms.py @@ -1923,3 +1923,118 @@ def test_measure_forms_geo_area_serialize_deserialize(form_data, request): assert deserialized_form.is_valid() assert type(deserialized_form) == forms.MeasureGeographicalAreaForm assert deserialized_form.data == form.data + + +@pytest.mark.parametrize( + "form_class, form_data_fixture, has_form_kwargs", + [ + ( + forms.MeasureStartDateForm, + "measure_edit_start_date_form_data", + True, + ), + ( + forms.MeasureEndDateForm, + "measure_edit_end_date_form_data", + True, + ), + ( + forms.MeasureQuotaOrderNumberForm, + "measure_quota_order_number_form_data", + False, + ), + ( + forms.MeasureRegulationForm, + "measure_edit_regulation_form_data", + True, + ), + ( + forms.MeasureDutiesForm, + "measure_edit_duties_form_data", + True, + ), + ], + ids=[ + "measure_edit_start_date_form", + "measure_edit_end_date_form", + "measure_edit_quota_order_number_form", + "measure_edit_regulation_form", + "measure_edit_duties_form", + ], +) +def test_simple_measure_edit_forms_serialize_deserialize( + form_class, + form_data_fixture, + has_form_kwargs, + date_ranges, + request, + duty_sentence_parser, +): + """Test that the EditMeasure simple forms that use the + SerializableFormMixin behave correctly and as expected.""" + + # Create some measures to apply this data to, for the kwargs + quota_order_number = factories.QuotaOrderNumberFactory() + regulation = factories.RegulationFactory.create() + selected_measures = factories.MeasureFactory.create_batch( + 4, + valid_between=date_ranges.normal, + order_number=quota_order_number, + generating_regulation=regulation, + ) + + # Check the forms are valid on data submission + form_data = request.getfixturevalue(form_data_fixture) + form_kwarg_data = {} + + if has_form_kwargs: + form_kwarg_data = { + "selected_measures": selected_measures, + } + + form = form_class(form_data, **form_kwarg_data) + assert form.is_valid() + + # Create the serialized data + serialized_data = form.serializable_data() + serialized_data_kwargs = {} + + if has_form_kwargs: + serialized_data_kwargs = form.serializable_init_kwargs(form_kwarg_data) + + # Deserialize the kwargs + deserialized_form_kwargs = form.deserialize_init_kwargs( + serialized_data_kwargs, + ) + + # Make a form from serialized data.Check the form is the right type, valid, and the data that went in is the same that comes out + deserialized_form = form_class( + data=serialized_data, + **deserialized_form_kwargs, + ) + + # Check the form is the right type, valid, and the data that went in is the same that comes out + assert type(deserialized_form) == form_class + assert deserialized_form.is_valid() + assert deserialized_form.data == form_data + + +def test_measure_edit_forms_geo_area_exclusions_serialize_deserialize(): + geo_area1 = factories.GeographicalAreaFactory.create() + geo_area2 = factories.GeographicalAreaFactory.create() + + form_data = {"form-0-excluded_area": geo_area1, "form-1-excluded_area": geo_area2} + with override_current_transaction(Transaction.objects.last()): + form = forms.MeasureGeographicalAreaExclusionsFormSet( + form_data, + ) + assert form.is_valid() + + serializable_form_data = form.serializable_data() + + deserialized_form = forms.MeasureGeographicalAreaExclusionsFormSet( + data=serializable_form_data, + ) + assert deserialized_form.is_valid() + assert type(deserialized_form) == forms.MeasureGeographicalAreaExclusionsFormSet + assert deserialized_form.data == form.data diff --git a/measures/tests/test_views.py b/measures/tests/test_views.py index 28dfda8cb..f6faf38b6 100644 --- a/measures/tests/test_views.py +++ b/measures/tests/test_views.py @@ -58,11 +58,11 @@ @pytest.fixture() def mocked_diff_components(): - """Mocks `diff_components()` inside `update_measure_components()` in + """Mocks `diff_components()` inside `update_measure_components()` that is called in `MeasureEditWizard` to prevent parsing errors where test measures lack a duty sentence.""" with patch( - "measures.views.MeasureEditWizard.update_measure_components", + "measures.editors.update_measure_components", ) as update_measure_components: yield update_measure_components @@ -1834,6 +1834,7 @@ def test_measuretype_api_list_view(valid_user_client): ) +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_start_and_end_date_edit_functionality( valid_user_client, user_workbasket, @@ -1966,6 +1967,7 @@ def test_multiple_measure_start_and_end_date_edit_functionality( ), ], ) +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_edit_single_form_functionality( step, form_data, @@ -2039,6 +2041,7 @@ def test_multiple_measure_edit_single_form_functionality( assert reduce(getattr, updated_attribute.split("."), measure) == expected_data +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_edit_only_regulation( valid_user_client, user_workbasket, @@ -2264,6 +2267,7 @@ def test_measure_list_selected_measures_list(valid_user_client): assert not measure_ids_in_table.difference(selected_measures_ids) +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_edit_only_quota_order_number( valid_user_client, user_workbasket, @@ -2335,6 +2339,7 @@ def test_multiple_measure_edit_only_quota_order_number( assert measure.order_number == quota_order_number +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_edit_only_duties( valid_user_client, user_workbasket, @@ -2406,6 +2411,7 @@ def test_multiple_measure_edit_only_duties( assert measure.duty_sentence == duties +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_edit_preserves_footnote_associations( valid_user_client, user_workbasket, @@ -2482,6 +2488,7 @@ def test_multiple_measure_edit_preserves_footnote_associations( assert footnote in expected_footnotes +@override_settings(MEASURES_ASYNC_EDIT=False) def test_multiple_measure_edit_geographical_area_exclusions( valid_user_client, user_workbasket, diff --git a/measures/util.py b/measures/util.py index bf00efd6a..80849bde4 100644 --- a/measures/util.py +++ b/measures/util.py @@ -1,13 +1,21 @@ import decimal from datetime import date from math import floor -from typing import Type from common.models import TrackedModel from common.models.transactions import Transaction from common.validators import UpdateType -from measures.models import MeasureComponent -from workbaskets.models import WorkBasket + +from geo_areas.models import GeographicalArea +from geo_areas.utils import get_all_members_of_geo_groups +from measures import models as measure_models +from typing import List +from typing import Type +from workbaskets import models as workbasket_models + +import logging + +logger = logging.getLogger(__name__) def convert_eur_to_gbp(amount: str, eur_gbp_conversion_rate: float) -> str: @@ -29,9 +37,9 @@ def diff_components( instance, duty_sentence: str, start_date: date, - workbasket: WorkBasket, + workbasket: "workbasket_models.WorkBasket", transaction: Type[Transaction], - component_output: Type[TrackedModel] = MeasureComponent, + component_output: type = None, reverse_attribute: str = "component_measure", ): """ @@ -49,6 +57,11 @@ def diff_components( """ from measures.parsers import DutySentenceParser + # We add in the component output type here as otherwise we run into circular import issues. + component_output = ( + measure_models.MeasureComponent if not component_output else component_output + ) + parser = DutySentenceParser.create( start_date, component_output=component_output, @@ -91,3 +104,103 @@ def diff_components( update_type=UpdateType.DELETE, transaction=workbasket.new_transaction(), ) + + +def update_measure_components( + duties: str, + workbasket: "workbasket_models.WorkBasket", + measure: "measure_models.Measure", +): + """Updates the measure components associated to the measure.""" + diff_components( + instance=measure, + duty_sentence=duties if duties else measure.duty_sentence, + start_date=measure.valid_between.lower, + workbasket=workbasket, + transaction=workbasket.current_transaction, + ) + + +def update_measure_condition_components( + measure: "measure_models.Measure", + workbasket: "workbasket_models.WorkBasket", +): + """Updates the measure condition components associated to the + measure.""" + conditions = measure.conditions.current() + for condition in conditions: + condition.new_version( + dependent_measure=measure, + workbasket=workbasket, + ) + + +def update_measure_excluded_geographical_areas( + edited: bool, + measure: "measure_models.Measure", + exclusions: List[GeographicalArea], + workbasket: "workbasket_models.WorkBasket", +): + """Updates the excluded geographical areas associated to the measure.""" + existing_exclusions = measure.exclusions.current() + + # Update any exclusions to new measure version + if not edited: + for exclusion in existing_exclusions: + exclusion.new_version( + modified_measure=measure, + workbasket=workbasket, + ) + return + + new_excluded_areas = get_all_members_of_geo_groups( + validity=measure.valid_between, + geo_areas=exclusions, + ) + + for geo_area in new_excluded_areas: + existing_exclusion = existing_exclusions.filter( + excluded_geographical_area=geo_area, + ).first() + if existing_exclusion: + existing_exclusion.new_version( + modified_measure=measure, + workbasket=workbasket, + ) + else: + measure_models.MeasureExcludedGeographicalArea.objects.create( + modified_measure=measure, + excluded_geographical_area=geo_area, + update_type=UpdateType.CREATE, + transaction=workbasket.new_transaction(), + ) + + removed_excluded_areas = { + e.excluded_geographical_area for e in existing_exclusions + }.difference(set(exclusions)) + + exclusions_to_remove = [ + existing_exclusions.get(excluded_geographical_area__id=geo_area.id) + for geo_area in removed_excluded_areas + ] + + for exclusion in exclusions_to_remove: + exclusion.new_version( + update_type=UpdateType.DELETE, + modified_measure=measure, + workbasket=workbasket, + ) + + +def update_measure_footnote_associations(measure, workbasket): + """Updates the footnotes associated to the measure.""" + footnote_associations = ( + measure_models.FootnoteAssociationMeasure.objects.current().filter( + footnoted_measure__sid=measure.sid, + ) + ) + for fa in footnote_associations: + fa.new_version( + footnoted_measure=measure, + workbasket=workbasket, + ) diff --git a/measures/views/mixins.py b/measures/views/mixins.py index bab973754..feb0d21ac 100644 --- a/measures/views/mixins.py +++ b/measures/views/mixins.py @@ -1,4 +1,5 @@ from typing import Type +from typing import Dict from common.models import TrackedModel from measures import models @@ -50,3 +51,73 @@ def get_queryset(self): """Get the queryset for measures that are candidates for editing/deletion.""" return models.Measure.objects.filter(pk__in=self.measure_selections) + + +class MeasureSerializableWizardMixin: + """A Mixin for the wizard forms that utilise asynchronous bulk processing. This mixin provides the functionality to go through each form + and serialize the data ready for storing in the database.""" + + def get_data_form_list(self) -> dict: + """ + Returns a form list based on form_list, conditionally including only + those items as per condition_list and also appearing in data_form_list. + The list is generated dynamically because conditions in condition_list + may be dynamic. + Essentially, version of `WizardView.get_form_list()` filtering in only + those list items appearing in `data_form_list`. + """ + data_form_keys = [key for key, form in self.data_form_list] + return { + form_key: form_class + for form_key, form_class in self.get_form_list().items() + if form_key in data_form_keys + } + + def all_serializable_form_data(self) -> Dict: + """ + Returns serializable data for all wizard steps. + This is a re-implementation of + MeasureCreateWizard.get_all_cleaned_data(), but using self.data after + is_valid() has been successfully run. + """ + + all_data = {} + + for form_key in self.get_data_form_list().keys(): + all_data[form_key] = self.serializable_form_data_for_step(form_key) + + return all_data + + def serializable_form_data_for_step(self, step) -> Dict: + """ + Returns serializable data for a wizard step. + This is a re-implementation of WizardView.get_cleaned_data_for_step(), + returning the serializable version of data in place of the form's + regular cleaned_data. + """ + + form_obj = self.get_form( + step=step, + data=self.storage.get_step_data(step), + files=self.storage.get_step_files(step), + ) + + return form_obj.serializable_data(remove_key_prefix=step) + + def all_serializable_form_kwargs(self) -> Dict: + """Returns serializable kwargs for all wizard steps.""" + + all_kwargs = {} + + for form_key in self.get_data_form_list().keys(): + all_kwargs[form_key] = self.serializable_form_kwargs_for_step(form_key) + + return all_kwargs + + def serializable_form_kwargs_for_step(self, step) -> Dict: + """Returns serializable kwargs for a wizard step.""" + + form_kwargs = self.get_form_kwargs(step) + form_class = self.form_list[step] + + return form_class.serializable_init_kwargs(form_kwargs) diff --git a/measures/views/wizard.py b/measures/views/wizard.py index af82e1af0..b69b6d1b9 100644 --- a/measures/views/wizard.py +++ b/measures/views/wizard.py @@ -13,12 +13,9 @@ from django.views.generic import TemplateView from formtools.wizard.views import NamedUrlSessionWizardView -from common.util import TaricDateRange -from common.validators import UpdateType from geo_areas import constants from geo_areas.models import GeographicalArea from geo_areas.models import GeographicalMembership -from geo_areas.utils import get_all_members_of_geo_groups from geo_areas.validators import AreaCode from measures import forms from measures import models @@ -27,7 +24,8 @@ from measures.constants import START from measures.constants import MeasureEditSteps from measures.creators import MeasuresCreator -from measures.util import diff_components +from measures.editors import MeasuresEditor +from measures.views.mixins import MeasureSerializableWizardMixin from workbaskets.models import WorkBasket from workbaskets.views.decorators import require_current_workbasket @@ -41,6 +39,7 @@ class MeasureEditWizard( PermissionRequiredMixin, MeasureSelectionQuerysetMixin, NamedUrlSessionWizardView, + MeasureSerializableWizardMixin, ): """ Multipart form wizard for editing multiple measures. @@ -51,8 +50,7 @@ class MeasureEditWizard( storage_name = "measures.wizard.MeasureEditSessionStorage" permission_required = ["common.change_trackedmodel"] - form_list = [ - (START, forms.MeasuresEditFieldsForm), + data_form_list = [ (MeasureEditSteps.START_DATE, forms.MeasureStartDateForm), (MeasureEditSteps.END_DATE, forms.MeasureEndDateForm), (MeasureEditSteps.QUOTA_ORDER_NUMBER, forms.MeasureQuotaOrderNumberForm), @@ -63,6 +61,14 @@ class MeasureEditWizard( forms.MeasureGeographicalAreaExclusionsFormSet, ), ] + """Forms in this wizard's steps that collect user data.""" + + form_list = [ + (START, forms.MeasuresEditFieldsForm), + *data_form_list, + ] + """All Forms in this wizard's steps, including both those that collect user + data and those that don't.""" templates = { START: "measures/edit-multiple-start.jinja", @@ -100,6 +106,10 @@ def get_template_names(self): "measures/edit-wizard-step.jinja", ) + @property + def workbasket(self) -> WorkBasket: + return WorkBasket.current(self.request) + def get_context_data(self, form, **kwargs): context = super().get_context_data(form=form, **kwargs) context["step_metadata"] = self.step_metadata @@ -131,179 +141,76 @@ def get_form_kwargs(self, step): return kwargs - def update_measure_components( - self, - measure: models.Measure, - duties: str, - workbasket: WorkBasket, - ): - """Updates the measure components associated to the measure.""" - diff_components( - instance=measure, - duty_sentence=duties if duties else measure.duty_sentence, - start_date=measure.valid_between.lower, - workbasket=workbasket, - transaction=workbasket.current_transaction, - ) + def done(self, form_list, **kwargs): + if settings.MEASURES_ASYNC_EDIT: + return self.async_done(form_list, **kwargs) + else: + return self.sync_done(form_list, **kwargs) - def update_measure_condition_components( - self, - measure: models.Measure, - workbasket: WorkBasket, - ): - """Updates the measure condition components associated to the - measure.""" - conditions = measure.conditions.current() - for condition in conditions: - condition.new_version( - dependent_measure=measure, - workbasket=workbasket, - ) + def async_done(self, form_list, **kwargs): + logger.info("Editing measures asynchronously.") + serializable_data = self.all_serializable_form_data() + serializable_form_kwargs = self.all_serializable_form_kwargs() - def update_measure_excluded_geographical_areas( - self, - edited: bool, - measure: models.Measure, - exclusions: List[GeographicalArea], - workbasket: WorkBasket, - ): - """Updates the excluded geographical areas associated to the measure.""" - existing_exclusions = measure.exclusions.current() - - # Update any exclusions to new measure version - if not edited: - for exclusion in existing_exclusions: - exclusion.new_version( - modified_measure=measure, - workbasket=workbasket, - ) - return + db_selected_measures = [] + for measure in self.get_queryset(): + db_selected_measures.append(measure.id) - new_excluded_areas = get_all_members_of_geo_groups( - validity=measure.valid_between, - geo_areas=exclusions, + measures_bulk_editor = models.MeasuresBulkEditor.objects.create( + form_data=serializable_data, + form_kwargs=serializable_form_kwargs, + workbasket=self.workbasket, + user=self.request.user, + selected_measures=db_selected_measures, ) + self.session_store.clear() + measures_bulk_editor.schedule_task() - for geo_area in new_excluded_areas: - existing_exclusion = existing_exclusions.filter( - excluded_geographical_area=geo_area, - ).first() - if existing_exclusion: - existing_exclusion.new_version( - modified_measure=measure, - workbasket=workbasket, - ) - else: - models.MeasureExcludedGeographicalArea.objects.create( - modified_measure=measure, - excluded_geographical_area=geo_area, - update_type=UpdateType.CREATE, - transaction=workbasket.new_transaction(), - ) + return redirect( + reverse( + "workbaskets:workbasket-ui-review-measures", + kwargs={"pk": self.workbasket.pk}, + ), + ) - removed_excluded_areas = { - e.excluded_geographical_area for e in existing_exclusions - }.difference(set(exclusions)) + def edit_measures(self, selected_measures, cleaned_data): + """Synchronously edit measures within the context of the view / web + worker using accumulated data, `cleaned_data`, from all the necessary + wizard forms.""" - exclusions_to_remove = [ - existing_exclusions.get(excluded_geographical_area__id=geo_area.id) - for geo_area in removed_excluded_areas - ] + measures_editor = MeasuresEditor( + self.workbasket, selected_measures, cleaned_data + ) + return measures_editor.edit_measures() - for exclusion in exclusions_to_remove: - exclusion.new_version( - update_type=UpdateType.DELETE, - modified_measure=measure, - workbasket=workbasket, - ) + def sync_done(self, form_list, **kwargs): + """ + Handles this wizard's done step to edit measures within the context of + the web worker process. - def update_measure_footnote_associations(self, measure, workbasket): - """Updates the footnotes associated to the measure.""" - footnote_associations = ( - models.FootnoteAssociationMeasure.objects.current().filter( - footnoted_measure__sid=measure.sid, - ) - ) - for fa in footnote_associations: - fa.new_version( - footnoted_measure=measure, - workbasket=workbasket, - ) + Because bulk editing measures can be computationally expensive, this can + take an excessive amount of time within the context of HTTP request + processing. + """ + logger.info("Editing measures synchronously.") - def done(self, form_list, **kwargs): cleaned_data = self.get_all_cleaned_data() selected_measures = self.get_queryset() - workbasket = WorkBasket.current(self.request) - new_start_date = cleaned_data.get("start_date", None) - new_end_date = cleaned_data.get("end_date", False) - new_quota_order_number = cleaned_data.get("order_number", None) - new_generating_regulation = cleaned_data.get("generating_regulation", None) - new_duties = cleaned_data.get("duties", None) - new_exclusions = [ - e["excluded_area"] - for e in cleaned_data.get("formset-geographical_area_exclusions", []) - ] - for measure in selected_measures: - new_measure = measure.new_version( - workbasket=workbasket, - update_type=UpdateType.UPDATE, - valid_between=TaricDateRange( - lower=( - new_start_date - if new_start_date - else measure.valid_between.lower - ), - upper=( - new_end_date - if new_end_date is not False - else measure.valid_between.upper - ), - ), - order_number=( - new_quota_order_number - if new_quota_order_number - else measure.order_number - ), - generating_regulation=( - new_generating_regulation - if new_generating_regulation - else measure.generating_regulation - ), - ) - self.update_measure_components( - measure=new_measure, - duties=new_duties, - workbasket=workbasket, - ) - self.update_measure_condition_components( - measure=new_measure, - workbasket=workbasket, - ) - self.update_measure_excluded_geographical_areas( - edited="geographical_area_exclusions" - in cleaned_data.get("fields_to_edit", []), - measure=new_measure, - exclusions=new_exclusions, - workbasket=workbasket, - ) - self.update_measure_footnote_associations( - measure=new_measure, - workbasket=workbasket, - ) + + self.edit_measures(selected_measures, cleaned_data) self.session_store.clear() return redirect( reverse( "workbaskets:workbasket-ui-review-measures", - kwargs={"pk": workbasket.pk}, + kwargs={"pk": self.workbasket.pk}, ), ) @method_decorator(require_current_workbasket, name="dispatch") class MeasureCreateWizard( - PermissionRequiredMixin, - NamedUrlSessionWizardView, + PermissionRequiredMixin, NamedUrlSessionWizardView, MeasureSerializableWizardMixin ): """ Multipart form wizard for creating a single measure. @@ -426,24 +333,6 @@ class MeasureCreateWizard( boolean or boolean values that indicate whether a wizard step should be shown.""" - def get_data_form_list(self) -> dict: - """ - Returns a form list based on form_list, conditionally including only - those items as per condition_list and also appearing in data_form_list. - - The list is generated dynamically because conditions in condition_list - may be dynamic. - - Essentially, version of `WizardView.get_form_list()` filtering in only - those list items appearing in `data_form_list`. - """ - data_form_keys = [key for key, form in self.data_form_list] - return { - form_key: form_class - for form_key, form_class in self.get_form_list().items() - if form_key in data_form_keys - } - @property def workbasket(self) -> WorkBasket: return WorkBasket.current(self.request) @@ -507,57 +396,6 @@ def async_done(self, form_list, **kwargs): expected_measures_count=measures_bulk_creator.expected_measures_count, ) - def all_serializable_form_data(self) -> Dict: - """ - Returns serializable data for all wizard steps. - - This is a re-implementation of - MeasureCreateWizard.get_all_cleaned_data(), but using self.data after - is_valid() has been successfully run. - """ - - all_data = {} - - for form_key in self.get_data_form_list().keys(): - all_data[form_key] = self.serializable_form_data_for_step(form_key) - - return all_data - - def serializable_form_data_for_step(self, step) -> Dict: - """ - Returns serializable data for a wizard step. - - This is a re-implementation of WizardView.get_cleaned_data_for_step(), - returning the serializable version of data in place of the form's - regular cleaned_data. - """ - - form_obj = self.get_form( - step=step, - data=self.storage.get_step_data(step), - files=self.storage.get_step_files(step), - ) - - return form_obj.serializable_data(remove_key_prefix=step) - - def all_serializable_form_kwargs(self) -> Dict: - """Returns serializable kwargs for all wizard steps.""" - - all_kwargs = {} - - for form_key in self.get_data_form_list().keys(): - all_kwargs[form_key] = self.serializable_form_kwargs_for_step(form_key) - - return all_kwargs - - def serializable_form_kwargs_for_step(self, step) -> Dict: - """Returns serializable kwargs for a wizard step.""" - - form_kwargs = self.get_form_kwargs(step) - form_class = self.form_list[step] - - return form_class.serializable_init_kwargs(form_kwargs) - def get_all_cleaned_data(self): """ Returns a merged dictionary of all step cleaned_data. If a step contains diff --git a/settings/common.py b/settings/common.py index da592d65a..a4c235f55 100644 --- a/settings/common.py +++ b/settings/common.py @@ -667,6 +667,9 @@ "measures.tasks.bulk_create_measures": { "queue": "bulk-create", }, + "measures.tasks.bulk_edit_measures": { + "queue": "bulk-create", + }, } SQLITE_EXCLUDED_APPS = [ @@ -923,3 +926,4 @@ # Asynchronous / background (bulk) object creation and editing config. MEASURES_ASYNC_CREATION = is_truthy(os.environ.get("MEASURES_ASYNC_CREATION", "true")) +MEASURES_ASYNC_EDIT = is_truthy(os.environ.get("MEASURES_ASYNC_EDIT", "true"))