From 9e1edfd3919e64f1d5c71a0ed85afd6235dc20e8 Mon Sep 17 00:00:00 2001 From: Tomos Williams Date: Wed, 28 Aug 2024 16:11:47 +0100 Subject: [PATCH 1/6] updated serializer and tests --- api/core/tests/test_validators.py | 20 +++++ api/core/validators.py | 22 ++++- api/goods/serializers.py | 12 +-- api/goods/tests/test_serializers.py | 111 ++++++++++++++++++++++++++ api/parties/serializers.py | 6 +- api/parties/tests/test_serializers.py | 104 ++++++++++++++++++++---- 6 files changed, 248 insertions(+), 27 deletions(-) create mode 100644 api/core/tests/test_validators.py diff --git a/api/core/tests/test_validators.py b/api/core/tests/test_validators.py new file mode 100644 index 0000000000..28a1e051b0 --- /dev/null +++ b/api/core/tests/test_validators.py @@ -0,0 +1,20 @@ +import pytest +from django.core.exceptions import ValidationError +from api.core.validators import EdifactStringValidator + + +@pytest.mark.parametrize( + "value", + ((""), ("random value"), ("random-value"), ("random!value"), ("random-!.<>/%&*;+'(),.value")), +) +def test_edifactstringvalidator_valid(value): + validator = EdifactStringValidator() + result = validator(value) + assert result == None + + +@pytest.mark.parametrize("value", (("\r\n"), ("random_value"), ("random$value"), ("random@value"))) +def test_edifactstringvalidator_invalid(value): + validator = EdifactStringValidator() + with pytest.raises(ValidationError): + results = validator(value) diff --git a/api/core/validators.py b/api/core/validators.py index 3be8a04dac..9b4b159cc4 100644 --- a/api/core/validators.py +++ b/api/core/validators.py @@ -1,6 +1,6 @@ from django.utils.deconstruct import deconstructible from rest_framework.exceptions import ValidationError - +import re from api.staticdata.control_list_entries.models import ControlListEntry @@ -22,3 +22,23 @@ def __call__(self, value): ControlListEntry.objects.get(rating=value) except ControlListEntry.DoesNotExist: raise ValidationError(self.message, code=self.code) + + +class EdifactStringValidator: + message = "Undefined Error" + regex_string = r"^[a-zA-Z0-9 .,\-\)\(\/'+:=\?\!\"%&\*;\<\>]+$" + + def __call__(self, value): + match_regex = re.compile(self.regex_string) + is_value_valid = bool(match_regex.match(value)) + if not is_value_valid: + raise ValidationError(self.message) + + +class GoodNameValidator(EdifactStringValidator): + message = "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes" + + +class PartyAddressValidator(EdifactStringValidator): + regex_string = re.compile(r"^[a-zA-Z0-9 .,\-\)\(\/'+:=\?\!\"%&\*;\<\>\r\n]+$") + message = "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes" diff --git a/api/goods/serializers.py b/api/goods/serializers.py index 679a74d677..122c66e779 100644 --- a/api/goods/serializers.py +++ b/api/goods/serializers.py @@ -1,8 +1,8 @@ from rest_framework import serializers from rest_framework.relations import PrimaryKeyRelatedField - from api.core.helpers import str_to_bool from api.core.serializers import KeyValueChoiceField, ControlListEntryField, GoodControlReviewSerializer +from api.core.validators import GoodNameValidator from api.documents.libraries.process_document import process_document from api.goods.enums import ( FirearmCategory, @@ -308,7 +308,7 @@ def update(self, instance, validated_data): class GoodListSerializer(serializers.Serializer): id = serializers.UUIDField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() control_list_entries = ControlListEntrySerializer(many=True, allow_null=True) part_number = serializers.CharField() @@ -357,7 +357,7 @@ class GoodCreateSerializer(serializers.ModelSerializer): Because of this, each 'get' override must check the instance type before creating queries """ - name = serializers.CharField(error_messages={"blank": "Enter a product name"}) + name = serializers.CharField(error_messages={"blank": "Enter a product name"}, validators=[GoodNameValidator()]) description = serializers.CharField(max_length=280, allow_blank=True, required=False) is_good_controlled = KeyValueChoiceField(choices=GoodControlled.choices, allow_null=True) control_list_entries = ControlListEntryField(required=False, many=True, allow_null=True, allow_empty=True) @@ -690,7 +690,7 @@ def create(self, validated_data): class GoodDocumentViewSerializer(serializers.Serializer): id = serializers.UUIDField() created_at = serializers.DateTimeField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() user = ExporterUserSimpleSerializer() s3_key = serializers.SerializerMethodField() @@ -788,7 +788,7 @@ class Meta: class GoodSerializerInternal(serializers.Serializer): id = serializers.UUIDField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() part_number = serializers.CharField() no_part_number_comments = serializers.CharField() @@ -871,7 +871,7 @@ def get_user(self, instance): class GoodSerializerExporter(serializers.Serializer): id = serializers.UUIDField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() control_list_entries = ControlListEntryField(many=True) part_number = serializers.CharField() diff --git a/api/goods/tests/test_serializers.py b/api/goods/tests/test_serializers.py index 01f7026e3c..6f17034ab9 100644 --- a/api/goods/tests/test_serializers.py +++ b/api/goods/tests/test_serializers.py @@ -1,4 +1,5 @@ from datetime import datetime +from parameterized import parameterized from django.urls import reverse from django.utils import timezone @@ -150,6 +151,61 @@ def test_report_summary_present(self): self.assertEqual(actual_subject["id"], str(self.good.report_summary_subject.id)) self.assertEqual(actual_subject["name"], self.good.report_summary_subject.name) + @parameterized.expand( + [ + "random good", + "good-name", + "good!name", + "good-!.<>/%&*;+'(),.name", + ] + ) + def test_validate_good_internal_name_valid(self, name): + serializer = GoodSerializerInternal( + data={"name": name}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + + @parameterized.expand( + [ + ("", "This field may not be blank."), + ("\r\n", "This field may not be blank."), + ( + "good\rname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\r\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good_name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good$name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good@name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ] + ) + def test_validate_good_internal_name_invalid(self, name, error_message): + serializer = GoodSerializerInternal(data={"name": name}, partial=True) + self.assertFalse(serializer.is_valid()) + name_errors = serializer.errors["name"] + self.assertEqual(len(name_errors), 1) + self.assertEqual( + str(name_errors[0]), + error_message, + ) + class GoodSerializerExporterFullDetailTests(DataTestClient): @@ -182,3 +238,58 @@ def test_exporter_has_archive_history(self): archive_history["actioned_on"], datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.get_current_timezone()), ) + + @parameterized.expand( + [ + "random good", + "good-name", + "good!name", + "good-!.<>/%&*;+'(),.name", + ] + ) + def test_validate_good_exporter_name_valid(self, address): + serializer = GoodSerializerInternal( + data={"address": address}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + + @parameterized.expand( + [ + ("", "This field may not be blank."), + ("\r\n", "This field may not be blank."), + ( + "good\rname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\r\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good_name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good$name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good@name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ] + ) + def test_validate_good_exporter_name_invalid(self, name, error_message): + serializer = GoodSerializerExporterFullDetail(data={"name": name}, partial=True) + self.assertFalse(serializer.is_valid()) + name_errors = serializer.errors["name"] + self.assertEqual(len(name_errors), 1) + self.assertEqual( + str(name_errors[0]), + error_message, + ) diff --git a/api/parties/serializers.py b/api/parties/serializers.py index 68dcb98201..a4eb7edb43 100644 --- a/api/parties/serializers.py +++ b/api/parties/serializers.py @@ -3,6 +3,7 @@ from api.cases.enums import CaseTypeSubTypeEnum from api.core.serializers import KeyValueChoiceField, CountrySerializerField +from api.core.validators import PartyAddressValidator from api.documents.libraries.process_document import process_document from api.flags.serializers import FlagSerializer from api.goods.enums import PvGrading @@ -15,7 +16,7 @@ class PartySerializer(serializers.ModelSerializer): name = serializers.CharField(error_messages=PartyErrors.NAME) - address = serializers.CharField(error_messages=PartyErrors.ADDRESS) + address = serializers.CharField(error_messages=PartyErrors.ADDRESS, validators=[PartyAddressValidator()]) country = CountrySerializerField() website = serializers.CharField(required=False, allow_blank=True) signatory_name_euu = serializers.CharField(allow_blank=True) @@ -116,8 +117,7 @@ def validate(self, data): return validated_data - @staticmethod - def validate_website(value): + def validate_website(self, value): """ Custom validation for URL that makes use of django URLValidator but makes the passing of http:// or https:// optional by prepending diff --git a/api/parties/tests/test_serializers.py b/api/parties/tests/test_serializers.py index 509c439f04..903f293ce8 100644 --- a/api/parties/tests/test_serializers.py +++ b/api/parties/tests/test_serializers.py @@ -1,24 +1,94 @@ -import pytest from parameterized import parameterized from api.parties.serializers import PartySerializer -from django.core.exceptions import ValidationError +from test_helpers.clients import DataTestClient -@parameterized.expand( - [ - ("http://workingexample.com", "http://workingexample.com"), - ("http://www.workingexample.com", "http://www.workingexample.com"), - ("http://WWW.workingexample.com", "http://WWW.workingexample.com"), - ("http://workingexample.com", "http://workingexample.com"), - ("workingexample.com", "https://workingexample.com"), - ("HTTPS://workingexample.com", "HTTPS://workingexample.com"), - ] -) -def test_party_serializer_validate_website_valid(url_input, url_output): - assert url_output == PartySerializer.validate_website(url_input) +class TestPartySerializer(DataTestClient): + @parameterized.expand( + [ + ("http://workingexample.com", "http://workingexample.com"), + ("http://www.workingexample.com", "http://www.workingexample.com"), + ("http://WWW.workingexample.com", "http://WWW.workingexample.com"), + ("http://workingexample.com", "http://workingexample.com"), + ("workingexample.com", "https://workingexample.com"), + ("HTTPS://workingexample.com", "HTTPS://workingexample.com"), + ] + ) + def test_party_serializer_validate_website_valid(self, url_input, url_output): + serializer = PartySerializer( + data={"website": url_input}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + self.assertEqual( + serializer.validated_data["website"], + url_output, + ) + def test_party_serializer_validate_website_invalid(self): + serializer = PartySerializer( + data={"website": "invalid@ur&l-i.am"}, + partial=True, + ) + self.assertFalse(serializer.is_valid()) + website_errors = serializer.errors["website"] + self.assertEqual(len(website_errors), 1) + self.assertEqual( + str(website_errors[0]), + "Enter a valid URL.", + ) -def test_party_serializer_validate_website_invalid(): - with pytest.raises(ValidationError): - PartySerializer.validate_website("invalid@ur&l-i.am") + @parameterized.expand( + [ + "random party", + "party-address", + "party!address", + "party-!.<>/%&*;+'(),.address", + "party\r\naddress", + ] + ) + def test_validate_party_address_valid(self, address): + serializer = PartySerializer( + data={"address": address}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + + @parameterized.expand( + [ + ("\r\n", "Enter an address"), + ( + "party\address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party-\waddress", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party_address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party$address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party@address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ] + ) + def test_validate_party_address_invalid(self, address, error_message): + serializer = PartySerializer( + data={"address": address}, + partial=True, + ) + self.assertFalse(serializer.is_valid()) + address_errors = serializer.errors["address"] + self.assertEqual(len(address_errors), 1) + self.assertEqual( + str(address_errors[0]), + error_message, + ) From 8421f1f60297079a66002ef88dafffcb3acd9e96 Mon Sep 17 00:00:00 2001 From: Tomos Williams Date: Wed, 11 Sep 2024 15:00:51 +0100 Subject: [PATCH 2/6] Update api/core/tests/test_validators.py Co-authored-by: code-review-doctor[bot] <72320148+code-review-doctor[bot]@users.noreply.github.com> --- api/core/tests/test_validators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/tests/test_validators.py b/api/core/tests/test_validators.py index 28a1e051b0..2bdb96b3d0 100644 --- a/api/core/tests/test_validators.py +++ b/api/core/tests/test_validators.py @@ -10,7 +10,7 @@ def test_edifactstringvalidator_valid(value): validator = EdifactStringValidator() result = validator(value) - assert result == None + assert result is None @pytest.mark.parametrize("value", (("\r\n"), ("random_value"), ("random$value"), ("random@value"))) From 3a6eb67f35681fd2c1a45bb748714573aad70375 Mon Sep 17 00:00:00 2001 From: Tomos Williams Date: Wed, 11 Sep 2024 15:29:16 +0100 Subject: [PATCH 3/6] expect correct exception --- api/core/tests/test_validators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/tests/test_validators.py b/api/core/tests/test_validators.py index 2bdb96b3d0..aea29f02d0 100644 --- a/api/core/tests/test_validators.py +++ b/api/core/tests/test_validators.py @@ -1,11 +1,11 @@ import pytest -from django.core.exceptions import ValidationError +from rest_framework.exceptions import ValidationError from api.core.validators import EdifactStringValidator @pytest.mark.parametrize( "value", - ((""), ("random value"), ("random-value"), ("random!value"), ("random-!.<>/%&*;+'(),.value")), + (("random value"), ("random-value"), ("random!value"), ("random-!.<>/%&*;+'(),.value")), ) def test_edifactstringvalidator_valid(value): validator = EdifactStringValidator() From 9f60e3a897c5fd43c6c2d93aecb8f6007220d6a6 Mon Sep 17 00:00:00 2001 From: Arun Siluvery Date: Wed, 11 Sep 2024 22:42:34 +0100 Subject: [PATCH 4/6] Update seedinternalusers command to accept team and default queue The users created using this command are by default belong to Admin team and queue is set to 'All cases'. This is a problem in E2E tests as when we sign in as different users their team and queue are not correct. Because of this we are having to switch team, queue as part of the test and as these persist between tests it can cause test failures. Better approach is to seed certain users with expected roles, teams, queues etc. --- .../management/commands/seedinternalusers.py | 8 ++- .../commands/tests/test_commands.py | 54 ++++++++++++++++++- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/api/staticdata/management/commands/seedinternalusers.py b/api/staticdata/management/commands/seedinternalusers.py index c59e21ee1c..b54e6b2702 100644 --- a/api/staticdata/management/commands/seedinternalusers.py +++ b/api/staticdata/management/commands/seedinternalusers.py @@ -3,6 +3,7 @@ from api.core.constants import Teams, Roles from api.conf.settings import env +from api.queues.constants import ALL_CASES_QUEUE_ID from api.staticdata.management.SeedCommand import SeedCommand from api.users.enums import UserType from api.users.models import Role, GovUser, BaseUser @@ -23,6 +24,8 @@ def operation(self, *args, **options): for admin_user in admin_users: email = admin_user["email"] + team_id = admin_user.get("team_id", Teams.ADMIN_TEAM_ID) + default_queue = admin_user.get("default_queue", ALL_CASES_QUEUE_ID) role = Role.objects.get( name=admin_user.get("role", Roles.INTERNAL_SUPER_USER_ROLE_NAME), type=UserType.INTERNAL @@ -31,14 +34,15 @@ def operation(self, *args, **options): email__iexact=email, defaults={"email": email}, type=UserType.INTERNAL ) admin_user, created = GovUser.objects.get_or_create( - baseuser_ptr=base_user, defaults={"team_id": Teams.ADMIN_TEAM_ID, "role": role} + baseuser_ptr=base_user, + defaults={"team_id": team_id, "role": role, "default_queue": default_queue}, ) if created or admin_user.role != role: admin_user.role = role admin_user.save() - admin_data = dict(email=email, team=Teams.ADMIN_TEAM_NAME, role=role.name) + admin_data = dict(email=email, team=admin_user.team.name, role=role.name) self.print_created_or_updated(GovUser, admin_data, is_created=created) if not GovUser.objects.count() >= 1: diff --git a/api/staticdata/management/commands/tests/test_commands.py b/api/staticdata/management/commands/tests/test_commands.py index ca85c4fff8..d6d861b289 100644 --- a/api/staticdata/management/commands/tests/test_commands.py +++ b/api/staticdata/management/commands/tests/test_commands.py @@ -1,15 +1,18 @@ +import json import os import pytest +from parameterized import parameterized from tempfile import NamedTemporaryFile from api.cases.enums import CaseTypeEnum from api.cases.models import CaseType -from api.core.constants import GovPermissions, ExporterPermissions +from api.core.constants import GovPermissions, ExporterPermissions, Teams from api.conf.settings import BASE_DIR from api.letter_templates.models import LetterTemplate from api.staticdata.countries.models import Country from api.cases.enums import AdviceType +from api.queues.constants import ALL_CASES_QUEUE_ID from api.staticdata.decisions.models import Decision from api.staticdata.denial_reasons.models import DenialReason from api.staticdata.letter_layouts.models import LetterLayout @@ -23,12 +26,22 @@ seedlettertemplates, seedrolepermissions, seedfinaldecisions, + seedinternalusers, ) from api.staticdata.statuses.models import CaseStatus, CaseStatusCaseType -from api.users.models import Permission +from api.teams.enums import TeamIdEnum +from api.users.enums import UserType +from api.users.models import GovUser, Permission +from api.users.tests.factories import RoleFactory class SeedingTests(SeedCommandTest): + def setUp(self) -> None: + super().setUp() + role_names = ["Super User", "Case officer", "Case adviser", "Manager", "Senior Manager"] + for name in role_names: + RoleFactory(name=name, type=UserType.INTERNAL) + @pytest.mark.seeding def test_seed_case_types(self): self.seed_command(seedcasetypes.Command) @@ -111,3 +124,40 @@ def test_seed_letter_templates(self): # running again with existing templates does nothing self.seed_command(seedlettertemplates.Command) self.assertEqual(LetterTemplate.objects.count(), 2) + + @pytest.mark.seeding + @parameterized.expand( + [ + ([{"email": "admin@example.co.uk", "role": "Super User"}],), + ([{"email": "manager@example.co.uk", "role": "Manager", "first_name": "LU", "last_name": "Manager"}],), + ( + [ + { + "email": "senior_manager@example.co.uk", + "role": "Senior Manager", + "team_id": TeamIdEnum.LICENSING_UNIT, + } + ], + ), + ( + [ + { + "email": "case_officer@example.co.uk", + "role": "Case officer", + "team_id": TeamIdEnum.LICENSING_UNIT, + "default_queue": "00000000-0000-0000-0000-000000000004", + } + ], + ), + ] + ) + def test_seed_internal_users(self, data): + os.environ["INTERNAL_USERS"] = json.dumps(data) + + self.seed_command(seedinternalusers.Command) + + for item in data: + user = GovUser.objects.get(baseuser_ptr__email=item["email"]) + self.assertEqual(user.role.name, item["role"]) + self.assertEqual(str(user.team_id), item.get("team_id", Teams.ADMIN_TEAM_ID)) + self.assertEqual(str(user.default_queue), item.get("default_queue", ALL_CASES_QUEUE_ID)) From 0ce91c0600ff25d7b024f1db06ca108f19f932c4 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Fri, 6 Sep 2024 14:48:22 +0100 Subject: [PATCH 5/6] Fix a race condition when multiple amendments are made This stops a race condition where an amendment request hits our API multiple times simultaneously --- .circleci/config.yml | 45 ++++++++++++++++++- api/applications/exceptions.py | 2 + api/applications/models.py | 29 ++++++++++++ api/applications/tests/test_models.py | 63 +++++++++++++++++++++++++-- pytest.ini | 1 + 5 files changed, 134 insertions(+), 6 deletions(-) create mode 100644 api/applications/exceptions.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 8a29ab99d2..67c9f96a11 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -134,7 +134,7 @@ jobs: - run: name: Run tests command: | - pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser" + pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser and not requires_transactions" - upload_code_coverage: alias: tests @@ -154,7 +154,7 @@ jobs: - run: name: Run tests on Postgres 13 command: | - pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser" + pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser and not requires_transactions" - upload_code_coverage: alias: tests_dbt_platform @@ -232,6 +232,44 @@ jobs: - upload_code_coverage: alias: anonymised_db_dump_tests_dbt_platform + requires_transactions_tests: + docker: + - <<: *image_python + - <<: *image_postgres + - <<: *image_opensearch_v1 + - <<: *image_redis + working_directory: ~/lite-api + environment: + <<: *common_env_vars + LITE_API_ENABLE_ES: True + steps: + - setup + - run: + name: Run requiring transactions tests + command: | + pipenv run pytest --cov=. --cov-report xml --cov-config=.coveragerc -k requires_transactions + - upload_code_coverage: + alias: requires_transactions_tests + + requires_transactions_tests_dbt_platform: + docker: + - <<: *image_python + - <<: *image_postgres13 + - <<: *image_opensearch + - <<: *image_redis + working_directory: ~/lite-api + environment: + <<: *common_env_vars + LITE_API_ENABLE_ES: True + steps: + - setup + - run: + name: Run requiring transactions tests on Postgres 13 + command: | + pipenv run pytest --cov=. --cov-report xml --cov-config=.coveragerc -k requires_transactions + - upload_code_coverage: + alias: requires_transactions_tests_dbt_platform + migration_tests: docker: - <<: *image_python @@ -613,7 +651,10 @@ workflows: - open_search_tests - migration_tests - lite_routing_tests + - requires_transactions_tests - check-lite-routing-sha - e2e_tests - anonymised_db_dump_tests - anonymised_db_dump_tests_dbt_platform + - requires_transactions_tests + - requires_transactions_tests_dbt_platform diff --git a/api/applications/exceptions.py b/api/applications/exceptions.py new file mode 100644 index 0000000000..218443f118 --- /dev/null +++ b/api/applications/exceptions.py @@ -0,0 +1,2 @@ +class AmendmentError(Exception): + pass diff --git a/api/applications/models.py b/api/applications/models.py index cf86da3a11..484a4e32ac 100644 --- a/api/applications/models.py +++ b/api/applications/models.py @@ -1,3 +1,4 @@ +import logging import uuid from django.contrib.postgres.fields import ArrayField @@ -14,6 +15,7 @@ NSGListType, ) from api.appeals.models import Appeal +from api.applications.exceptions import AmendmentError from api.applications.managers import BaseApplicationManager from api.applications.libraries.application_helpers import create_submitted_audit from api.audit_trail.models import AuditType @@ -49,6 +51,9 @@ from lite_routing.routing_rules_internal.enums import QueuesEnum +logger = logging.getLogger(__name__) + + class ApplicationException(APIException): def __init__(self, data): super().__init__(data) @@ -375,6 +380,30 @@ def clone(self, exclusions=None, **overrides): @transaction.atomic def create_amendment(self, user): + # It's possible that we've arrived here multiple times simultaneously because multiple requests to create an + # amendment have been made at the same time. + # The views that call this may have already checked the status but it's possible that when they checked it was + # before the status has actually been updated on the application so we've got a potential race condition going. + # What we want to do here is lock the row to make sure that only one request can be trying to create an + # amendment at any one time and if we have multiple requests clashing we'll re-check the status once the lock + # has been removed from any other racing requests. + # To be helpful to the caller we'll return the amendment that did happen even if there are multiple requests. + original = StandardApplication.objects.select_for_update().get(pk=self.pk) + if not CaseStatusEnum.can_invoke_major_edit(original.status.status): + logger.warning( + "Attempted to create an amendment from an already amended application %s with status %s", + original.pk, + original.status, + ) + if original.superseded_by: + logger.info( + "Found an amendment already: %s for the application: %s", + original.superseded_by.pk, + original.pk, + ) + return StandardApplication.objects.get(pk=original.superseded_by.pk) + raise AmendmentError(f"Failed to create an amendment from {original.pk}") + amendment_application = self.clone(amendment_of=self) CaseQueue.objects.filter(case=self.case_ptr).delete() audit_trail_service.create( diff --git a/api/applications/tests/test_models.py b/api/applications/tests/test_models.py index 6d9a75bd5e..ba797065cd 100644 --- a/api/applications/tests/test_models.py +++ b/api/applications/tests/test_models.py @@ -1,9 +1,16 @@ +import concurrent.futures +import pytest + from django.forms import model_to_dict +from django.test import TransactionTestCase from django.utils import timezone +from parameterized import parameterized + from test_helpers.clients import DataTestClient from api.appeals.tests.factories import AppealFactory +from api.applications.exceptions import AmendmentError from api.audit_trail.models import Audit from api.cases.models import CaseType, Queue from api.flags.models import Flag @@ -32,8 +39,12 @@ from api.staticdata.control_list_entries.models import ControlListEntry from api.staticdata.report_summaries.models import ReportSummary, ReportSummaryPrefix, ReportSummarySubject from api.staticdata.statuses.models import CaseStatus, CaseSubStatus -from api.staticdata.statuses.enums import CaseStatusEnum +from api.staticdata.statuses.enums import ( + CaseStatusEnum, +) +from api.users.enums import SystemUser from api.users.models import ExporterUser +from api.users.tests.factories import BaseUserFactory class TestBaseApplication(DataTestClient): @@ -79,9 +90,10 @@ def test_on_submit_amendment_application(self): class TestStandardApplication(DataTestClient): - def test_create_amendment(self): + @parameterized.expand(CaseStatusEnum.can_invoke_major_edit_statuses()) + def test_create_amendment(self, major_editable_status): original_application = StandardApplicationFactory( - status=CaseStatus.objects.get(status="ogd_advice"), + status=CaseStatus.objects.get(status=major_editable_status), ) original_application.queues.add(Queue.objects.first()) original_application.save() @@ -109,10 +121,22 @@ def test_create_amendment(self): assert amendment_audit_entry.actor == exporter_user status_change_audit_entry = audit_entries[0] assert status_change_audit_entry.payload == { - "status": {"new": "superseded_by_exporter_edit", "old": "ogd_advice"} + "status": {"new": "superseded_by_exporter_edit", "old": major_editable_status} } assert status_change_audit_entry.verb == "updated_status" + @parameterized.expand(CaseStatusEnum.can_not_invoke_major_edit_statuses()) + def test_create_amendment_failure(self, non_major_editable_status): + original_application = StandardApplicationFactory( + status=CaseStatus.objects.get(status=non_major_editable_status), + ) + original_application.queues.add(Queue.objects.first()) + original_application.save() + + exporter_user = ExporterUser.objects.first() + with self.assertRaises(AmendmentError): + amendment_application = original_application.create_amendment(exporter_user) + def test_clone(self): original_application = StandardApplicationFactory( activity="Trade", @@ -574,3 +598,34 @@ def test_clone_with_party_override(self): cloned by default or not and adjust PartyOnApplication.clone_* attributes accordingly. """ + + +@pytest.mark.requires_transactions +class TestStandardApplicationRaceConditions(TransactionTestCase): + def test_create_amendment_race_condition_success(self): + BaseUserFactory(id=SystemUser.id) + + original_application = StandardApplicationFactory() + + original_application = StandardApplication.objects.get(pk=original_application.pk) + same_application = StandardApplication.objects.get(pk=original_application.pk) + + exporter_user = ExporterUser.objects.first() + + def _create_amendment(application): + return application.create_amendment(exporter_user) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future_1 = executor.submit(_create_amendment, original_application) + future_2 = executor.submit(_create_amendment, same_application) + + amendment_1 = future_1.result() + amendment_2 = future_2.result() + + self.assertEqual( + StandardApplication.objects.count(), + 2, + ) + self.assertEqual(amendment_1, amendment_2) + self.assertEqual(amendment_1.amendment_of.get_case(), original_application.get_case()) + self.assertEqual(amendment_2.amendment_of.get_case(), original_application.get_case()) diff --git a/pytest.ini b/pytest.ini index 76fc6e4346..4be4053f05 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,3 +9,4 @@ markers = elasticsearch: Tests that use elasticsearch seeding: tests that check seed commands performance: tests that check performance + requires_transactions: tests that require fine grained controls of transactions From 3c45680954e765fe606fc5981d63ca4b2518d376 Mon Sep 17 00:00:00 2001 From: "mark.j0hnst0n" Date: Thu, 5 Sep 2024 13:09:05 +0100 Subject: [PATCH 6/6] removed gov notifications and added system queue counts back in add comment back in remove gov notifications removing a bit more add system queue counts back in test update remove commented out code and add back in code which sends notification to other exporters remove unused variable update application_helpers to remove notification update tests remove updated queue tests add comment back in --- .../libraries/application_helpers.py | 1 - api/applications/views/goods.py | 1 - api/audit_trail/managers.py | 14 --- api/audit_trail/models.py | 5 +- api/audit_trail/service.py | 2 - api/cases/helpers.py | 13 +-- api/cases/libraries/delete_notifications.py | 6 +- api/cases/managers.py | 12 +-- api/cases/serializers.py | 13 +-- api/cases/tests/test_case_search.py | 91 ------------------- api/cases/views/search/activity.py | 5 - .../tests/test_gov_user_notifications.py | 77 ---------------- api/open_general_licences/views.py | 5 - api/queues/constants.py | 1 - api/queues/service.py | 2 - .../migrations/0007_delete_govnotification.py | 16 ++++ api/users/models.py | 18 ---- test_helpers/clients.py | 2 - 18 files changed, 21 insertions(+), 263 deletions(-) create mode 100644 api/users/migrations/0007_delete_govnotification.py diff --git a/api/applications/libraries/application_helpers.py b/api/applications/libraries/application_helpers.py index b6ff562965..398b148b1a 100644 --- a/api/applications/libraries/application_helpers.py +++ b/api/applications/libraries/application_helpers.py @@ -59,5 +59,4 @@ def create_submitted_audit(user, application, old_status: str, additional_payloa target=application.get_case(), payload=payload, ignore_case_status=True, - send_notification=False, ) diff --git a/api/applications/views/goods.py b/api/applications/views/goods.py index 936cc519da..e52199c2fb 100644 --- a/api/applications/views/goods.py +++ b/api/applications/views/goods.py @@ -363,7 +363,6 @@ def put(self, request, pk, good_on_application_pk): "good_name": good_on_application.good.name, }, ignore_case_status=True, - send_notification=False, ) return JsonResponse( diff --git a/api/audit_trail/managers.py b/api/audit_trail/managers.py index fc4f11c6ae..c0295a52f9 100644 --- a/api/audit_trail/managers.py +++ b/api/audit_trail/managers.py @@ -5,8 +5,6 @@ from api.cases.models import Case from api.staticdata.statuses.libraries.case_status_validate import is_case_status_draft -from api.users.models import ExporterUser -from api.users.models import GovUser class AuditQuerySet(GFKQuerySet): @@ -26,28 +24,16 @@ def create(self, *args, **kwargs): """ Create an audit entry for a model target: the target object (such as a case) - actor: the object causing the audit entry (such as a user) - send_notification: certain scenarios alert internal users, default is True ignore_case_status: draft cases become audited, default is False """ - # TODO: decouple notifications and audit (signals?) target = kwargs.get("target") - actor = kwargs.get("actor") - send_notification = kwargs.pop("send_notification", True) ignore_case_status = kwargs.pop("ignore_case_status", False) if isinstance(target, Case): # Only audit cases if their status is not draft if not is_case_status_draft(target.status.status) or ignore_case_status: audit = super(AuditManager, self).create(*args, **kwargs) - - # Notify gov users when an exporter updates a case - if isinstance(actor, ExporterUser) and send_notification: - for gov_user in GovUser.objects.all(): - gov_user.send_notification(content_object=audit, case=target) - return audit - return None return super(AuditManager, self).create(*args, **kwargs) diff --git a/api/audit_trail/models.py b/api/audit_trail/models.py index 2090e7720b..38a9142415 100644 --- a/api/audit_trail/models.py +++ b/api/audit_trail/models.py @@ -1,6 +1,6 @@ import uuid -from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.db import models from django.utils import timesince @@ -9,7 +9,6 @@ from api.audit_trail.managers import AuditManager from api.audit_trail.enums import AuditType from api.common.models import TimestampableModel -from api.users.models import GovNotification class Audit(TimestampableModel): @@ -51,8 +50,6 @@ class Audit(TimestampableModel): objects = AuditManager() - notifications = GenericRelation(GovNotification, related_query_name="audit") - class Meta: ordering = ("-created_at",) diff --git a/api/audit_trail/service.py b/api/audit_trail/service.py index 6b062dcc90..f0df6a9081 100644 --- a/api/audit_trail/service.py +++ b/api/audit_trail/service.py @@ -28,7 +28,6 @@ def create( target: Optional[Case] = None, payload=None, ignore_case_status: bool = False, - send_notification: bool = True, ) -> Optional[Audit]: if not payload: payload = {} @@ -50,7 +49,6 @@ def create( target=target, payload=payload, ignore_case_status=ignore_case_status, - send_notification=send_notification, ) diff --git a/api/cases/helpers.py b/api/cases/helpers.py index dd3854334e..fa98795b33 100644 --- a/api/cases/helpers.py +++ b/api/cases/helpers.py @@ -2,7 +2,7 @@ from api.audit_trail.enums import AuditType from api.common.dates import is_bank_holiday, is_weekend -from api.users.models import BaseUser, GovUser, GovNotification +from api.users.models import BaseUser, GovUser from api.users.enums import SystemUser @@ -22,17 +22,6 @@ def get_assigned_as_case_officer_case_ids(user: GovUser): return Case.objects.filter(case_officer=user).values_list("id", flat=True) -def get_updated_case_ids(user: GovUser): - """ - Get the cases that have raised notifications when updated by an exporter - """ - assigned_to_user_case_ids = get_assigned_to_user_case_ids(user) - assigned_as_case_officer_case_ids = get_assigned_as_case_officer_case_ids(user) - cases = assigned_to_user_case_ids.union(assigned_as_case_officer_case_ids) - - return GovNotification.objects.filter(user_id=user.pk, case__id__in=cases).values_list("case__id", flat=True) - - def working_days_in_range(start_date, end_date): dates_in_range = [start_date + timedelta(n) for n in range((end_date - start_date).days)] return len([date for date in dates_in_range if (not is_bank_holiday(date) and not is_weekend(date))]) diff --git a/api/cases/libraries/delete_notifications.py b/api/cases/libraries/delete_notifications.py index be78d14fa1..8fb8dd05f0 100644 --- a/api/cases/libraries/delete_notifications.py +++ b/api/cases/libraries/delete_notifications.py @@ -1,4 +1,4 @@ -from api.users.models import ExporterNotification, ExporterUser, GovNotification, GovUser +from api.users.models import ExporterNotification, ExporterUser def delete_exporter_notifications(user: ExporterUser, organisation_id, objects: list): @@ -6,7 +6,3 @@ def delete_exporter_notifications(user: ExporterUser, organisation_id, objects: ExporterNotification.objects.filter( user=user.baseuser_ptr, organisation_id=organisation_id, object_id__in=id_list ).delete() - - -def delete_gov_user_notifications(user: GovUser, id_list: list): - GovNotification.objects.filter(user=user.baseuser_ptr, object_id__in=id_list).delete() diff --git a/api/cases/managers.py b/api/cases/managers.py index 2e9b3bae34..406cd1b3e9 100644 --- a/api/cases/managers.py +++ b/api/cases/managers.py @@ -12,7 +12,7 @@ ) from api.cases.enums import AdviceLevel, CaseTypeEnum -from api.cases.helpers import get_updated_case_ids, get_assigned_to_user_case_ids, get_assigned_as_case_officer_case_ids +from api.cases.helpers import get_assigned_to_user_case_ids, get_assigned_as_case_officer_case_ids from api.common.enums import SortOrder from api.cases.enums import AdviceType from api.compliance.enums import COMPLIANCE_CASE_ACCEPTABLE_GOOD_CONTROL_CODES @@ -21,7 +21,6 @@ ALL_CASES_QUEUE_ID, MY_TEAMS_QUEUES_CASES_ID, OPEN_CASES_QUEUE_ID, - UPDATED_CASES_QUEUE_ID, MY_ASSIGNED_CASES_QUEUE_ID, MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID, ) @@ -52,13 +51,6 @@ def in_queue(self, queue_id): def in_team(self, team_id): return self.filter(queues__team_id=team_id).distinct() - def is_updated(self, user): - """ - Get the cases that have raised notifications when updated by an exporter - """ - updated_case_ids = get_updated_case_ids(user) - return self.filter(id__in=updated_case_ids) - def assigned_to_user(self, user, queue_id=None): assigned_to_user_case_ids = get_assigned_to_user_case_ids(user, queue_id) return self.filter(id__in=assigned_to_user_case_ids) @@ -223,8 +215,6 @@ def filter_based_on_queue(self, queue_id, team_id, user): return self.in_team(team_id=team_id) elif queue_id == OPEN_CASES_QUEUE_ID: return self.is_open() - elif queue_id == UPDATED_CASES_QUEUE_ID: - return self.is_updated(user=user) elif queue_id == MY_ASSIGNED_CASES_QUEUE_ID: return self.assigned_to_user(user=user).not_terminal() elif queue_id == MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID: diff --git a/api/cases/serializers.py b/api/cases/serializers.py index 5fed53b11c..af87bcb402 100644 --- a/api/cases/serializers.py +++ b/api/cases/serializers.py @@ -1,4 +1,3 @@ -from django.contrib.contenttypes.models import ContentType from rest_framework import serializers from api.applications.libraries.get_applications import get_application @@ -6,7 +5,6 @@ from api.applications.serializers.advice import AdviceViewSerializer, CountersignDecisionAdviceViewSerializer from api.staticdata.statuses.serializers import CaseSubStatusSerializer -from api.audit_trail.models import Audit from api.cases.enums import ( CaseTypeTypeEnum, AdviceType, @@ -49,7 +47,7 @@ from api.staticdata.statuses.enums import CaseStatusEnum from api.teams.serializers import TeamSerializer from api.users.enums import UserStatuses -from api.users.models import BaseUser, GovUser, GovNotification, ExporterUser +from api.users.models import BaseUser, GovUser, ExporterUser from api.users.serializers import ( BaseUserViewSerializer, GovUserViewSerializer, @@ -268,7 +266,6 @@ class CaseDetailSerializer(serializers.ModelSerializer): all_flags = serializers.SerializerMethodField() case_officer = GovUserSimpleSerializer(read_only=True) copy_of = serializers.SerializerMethodField() - audit_notification = serializers.SerializerMethodField() sla_days = serializers.IntegerField() sla_remaining_days = serializers.IntegerField() advice = AdviceViewSerializer(many=True) @@ -295,7 +292,6 @@ class Meta: "countersign_advice", "all_flags", "case_officer", - "audit_notification", "reference_code", "copy_of", "sla_days", @@ -380,13 +376,6 @@ def get_all_flags(self, instance): """ return get_ordered_flags(instance, self.team, distinct=True) - def get_audit_notification(self, instance): - content_type = ContentType.objects.get_for_model(Audit) - queryset = GovNotification.objects.filter(user_id=self.user.pk, content_type=content_type, case=instance) - - if queryset.exists(): - return {"audit_id": queryset.first().object_id} - def get_copy_of(self, instance): if instance.copy_of and instance.copy_of.status.status != CaseStatusEnum.DRAFT: return CaseCopyOfSerializer(instance.copy_of).data diff --git a/api/cases/tests/test_case_search.py b/api/cases/tests/test_case_search.py index b0596915b7..6b20508dc8 100644 --- a/api/cases/tests/test_case_search.py +++ b/api/cases/tests/test_case_search.py @@ -5,8 +5,6 @@ from parameterized import parameterized from rest_framework import status -from api.audit_trail.models import Audit -from api.audit_trail.enums import AuditType from api.applications.tests.factories import ( DenialMatchOnApplicationFactory, DenialEntityFactory, @@ -22,7 +20,6 @@ from api.picklists.enums import PicklistType from api.cases.tests.factories import FinalAdviceFactory from api.queues.constants import ( - UPDATED_CASES_QUEUE_ID, SYSTEM_QUEUES, ALL_CASES_QUEUE_ID, ) @@ -33,10 +30,8 @@ from api.staticdata.report_summaries.tests.factories import ReportSummaryPrefixFactory, ReportSummarySubjectFactory from api.staticdata.statuses.libraries.get_case_status import get_case_status_by_status from api.staticdata.statuses.models import CaseSubStatus -from api.users.tests.factories import GovUserFactory from test_helpers.clients import DataTestClient from api.users.enums import UserStatuses -from api.users.libraries.user_to_token import user_to_token from api.users.models import GovUser from api.cases.tests import factories from api.cases.enums import AdviceType @@ -821,92 +816,6 @@ def test_get_cases_filter_by_includes_refusal_recommendation_not_met(self): self.assertEqual(len(response_data), 0) -class UpdatedCasesQueueTests(DataTestClient): - def setUp(self): - super().setUp() - - self.case = self.create_standard_application_case(self.organisation).get_case() - self.old_status = self.case.status.status - self.case.queues.set([self.queue]) - self.case_assignment = CaseAssignment.objects.create(case=self.case, queue=self.queue, user=self.gov_user) - self.case.status = get_case_status_by_status(CaseStatusEnum.APPLICANT_EDITING) - self.case.save() - - self.audit = Audit.objects.create( - actor=self.exporter_user, - verb=AuditType.UPDATED_STATUS, - target=self.case, - payload={"status": {"new": CaseStatusEnum.APPLICANT_EDITING, "old": self.old_status}}, - ) - self.gov_user.send_notification(content_object=self.audit, case=self.case) - - self.url = f'{reverse("cases:search")}?queue_id={UPDATED_CASES_QUEUE_ID}' - - def test_get_cases_on_updated_cases_queue_when_user_is_assigned_to_a_case_returns_expected_cases(self): - # Create another case that does not have an update - case = self.create_standard_application_case(self.organisation).get_case() - case.queues.set([self.queue]) - case_assignment = CaseAssignment.objects.create(case=case, queue=self.queue, user=self.gov_user) - self.gov_user.send_notification(content_object=self.audit, case=case) - - response = self.client.get(self.url, **self.gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 2) # Count is 2 as another case is created in setup - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - def test_get_cases_on_updated_cases_queue_non_team_queue(self): - other_team = self.create_team("other_team") - self.gov_user.team = other_team - - response = self.client.get(self.url, **self.gov_headers) - response_data = response.json()["results"]["cases"] - - self.assertEqual(len(response_data), 1) - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - def test_get_cases_on_updated_cases_queue_when_user_is_not_assigned_to_a_case_returns_no_cases(self): - other_user = GovUserFactory( - baseuser_ptr__email="test2@mail.com", - baseuser_ptr__first_name="John", - baseuser_ptr__last_name="Smith", - team=self.team, - ) - gov_headers = {"HTTP_GOV_USER_TOKEN": user_to_token(other_user.baseuser_ptr)} - - response = self.client.get(self.url, **gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 0) - - def test_get_cases_on_updated_cases_queue_when_user_is_assigned_as_case_officer_returns_expected_cases(self): - CaseAssignment.objects.filter(case=self.case, queue=self.queue).delete() - self.case.case_officer = self.gov_user - self.case.save() - - response = self.client.get(self.url, **self.gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 1) - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - def test_get_cases_on_updated_cases_queue_when_user_is_assigned_to_case_and_as_case_officer_returns_expected_cases( - self, - ): - self.case.case_officer = self.gov_user - self.case.save() - - response = self.client.get(self.url, **self.gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 1) - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - class CaseOrderingOnQueueTests(DataTestClient): def test_all_cases_queue_returns_cases_in_expected_order(self): """Test All cases queue returns cases in expected order (newest first).""" diff --git a/api/cases/views/search/activity.py b/api/cases/views/search/activity.py index 4f179eaae1..8a3e9bb7e3 100644 --- a/api/cases/views/search/activity.py +++ b/api/cases/views/search/activity.py @@ -5,10 +5,8 @@ from api.audit_trail import service as audit_trail_service from api.audit_trail.serializers import AuditSerializer -from api.cases.libraries.delete_notifications import delete_gov_user_notifications from api.cases.models import Case from api.core.authentication import GovAuthentication -from api.users.models import GovUser class CaseActivityView(APIView): @@ -23,9 +21,6 @@ def get(self, request, pk): data = AuditSerializer(audit_trail_qs, many=True).data - # Delete notifications related to audits - if isinstance(request.user, GovUser): - delete_gov_user_notifications(request.user, [obj["id"] for obj in data]) return JsonResponse(data={"activity": data}, status=status.HTTP_200_OK) diff --git a/api/gov_users/tests/test_gov_user_notifications.py b/api/gov_users/tests/test_gov_user_notifications.py index bd9ed17dc0..6b46796302 100644 --- a/api/gov_users/tests/test_gov_user_notifications.py +++ b/api/gov_users/tests/test_gov_user_notifications.py @@ -3,10 +3,7 @@ from rest_framework import status from api.audit_trail.models import Audit -from api.audit_trail.enums import AuditType from api.staticdata.statuses.enums import CaseStatusEnum -from api.staticdata.statuses.libraries.get_case_status import get_case_status_by_status -from api.users.models import GovNotification from test_helpers.clients import DataTestClient @@ -23,79 +20,5 @@ def test_edit_application_creates_new_audit_notification_success(self): response = self.client.put(url, data, **self.exporter_headers) self.case.refresh_from_db() - case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() self.assertEqual(status.HTTP_200_OK, response.status_code) - # There can only be one notification per gov user's case - # (the notification for updating the name overwrites any prior notification) - self.assertEqual(case_audit_notification_count, 1) - - def test_edit_application_updates_previous_audit_notification_success(self): - audit = Audit.objects.create( - actor=self.exporter_user, - verb=AuditType.UPDATED_APPLICATION_NAME, - target=self.case.get_case(), - payload={"old_name": "old_app_name", "new_name": "new_app_name"}, - ) - - self.gov_user.send_notification(content_object=audit, case=self.case) - prev_case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - - url = reverse("applications:application", kwargs={"pk": self.case.id}) - data = {"name": "even newer app name!"} - - response = self.client.put(url, data, **self.exporter_headers) - self.case.refresh_from_db() - case_audit_notification = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ) - - self.assertEqual(status.HTTP_200_OK, response.status_code) - self.assertEqual(case_audit_notification.count(), prev_case_audit_notification_count) - self.assertEqual(data["name"], case_audit_notification.last().content_object.payload["new_name"]) - self.assertNotEqual(case_audit_notification.last().content_object, audit) - - def test_edit_application_as_gov_user_does_not_create_an_audit_notification_success(self): - prev_case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - url = reverse("caseworker_applications:change_status", kwargs={"pk": self.case.id}) - data = {"status": "under_review"} - - response = self.client.post(url, data, **self.gov_headers) - self.case.refresh_from_db() - case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - - self.assertEqual(status.HTTP_200_OK, response.status_code) - self.assertEqual(case_audit_notification_count, prev_case_audit_notification_count) - - def test_get_case_activity_deletes_audit_notification_success(self): - self.case = self.create_standard_application_case(self.organisation).get_case() - old_status = self.case.status.status - self.case.status = get_case_status_by_status(CaseStatusEnum.APPLICANT_EDITING) - self.case.save() - audit = Audit.objects.create( - actor=self.exporter_user, - verb=AuditType.UPDATED_STATUS, - target=self.case, - payload={"status": {"new": CaseStatusEnum.APPLICANT_EDITING, "old": old_status}}, - ) - - self.gov_user.send_notification(content_object=audit, case=self.case) - url = reverse("cases:activity", kwargs={"pk": self.case.id}) - - response = self.client.get(url, **self.gov_headers) - case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - - self.assertEqual(response.status_code, status.HTTP_200_OK) - case_activity = response.json()["activity"] - self.assertEqual(len(case_activity), 2) - self.assertEqual(case_audit_notification_count, 1) diff --git a/api/open_general_licences/views.py b/api/open_general_licences/views.py index 29c9fc5827..c5317ae896 100644 --- a/api/open_general_licences/views.py +++ b/api/open_general_licences/views.py @@ -19,7 +19,6 @@ from api.organisations.models import Site from api.staticdata.statuses.enums import CaseStatusEnum from api.users.enums import UserType -from api.users.models import GovUser, GovNotification class OpenGeneralLicenceList(ListCreateAPIView): @@ -191,10 +190,6 @@ def get(self, request, pk): data = AuditSerializer(audit_trail_qs, many=True).data - if isinstance(request.user, GovUser): - # Delete notifications related to audits - GovNotification.objects.filter(user_id=request.user.pk, object_id__in=[obj["id"] for obj in data]).delete() - filters = audit_trail_service.get_objects_activity_filters(pk, content_type) return JsonResponse(data={"activity": data, "filters": filters}, status=status.HTTP_200_OK) diff --git a/api/queues/constants.py b/api/queues/constants.py index 5f677d4705..7d75e3afd9 100644 --- a/api/queues/constants.py +++ b/api/queues/constants.py @@ -23,7 +23,6 @@ ALL_CASES_QUEUE_ID: ALL_CASES_QUEUE_NAME, OPEN_CASES_QUEUE_ID: OPEN_CASES_QUEUE_NAME, MY_TEAMS_QUEUES_CASES_ID: MY_TEAMS_QUEUES_CASES_NAME, - UPDATED_CASES_QUEUE_ID: UPDATED_CASES_QUEUE_NAME, } NON_WORK_QUEUES = { diff --git a/api/queues/service.py b/api/queues/service.py index ec2b28719d..d39394601f 100644 --- a/api/queues/service.py +++ b/api/queues/service.py @@ -10,7 +10,6 @@ MY_TEAMS_QUEUES_CASES_ID, MY_ASSIGNED_CASES_QUEUE_ID, MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID, - UPDATED_CASES_QUEUE_ID, SYSTEM_QUEUES, ) from api.queues.models import Queue @@ -85,7 +84,6 @@ def _get_system_queues_case_count(user) -> Dict: MY_TEAMS_QUEUES_CASES_ID: case_qs.in_team(team_id=user.team.id).count(), MY_ASSIGNED_CASES_QUEUE_ID: case_qs.assigned_to_user(user=user).not_terminal().count(), MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID: case_qs.assigned_as_case_officer(user=user).not_terminal().count(), - UPDATED_CASES_QUEUE_ID: case_qs.is_updated(user=user).count(), } return cases_count diff --git a/api/users/migrations/0007_delete_govnotification.py b/api/users/migrations/0007_delete_govnotification.py new file mode 100644 index 0000000000..3f55716f4a --- /dev/null +++ b/api/users/migrations/0007_delete_govnotification.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.14 on 2024-09-06 10:22 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0006_alter_userorganisationrelationship_user"), + ] + + operations = [ + migrations.DeleteModel( + name="GovNotification", + ), + ] diff --git a/api/users/models.py b/api/users/models.py index 84e406f459..2d714a3497 100644 --- a/api/users/models.py +++ b/api/users/models.py @@ -115,10 +115,6 @@ class ExporterNotification(BaseNotification): organisation = models.ForeignKey("organisations.Organisation", on_delete=models.CASCADE, null=False) -class GovNotification(BaseNotification): - pass - - class BaseUserCompatMixin: baseuser_ptr: BaseUser @@ -212,20 +208,6 @@ def unassign_from_cases(self): """ self.case_assignments.filter(user=self).delete() - def send_notification(self, content_object, case): - from api.audit_trail.models import Audit - - if isinstance(content_object, Audit): - # There can only be one notification per gov user's case - # If a notification for that gov user's case already exists, update the case activity it points to - try: - content_type = ContentType.objects.get_for_model(Audit) - notification = GovNotification.objects.get(user=self.baseuser_ptr, content_type=content_type, case=case) - notification.content_object = content_object - notification.save() - except GovNotification.DoesNotExist: - GovNotification.objects.create(user=self.baseuser_ptr, content_object=content_object, case=case) - def has_permission(self, permission): user_permissions = self.role.permissions.values_list("id", flat=True) return permission.name in user_permissions diff --git a/test_helpers/clients.py b/test_helpers/clients.py index 668724eba8..933f1f4cb2 100644 --- a/test_helpers/clients.py +++ b/test_helpers/clients.py @@ -886,7 +886,6 @@ def create_audit( target=None, payload=None, ignore_case_status=False, - send_notification=True, ): if not payload: payload = {} @@ -898,7 +897,6 @@ def create_audit( target=target, payload=payload, ignore_case_status=ignore_case_status, - send_notification=send_notification, ) def add_users(self, count=3):