diff --git a/.circleci/config.yml b/.circleci/config.yml index aefa286960..a2238499a9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -134,7 +134,7 @@ jobs: - run: name: Run tests command: | - pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser" + pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser and not requires_transactions" - upload_code_coverage: alias: tests @@ -154,7 +154,7 @@ jobs: - run: name: Run tests on Postgres 13 command: | - pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser" + pipenv run pytest --circleci-parallelize --cov=. --cov-report xml --cov-config=.coveragerc --ignore lite_routing --ignore api/anonymised_db_dumps -k "not seeding and not elasticsearch and not performance and not migration and not db_anonymiser and not requires_transactions" - upload_code_coverage: alias: tests_dbt_platform @@ -232,6 +232,44 @@ jobs: - upload_code_coverage: alias: anonymised_db_dump_tests_dbt_platform + requires_transactions_tests: + docker: + - <<: *image_python + - <<: *image_postgres + - <<: *image_opensearch_v1 + - <<: *image_redis + working_directory: ~/lite-api + environment: + <<: *common_env_vars + LITE_API_ENABLE_ES: True + steps: + - setup + - run: + name: Run requiring transactions tests + command: | + pipenv run pytest --cov=. --cov-report xml --cov-config=.coveragerc -k requires_transactions + - upload_code_coverage: + alias: requires_transactions_tests + + requires_transactions_tests_dbt_platform: + docker: + - <<: *image_python + - <<: *image_postgres13 + - <<: *image_opensearch + - <<: *image_redis + working_directory: ~/lite-api + environment: + <<: *common_env_vars + LITE_API_ENABLE_ES: True + steps: + - setup + - run: + name: Run requiring transactions tests on Postgres 13 + command: | + pipenv run pytest --cov=. --cov-report xml --cov-config=.coveragerc -k requires_transactions + - upload_code_coverage: + alias: requires_transactions_tests_dbt_platform + migration_tests: docker: - <<: *image_python @@ -615,7 +653,10 @@ workflows: - open_search_tests - migration_tests - lite_routing_tests + - requires_transactions_tests - check-lite-routing-sha - e2e_tests - anonymised_db_dump_tests - anonymised_db_dump_tests_dbt_platform + - requires_transactions_tests + - requires_transactions_tests_dbt_platform diff --git a/api/applications/exceptions.py b/api/applications/exceptions.py new file mode 100644 index 0000000000..218443f118 --- /dev/null +++ b/api/applications/exceptions.py @@ -0,0 +1,2 @@ +class AmendmentError(Exception): + pass diff --git a/api/applications/libraries/application_helpers.py b/api/applications/libraries/application_helpers.py index b6ff562965..398b148b1a 100644 --- a/api/applications/libraries/application_helpers.py +++ b/api/applications/libraries/application_helpers.py @@ -59,5 +59,4 @@ def create_submitted_audit(user, application, old_status: str, additional_payloa target=application.get_case(), payload=payload, ignore_case_status=True, - send_notification=False, ) diff --git a/api/applications/models.py b/api/applications/models.py index cf86da3a11..484a4e32ac 100644 --- a/api/applications/models.py +++ b/api/applications/models.py @@ -1,3 +1,4 @@ +import logging import uuid from django.contrib.postgres.fields import ArrayField @@ -14,6 +15,7 @@ NSGListType, ) from api.appeals.models import Appeal +from api.applications.exceptions import AmendmentError from api.applications.managers import BaseApplicationManager from api.applications.libraries.application_helpers import create_submitted_audit from api.audit_trail.models import AuditType @@ -49,6 +51,9 @@ from lite_routing.routing_rules_internal.enums import QueuesEnum +logger = logging.getLogger(__name__) + + class ApplicationException(APIException): def __init__(self, data): super().__init__(data) @@ -375,6 +380,30 @@ def clone(self, exclusions=None, **overrides): @transaction.atomic def create_amendment(self, user): + # It's possible that we've arrived here multiple times simultaneously because multiple requests to create an + # amendment have been made at the same time. + # The views that call this may have already checked the status but it's possible that when they checked it was + # before the status has actually been updated on the application so we've got a potential race condition going. + # What we want to do here is lock the row to make sure that only one request can be trying to create an + # amendment at any one time and if we have multiple requests clashing we'll re-check the status once the lock + # has been removed from any other racing requests. + # To be helpful to the caller we'll return the amendment that did happen even if there are multiple requests. + original = StandardApplication.objects.select_for_update().get(pk=self.pk) + if not CaseStatusEnum.can_invoke_major_edit(original.status.status): + logger.warning( + "Attempted to create an amendment from an already amended application %s with status %s", + original.pk, + original.status, + ) + if original.superseded_by: + logger.info( + "Found an amendment already: %s for the application: %s", + original.superseded_by.pk, + original.pk, + ) + return StandardApplication.objects.get(pk=original.superseded_by.pk) + raise AmendmentError(f"Failed to create an amendment from {original.pk}") + amendment_application = self.clone(amendment_of=self) CaseQueue.objects.filter(case=self.case_ptr).delete() audit_trail_service.create( diff --git a/api/applications/tests/test_models.py b/api/applications/tests/test_models.py index 6d9a75bd5e..ba797065cd 100644 --- a/api/applications/tests/test_models.py +++ b/api/applications/tests/test_models.py @@ -1,9 +1,16 @@ +import concurrent.futures +import pytest + from django.forms import model_to_dict +from django.test import TransactionTestCase from django.utils import timezone +from parameterized import parameterized + from test_helpers.clients import DataTestClient from api.appeals.tests.factories import AppealFactory +from api.applications.exceptions import AmendmentError from api.audit_trail.models import Audit from api.cases.models import CaseType, Queue from api.flags.models import Flag @@ -32,8 +39,12 @@ from api.staticdata.control_list_entries.models import ControlListEntry from api.staticdata.report_summaries.models import ReportSummary, ReportSummaryPrefix, ReportSummarySubject from api.staticdata.statuses.models import CaseStatus, CaseSubStatus -from api.staticdata.statuses.enums import CaseStatusEnum +from api.staticdata.statuses.enums import ( + CaseStatusEnum, +) +from api.users.enums import SystemUser from api.users.models import ExporterUser +from api.users.tests.factories import BaseUserFactory class TestBaseApplication(DataTestClient): @@ -79,9 +90,10 @@ def test_on_submit_amendment_application(self): class TestStandardApplication(DataTestClient): - def test_create_amendment(self): + @parameterized.expand(CaseStatusEnum.can_invoke_major_edit_statuses()) + def test_create_amendment(self, major_editable_status): original_application = StandardApplicationFactory( - status=CaseStatus.objects.get(status="ogd_advice"), + status=CaseStatus.objects.get(status=major_editable_status), ) original_application.queues.add(Queue.objects.first()) original_application.save() @@ -109,10 +121,22 @@ def test_create_amendment(self): assert amendment_audit_entry.actor == exporter_user status_change_audit_entry = audit_entries[0] assert status_change_audit_entry.payload == { - "status": {"new": "superseded_by_exporter_edit", "old": "ogd_advice"} + "status": {"new": "superseded_by_exporter_edit", "old": major_editable_status} } assert status_change_audit_entry.verb == "updated_status" + @parameterized.expand(CaseStatusEnum.can_not_invoke_major_edit_statuses()) + def test_create_amendment_failure(self, non_major_editable_status): + original_application = StandardApplicationFactory( + status=CaseStatus.objects.get(status=non_major_editable_status), + ) + original_application.queues.add(Queue.objects.first()) + original_application.save() + + exporter_user = ExporterUser.objects.first() + with self.assertRaises(AmendmentError): + amendment_application = original_application.create_amendment(exporter_user) + def test_clone(self): original_application = StandardApplicationFactory( activity="Trade", @@ -574,3 +598,34 @@ def test_clone_with_party_override(self): cloned by default or not and adjust PartyOnApplication.clone_* attributes accordingly. """ + + +@pytest.mark.requires_transactions +class TestStandardApplicationRaceConditions(TransactionTestCase): + def test_create_amendment_race_condition_success(self): + BaseUserFactory(id=SystemUser.id) + + original_application = StandardApplicationFactory() + + original_application = StandardApplication.objects.get(pk=original_application.pk) + same_application = StandardApplication.objects.get(pk=original_application.pk) + + exporter_user = ExporterUser.objects.first() + + def _create_amendment(application): + return application.create_amendment(exporter_user) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future_1 = executor.submit(_create_amendment, original_application) + future_2 = executor.submit(_create_amendment, same_application) + + amendment_1 = future_1.result() + amendment_2 = future_2.result() + + self.assertEqual( + StandardApplication.objects.count(), + 2, + ) + self.assertEqual(amendment_1, amendment_2) + self.assertEqual(amendment_1.amendment_of.get_case(), original_application.get_case()) + self.assertEqual(amendment_2.amendment_of.get_case(), original_application.get_case()) diff --git a/api/applications/views/goods.py b/api/applications/views/goods.py index 936cc519da..e52199c2fb 100644 --- a/api/applications/views/goods.py +++ b/api/applications/views/goods.py @@ -363,7 +363,6 @@ def put(self, request, pk, good_on_application_pk): "good_name": good_on_application.good.name, }, ignore_case_status=True, - send_notification=False, ) return JsonResponse( diff --git a/api/audit_trail/managers.py b/api/audit_trail/managers.py index fc4f11c6ae..c0295a52f9 100644 --- a/api/audit_trail/managers.py +++ b/api/audit_trail/managers.py @@ -5,8 +5,6 @@ from api.cases.models import Case from api.staticdata.statuses.libraries.case_status_validate import is_case_status_draft -from api.users.models import ExporterUser -from api.users.models import GovUser class AuditQuerySet(GFKQuerySet): @@ -26,28 +24,16 @@ def create(self, *args, **kwargs): """ Create an audit entry for a model target: the target object (such as a case) - actor: the object causing the audit entry (such as a user) - send_notification: certain scenarios alert internal users, default is True ignore_case_status: draft cases become audited, default is False """ - # TODO: decouple notifications and audit (signals?) target = kwargs.get("target") - actor = kwargs.get("actor") - send_notification = kwargs.pop("send_notification", True) ignore_case_status = kwargs.pop("ignore_case_status", False) if isinstance(target, Case): # Only audit cases if their status is not draft if not is_case_status_draft(target.status.status) or ignore_case_status: audit = super(AuditManager, self).create(*args, **kwargs) - - # Notify gov users when an exporter updates a case - if isinstance(actor, ExporterUser) and send_notification: - for gov_user in GovUser.objects.all(): - gov_user.send_notification(content_object=audit, case=target) - return audit - return None return super(AuditManager, self).create(*args, **kwargs) diff --git a/api/audit_trail/models.py b/api/audit_trail/models.py index 2090e7720b..38a9142415 100644 --- a/api/audit_trail/models.py +++ b/api/audit_trail/models.py @@ -1,6 +1,6 @@ import uuid -from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation +from django.contrib.contenttypes.fields import GenericForeignKey from django.contrib.contenttypes.models import ContentType from django.db import models from django.utils import timesince @@ -9,7 +9,6 @@ from api.audit_trail.managers import AuditManager from api.audit_trail.enums import AuditType from api.common.models import TimestampableModel -from api.users.models import GovNotification class Audit(TimestampableModel): @@ -51,8 +50,6 @@ class Audit(TimestampableModel): objects = AuditManager() - notifications = GenericRelation(GovNotification, related_query_name="audit") - class Meta: ordering = ("-created_at",) diff --git a/api/audit_trail/service.py b/api/audit_trail/service.py index 6b062dcc90..f0df6a9081 100644 --- a/api/audit_trail/service.py +++ b/api/audit_trail/service.py @@ -28,7 +28,6 @@ def create( target: Optional[Case] = None, payload=None, ignore_case_status: bool = False, - send_notification: bool = True, ) -> Optional[Audit]: if not payload: payload = {} @@ -50,7 +49,6 @@ def create( target=target, payload=payload, ignore_case_status=ignore_case_status, - send_notification=send_notification, ) diff --git a/api/cases/helpers.py b/api/cases/helpers.py index dd3854334e..fa98795b33 100644 --- a/api/cases/helpers.py +++ b/api/cases/helpers.py @@ -2,7 +2,7 @@ from api.audit_trail.enums import AuditType from api.common.dates import is_bank_holiday, is_weekend -from api.users.models import BaseUser, GovUser, GovNotification +from api.users.models import BaseUser, GovUser from api.users.enums import SystemUser @@ -22,17 +22,6 @@ def get_assigned_as_case_officer_case_ids(user: GovUser): return Case.objects.filter(case_officer=user).values_list("id", flat=True) -def get_updated_case_ids(user: GovUser): - """ - Get the cases that have raised notifications when updated by an exporter - """ - assigned_to_user_case_ids = get_assigned_to_user_case_ids(user) - assigned_as_case_officer_case_ids = get_assigned_as_case_officer_case_ids(user) - cases = assigned_to_user_case_ids.union(assigned_as_case_officer_case_ids) - - return GovNotification.objects.filter(user_id=user.pk, case__id__in=cases).values_list("case__id", flat=True) - - def working_days_in_range(start_date, end_date): dates_in_range = [start_date + timedelta(n) for n in range((end_date - start_date).days)] return len([date for date in dates_in_range if (not is_bank_holiday(date) and not is_weekend(date))]) diff --git a/api/cases/libraries/delete_notifications.py b/api/cases/libraries/delete_notifications.py index be78d14fa1..8fb8dd05f0 100644 --- a/api/cases/libraries/delete_notifications.py +++ b/api/cases/libraries/delete_notifications.py @@ -1,4 +1,4 @@ -from api.users.models import ExporterNotification, ExporterUser, GovNotification, GovUser +from api.users.models import ExporterNotification, ExporterUser def delete_exporter_notifications(user: ExporterUser, organisation_id, objects: list): @@ -6,7 +6,3 @@ def delete_exporter_notifications(user: ExporterUser, organisation_id, objects: ExporterNotification.objects.filter( user=user.baseuser_ptr, organisation_id=organisation_id, object_id__in=id_list ).delete() - - -def delete_gov_user_notifications(user: GovUser, id_list: list): - GovNotification.objects.filter(user=user.baseuser_ptr, object_id__in=id_list).delete() diff --git a/api/cases/managers.py b/api/cases/managers.py index 2e9b3bae34..406cd1b3e9 100644 --- a/api/cases/managers.py +++ b/api/cases/managers.py @@ -12,7 +12,7 @@ ) from api.cases.enums import AdviceLevel, CaseTypeEnum -from api.cases.helpers import get_updated_case_ids, get_assigned_to_user_case_ids, get_assigned_as_case_officer_case_ids +from api.cases.helpers import get_assigned_to_user_case_ids, get_assigned_as_case_officer_case_ids from api.common.enums import SortOrder from api.cases.enums import AdviceType from api.compliance.enums import COMPLIANCE_CASE_ACCEPTABLE_GOOD_CONTROL_CODES @@ -21,7 +21,6 @@ ALL_CASES_QUEUE_ID, MY_TEAMS_QUEUES_CASES_ID, OPEN_CASES_QUEUE_ID, - UPDATED_CASES_QUEUE_ID, MY_ASSIGNED_CASES_QUEUE_ID, MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID, ) @@ -52,13 +51,6 @@ def in_queue(self, queue_id): def in_team(self, team_id): return self.filter(queues__team_id=team_id).distinct() - def is_updated(self, user): - """ - Get the cases that have raised notifications when updated by an exporter - """ - updated_case_ids = get_updated_case_ids(user) - return self.filter(id__in=updated_case_ids) - def assigned_to_user(self, user, queue_id=None): assigned_to_user_case_ids = get_assigned_to_user_case_ids(user, queue_id) return self.filter(id__in=assigned_to_user_case_ids) @@ -223,8 +215,6 @@ def filter_based_on_queue(self, queue_id, team_id, user): return self.in_team(team_id=team_id) elif queue_id == OPEN_CASES_QUEUE_ID: return self.is_open() - elif queue_id == UPDATED_CASES_QUEUE_ID: - return self.is_updated(user=user) elif queue_id == MY_ASSIGNED_CASES_QUEUE_ID: return self.assigned_to_user(user=user).not_terminal() elif queue_id == MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID: diff --git a/api/cases/serializers.py b/api/cases/serializers.py index 5fed53b11c..af87bcb402 100644 --- a/api/cases/serializers.py +++ b/api/cases/serializers.py @@ -1,4 +1,3 @@ -from django.contrib.contenttypes.models import ContentType from rest_framework import serializers from api.applications.libraries.get_applications import get_application @@ -6,7 +5,6 @@ from api.applications.serializers.advice import AdviceViewSerializer, CountersignDecisionAdviceViewSerializer from api.staticdata.statuses.serializers import CaseSubStatusSerializer -from api.audit_trail.models import Audit from api.cases.enums import ( CaseTypeTypeEnum, AdviceType, @@ -49,7 +47,7 @@ from api.staticdata.statuses.enums import CaseStatusEnum from api.teams.serializers import TeamSerializer from api.users.enums import UserStatuses -from api.users.models import BaseUser, GovUser, GovNotification, ExporterUser +from api.users.models import BaseUser, GovUser, ExporterUser from api.users.serializers import ( BaseUserViewSerializer, GovUserViewSerializer, @@ -268,7 +266,6 @@ class CaseDetailSerializer(serializers.ModelSerializer): all_flags = serializers.SerializerMethodField() case_officer = GovUserSimpleSerializer(read_only=True) copy_of = serializers.SerializerMethodField() - audit_notification = serializers.SerializerMethodField() sla_days = serializers.IntegerField() sla_remaining_days = serializers.IntegerField() advice = AdviceViewSerializer(many=True) @@ -295,7 +292,6 @@ class Meta: "countersign_advice", "all_flags", "case_officer", - "audit_notification", "reference_code", "copy_of", "sla_days", @@ -380,13 +376,6 @@ def get_all_flags(self, instance): """ return get_ordered_flags(instance, self.team, distinct=True) - def get_audit_notification(self, instance): - content_type = ContentType.objects.get_for_model(Audit) - queryset = GovNotification.objects.filter(user_id=self.user.pk, content_type=content_type, case=instance) - - if queryset.exists(): - return {"audit_id": queryset.first().object_id} - def get_copy_of(self, instance): if instance.copy_of and instance.copy_of.status.status != CaseStatusEnum.DRAFT: return CaseCopyOfSerializer(instance.copy_of).data diff --git a/api/cases/tests/test_case_search.py b/api/cases/tests/test_case_search.py index b0596915b7..6b20508dc8 100644 --- a/api/cases/tests/test_case_search.py +++ b/api/cases/tests/test_case_search.py @@ -5,8 +5,6 @@ from parameterized import parameterized from rest_framework import status -from api.audit_trail.models import Audit -from api.audit_trail.enums import AuditType from api.applications.tests.factories import ( DenialMatchOnApplicationFactory, DenialEntityFactory, @@ -22,7 +20,6 @@ from api.picklists.enums import PicklistType from api.cases.tests.factories import FinalAdviceFactory from api.queues.constants import ( - UPDATED_CASES_QUEUE_ID, SYSTEM_QUEUES, ALL_CASES_QUEUE_ID, ) @@ -33,10 +30,8 @@ from api.staticdata.report_summaries.tests.factories import ReportSummaryPrefixFactory, ReportSummarySubjectFactory from api.staticdata.statuses.libraries.get_case_status import get_case_status_by_status from api.staticdata.statuses.models import CaseSubStatus -from api.users.tests.factories import GovUserFactory from test_helpers.clients import DataTestClient from api.users.enums import UserStatuses -from api.users.libraries.user_to_token import user_to_token from api.users.models import GovUser from api.cases.tests import factories from api.cases.enums import AdviceType @@ -821,92 +816,6 @@ def test_get_cases_filter_by_includes_refusal_recommendation_not_met(self): self.assertEqual(len(response_data), 0) -class UpdatedCasesQueueTests(DataTestClient): - def setUp(self): - super().setUp() - - self.case = self.create_standard_application_case(self.organisation).get_case() - self.old_status = self.case.status.status - self.case.queues.set([self.queue]) - self.case_assignment = CaseAssignment.objects.create(case=self.case, queue=self.queue, user=self.gov_user) - self.case.status = get_case_status_by_status(CaseStatusEnum.APPLICANT_EDITING) - self.case.save() - - self.audit = Audit.objects.create( - actor=self.exporter_user, - verb=AuditType.UPDATED_STATUS, - target=self.case, - payload={"status": {"new": CaseStatusEnum.APPLICANT_EDITING, "old": self.old_status}}, - ) - self.gov_user.send_notification(content_object=self.audit, case=self.case) - - self.url = f'{reverse("cases:search")}?queue_id={UPDATED_CASES_QUEUE_ID}' - - def test_get_cases_on_updated_cases_queue_when_user_is_assigned_to_a_case_returns_expected_cases(self): - # Create another case that does not have an update - case = self.create_standard_application_case(self.organisation).get_case() - case.queues.set([self.queue]) - case_assignment = CaseAssignment.objects.create(case=case, queue=self.queue, user=self.gov_user) - self.gov_user.send_notification(content_object=self.audit, case=case) - - response = self.client.get(self.url, **self.gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 2) # Count is 2 as another case is created in setup - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - def test_get_cases_on_updated_cases_queue_non_team_queue(self): - other_team = self.create_team("other_team") - self.gov_user.team = other_team - - response = self.client.get(self.url, **self.gov_headers) - response_data = response.json()["results"]["cases"] - - self.assertEqual(len(response_data), 1) - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - def test_get_cases_on_updated_cases_queue_when_user_is_not_assigned_to_a_case_returns_no_cases(self): - other_user = GovUserFactory( - baseuser_ptr__email="test2@mail.com", - baseuser_ptr__first_name="John", - baseuser_ptr__last_name="Smith", - team=self.team, - ) - gov_headers = {"HTTP_GOV_USER_TOKEN": user_to_token(other_user.baseuser_ptr)} - - response = self.client.get(self.url, **gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 0) - - def test_get_cases_on_updated_cases_queue_when_user_is_assigned_as_case_officer_returns_expected_cases(self): - CaseAssignment.objects.filter(case=self.case, queue=self.queue).delete() - self.case.case_officer = self.gov_user - self.case.save() - - response = self.client.get(self.url, **self.gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 1) - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - def test_get_cases_on_updated_cases_queue_when_user_is_assigned_to_case_and_as_case_officer_returns_expected_cases( - self, - ): - self.case.case_officer = self.gov_user - self.case.save() - - response = self.client.get(self.url, **self.gov_headers) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - response_data = response.json()["results"]["cases"] - self.assertEqual(len(response_data), 1) - self.assertEqual(response_data[0]["id"], str(self.case.id)) - - class CaseOrderingOnQueueTests(DataTestClient): def test_all_cases_queue_returns_cases_in_expected_order(self): """Test All cases queue returns cases in expected order (newest first).""" diff --git a/api/cases/views/search/activity.py b/api/cases/views/search/activity.py index 4f179eaae1..8a3e9bb7e3 100644 --- a/api/cases/views/search/activity.py +++ b/api/cases/views/search/activity.py @@ -5,10 +5,8 @@ from api.audit_trail import service as audit_trail_service from api.audit_trail.serializers import AuditSerializer -from api.cases.libraries.delete_notifications import delete_gov_user_notifications from api.cases.models import Case from api.core.authentication import GovAuthentication -from api.users.models import GovUser class CaseActivityView(APIView): @@ -23,9 +21,6 @@ def get(self, request, pk): data = AuditSerializer(audit_trail_qs, many=True).data - # Delete notifications related to audits - if isinstance(request.user, GovUser): - delete_gov_user_notifications(request.user, [obj["id"] for obj in data]) return JsonResponse(data={"activity": data}, status=status.HTTP_200_OK) diff --git a/api/core/tests/test_validators.py b/api/core/tests/test_validators.py new file mode 100644 index 0000000000..aea29f02d0 --- /dev/null +++ b/api/core/tests/test_validators.py @@ -0,0 +1,20 @@ +import pytest +from rest_framework.exceptions import ValidationError +from api.core.validators import EdifactStringValidator + + +@pytest.mark.parametrize( + "value", + (("random value"), ("random-value"), ("random!value"), ("random-!.<>/%&*;+'(),.value")), +) +def test_edifactstringvalidator_valid(value): + validator = EdifactStringValidator() + result = validator(value) + assert result is None + + +@pytest.mark.parametrize("value", (("\r\n"), ("random_value"), ("random$value"), ("random@value"))) +def test_edifactstringvalidator_invalid(value): + validator = EdifactStringValidator() + with pytest.raises(ValidationError): + results = validator(value) diff --git a/api/core/validators.py b/api/core/validators.py index 3be8a04dac..9b4b159cc4 100644 --- a/api/core/validators.py +++ b/api/core/validators.py @@ -1,6 +1,6 @@ from django.utils.deconstruct import deconstructible from rest_framework.exceptions import ValidationError - +import re from api.staticdata.control_list_entries.models import ControlListEntry @@ -22,3 +22,23 @@ def __call__(self, value): ControlListEntry.objects.get(rating=value) except ControlListEntry.DoesNotExist: raise ValidationError(self.message, code=self.code) + + +class EdifactStringValidator: + message = "Undefined Error" + regex_string = r"^[a-zA-Z0-9 .,\-\)\(\/'+:=\?\!\"%&\*;\<\>]+$" + + def __call__(self, value): + match_regex = re.compile(self.regex_string) + is_value_valid = bool(match_regex.match(value)) + if not is_value_valid: + raise ValidationError(self.message) + + +class GoodNameValidator(EdifactStringValidator): + message = "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes" + + +class PartyAddressValidator(EdifactStringValidator): + regex_string = re.compile(r"^[a-zA-Z0-9 .,\-\)\(\/'+:=\?\!\"%&\*;\<\>\r\n]+$") + message = "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes" diff --git a/api/goods/serializers.py b/api/goods/serializers.py index 679a74d677..122c66e779 100644 --- a/api/goods/serializers.py +++ b/api/goods/serializers.py @@ -1,8 +1,8 @@ from rest_framework import serializers from rest_framework.relations import PrimaryKeyRelatedField - from api.core.helpers import str_to_bool from api.core.serializers import KeyValueChoiceField, ControlListEntryField, GoodControlReviewSerializer +from api.core.validators import GoodNameValidator from api.documents.libraries.process_document import process_document from api.goods.enums import ( FirearmCategory, @@ -308,7 +308,7 @@ def update(self, instance, validated_data): class GoodListSerializer(serializers.Serializer): id = serializers.UUIDField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() control_list_entries = ControlListEntrySerializer(many=True, allow_null=True) part_number = serializers.CharField() @@ -357,7 +357,7 @@ class GoodCreateSerializer(serializers.ModelSerializer): Because of this, each 'get' override must check the instance type before creating queries """ - name = serializers.CharField(error_messages={"blank": "Enter a product name"}) + name = serializers.CharField(error_messages={"blank": "Enter a product name"}, validators=[GoodNameValidator()]) description = serializers.CharField(max_length=280, allow_blank=True, required=False) is_good_controlled = KeyValueChoiceField(choices=GoodControlled.choices, allow_null=True) control_list_entries = ControlListEntryField(required=False, many=True, allow_null=True, allow_empty=True) @@ -690,7 +690,7 @@ def create(self, validated_data): class GoodDocumentViewSerializer(serializers.Serializer): id = serializers.UUIDField() created_at = serializers.DateTimeField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() user = ExporterUserSimpleSerializer() s3_key = serializers.SerializerMethodField() @@ -788,7 +788,7 @@ class Meta: class GoodSerializerInternal(serializers.Serializer): id = serializers.UUIDField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() part_number = serializers.CharField() no_part_number_comments = serializers.CharField() @@ -871,7 +871,7 @@ def get_user(self, instance): class GoodSerializerExporter(serializers.Serializer): id = serializers.UUIDField() - name = serializers.CharField() + name = serializers.CharField(validators=[GoodNameValidator()]) description = serializers.CharField() control_list_entries = ControlListEntryField(many=True) part_number = serializers.CharField() diff --git a/api/goods/tests/test_serializers.py b/api/goods/tests/test_serializers.py index 01f7026e3c..6f17034ab9 100644 --- a/api/goods/tests/test_serializers.py +++ b/api/goods/tests/test_serializers.py @@ -1,4 +1,5 @@ from datetime import datetime +from parameterized import parameterized from django.urls import reverse from django.utils import timezone @@ -150,6 +151,61 @@ def test_report_summary_present(self): self.assertEqual(actual_subject["id"], str(self.good.report_summary_subject.id)) self.assertEqual(actual_subject["name"], self.good.report_summary_subject.name) + @parameterized.expand( + [ + "random good", + "good-name", + "good!name", + "good-!.<>/%&*;+'(),.name", + ] + ) + def test_validate_good_internal_name_valid(self, name): + serializer = GoodSerializerInternal( + data={"name": name}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + + @parameterized.expand( + [ + ("", "This field may not be blank."), + ("\r\n", "This field may not be blank."), + ( + "good\rname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\r\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good_name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good$name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good@name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ] + ) + def test_validate_good_internal_name_invalid(self, name, error_message): + serializer = GoodSerializerInternal(data={"name": name}, partial=True) + self.assertFalse(serializer.is_valid()) + name_errors = serializer.errors["name"] + self.assertEqual(len(name_errors), 1) + self.assertEqual( + str(name_errors[0]), + error_message, + ) + class GoodSerializerExporterFullDetailTests(DataTestClient): @@ -182,3 +238,58 @@ def test_exporter_has_archive_history(self): archive_history["actioned_on"], datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.get_current_timezone()), ) + + @parameterized.expand( + [ + "random good", + "good-name", + "good!name", + "good-!.<>/%&*;+'(),.name", + ] + ) + def test_validate_good_exporter_name_valid(self, address): + serializer = GoodSerializerInternal( + data={"address": address}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + + @parameterized.expand( + [ + ("", "This field may not be blank."), + ("\r\n", "This field may not be blank."), + ( + "good\rname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good\r\nname", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good_name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good$name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "good@name", + "Product name must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ] + ) + def test_validate_good_exporter_name_invalid(self, name, error_message): + serializer = GoodSerializerExporterFullDetail(data={"name": name}, partial=True) + self.assertFalse(serializer.is_valid()) + name_errors = serializer.errors["name"] + self.assertEqual(len(name_errors), 1) + self.assertEqual( + str(name_errors[0]), + error_message, + ) diff --git a/api/gov_users/tests/test_gov_user_notifications.py b/api/gov_users/tests/test_gov_user_notifications.py index bd9ed17dc0..6b46796302 100644 --- a/api/gov_users/tests/test_gov_user_notifications.py +++ b/api/gov_users/tests/test_gov_user_notifications.py @@ -3,10 +3,7 @@ from rest_framework import status from api.audit_trail.models import Audit -from api.audit_trail.enums import AuditType from api.staticdata.statuses.enums import CaseStatusEnum -from api.staticdata.statuses.libraries.get_case_status import get_case_status_by_status -from api.users.models import GovNotification from test_helpers.clients import DataTestClient @@ -23,79 +20,5 @@ def test_edit_application_creates_new_audit_notification_success(self): response = self.client.put(url, data, **self.exporter_headers) self.case.refresh_from_db() - case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() self.assertEqual(status.HTTP_200_OK, response.status_code) - # There can only be one notification per gov user's case - # (the notification for updating the name overwrites any prior notification) - self.assertEqual(case_audit_notification_count, 1) - - def test_edit_application_updates_previous_audit_notification_success(self): - audit = Audit.objects.create( - actor=self.exporter_user, - verb=AuditType.UPDATED_APPLICATION_NAME, - target=self.case.get_case(), - payload={"old_name": "old_app_name", "new_name": "new_app_name"}, - ) - - self.gov_user.send_notification(content_object=audit, case=self.case) - prev_case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - - url = reverse("applications:application", kwargs={"pk": self.case.id}) - data = {"name": "even newer app name!"} - - response = self.client.put(url, data, **self.exporter_headers) - self.case.refresh_from_db() - case_audit_notification = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ) - - self.assertEqual(status.HTTP_200_OK, response.status_code) - self.assertEqual(case_audit_notification.count(), prev_case_audit_notification_count) - self.assertEqual(data["name"], case_audit_notification.last().content_object.payload["new_name"]) - self.assertNotEqual(case_audit_notification.last().content_object, audit) - - def test_edit_application_as_gov_user_does_not_create_an_audit_notification_success(self): - prev_case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - url = reverse("caseworker_applications:change_status", kwargs={"pk": self.case.id}) - data = {"status": "under_review"} - - response = self.client.post(url, data, **self.gov_headers) - self.case.refresh_from_db() - case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - - self.assertEqual(status.HTTP_200_OK, response.status_code) - self.assertEqual(case_audit_notification_count, prev_case_audit_notification_count) - - def test_get_case_activity_deletes_audit_notification_success(self): - self.case = self.create_standard_application_case(self.organisation).get_case() - old_status = self.case.status.status - self.case.status = get_case_status_by_status(CaseStatusEnum.APPLICANT_EDITING) - self.case.save() - audit = Audit.objects.create( - actor=self.exporter_user, - verb=AuditType.UPDATED_STATUS, - target=self.case, - payload={"status": {"new": CaseStatusEnum.APPLICANT_EDITING, "old": old_status}}, - ) - - self.gov_user.send_notification(content_object=audit, case=self.case) - url = reverse("cases:activity", kwargs={"pk": self.case.id}) - - response = self.client.get(url, **self.gov_headers) - case_audit_notification_count = GovNotification.objects.filter( - user=self.gov_user.baseuser_ptr, content_type=self.audit_content_type, case=self.case - ).count() - - self.assertEqual(response.status_code, status.HTTP_200_OK) - case_activity = response.json()["activity"] - self.assertEqual(len(case_activity), 2) - self.assertEqual(case_audit_notification_count, 1) diff --git a/api/open_general_licences/views.py b/api/open_general_licences/views.py index 29c9fc5827..c5317ae896 100644 --- a/api/open_general_licences/views.py +++ b/api/open_general_licences/views.py @@ -19,7 +19,6 @@ from api.organisations.models import Site from api.staticdata.statuses.enums import CaseStatusEnum from api.users.enums import UserType -from api.users.models import GovUser, GovNotification class OpenGeneralLicenceList(ListCreateAPIView): @@ -191,10 +190,6 @@ def get(self, request, pk): data = AuditSerializer(audit_trail_qs, many=True).data - if isinstance(request.user, GovUser): - # Delete notifications related to audits - GovNotification.objects.filter(user_id=request.user.pk, object_id__in=[obj["id"] for obj in data]).delete() - filters = audit_trail_service.get_objects_activity_filters(pk, content_type) return JsonResponse(data={"activity": data, "filters": filters}, status=status.HTTP_200_OK) diff --git a/api/parties/serializers.py b/api/parties/serializers.py index 68dcb98201..a4eb7edb43 100644 --- a/api/parties/serializers.py +++ b/api/parties/serializers.py @@ -3,6 +3,7 @@ from api.cases.enums import CaseTypeSubTypeEnum from api.core.serializers import KeyValueChoiceField, CountrySerializerField +from api.core.validators import PartyAddressValidator from api.documents.libraries.process_document import process_document from api.flags.serializers import FlagSerializer from api.goods.enums import PvGrading @@ -15,7 +16,7 @@ class PartySerializer(serializers.ModelSerializer): name = serializers.CharField(error_messages=PartyErrors.NAME) - address = serializers.CharField(error_messages=PartyErrors.ADDRESS) + address = serializers.CharField(error_messages=PartyErrors.ADDRESS, validators=[PartyAddressValidator()]) country = CountrySerializerField() website = serializers.CharField(required=False, allow_blank=True) signatory_name_euu = serializers.CharField(allow_blank=True) @@ -116,8 +117,7 @@ def validate(self, data): return validated_data - @staticmethod - def validate_website(value): + def validate_website(self, value): """ Custom validation for URL that makes use of django URLValidator but makes the passing of http:// or https:// optional by prepending diff --git a/api/parties/tests/test_serializers.py b/api/parties/tests/test_serializers.py index 509c439f04..903f293ce8 100644 --- a/api/parties/tests/test_serializers.py +++ b/api/parties/tests/test_serializers.py @@ -1,24 +1,94 @@ -import pytest from parameterized import parameterized from api.parties.serializers import PartySerializer -from django.core.exceptions import ValidationError +from test_helpers.clients import DataTestClient -@parameterized.expand( - [ - ("http://workingexample.com", "http://workingexample.com"), - ("http://www.workingexample.com", "http://www.workingexample.com"), - ("http://WWW.workingexample.com", "http://WWW.workingexample.com"), - ("http://workingexample.com", "http://workingexample.com"), - ("workingexample.com", "https://workingexample.com"), - ("HTTPS://workingexample.com", "HTTPS://workingexample.com"), - ] -) -def test_party_serializer_validate_website_valid(url_input, url_output): - assert url_output == PartySerializer.validate_website(url_input) +class TestPartySerializer(DataTestClient): + @parameterized.expand( + [ + ("http://workingexample.com", "http://workingexample.com"), + ("http://www.workingexample.com", "http://www.workingexample.com"), + ("http://WWW.workingexample.com", "http://WWW.workingexample.com"), + ("http://workingexample.com", "http://workingexample.com"), + ("workingexample.com", "https://workingexample.com"), + ("HTTPS://workingexample.com", "HTTPS://workingexample.com"), + ] + ) + def test_party_serializer_validate_website_valid(self, url_input, url_output): + serializer = PartySerializer( + data={"website": url_input}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + self.assertEqual( + serializer.validated_data["website"], + url_output, + ) + def test_party_serializer_validate_website_invalid(self): + serializer = PartySerializer( + data={"website": "invalid@ur&l-i.am"}, + partial=True, + ) + self.assertFalse(serializer.is_valid()) + website_errors = serializer.errors["website"] + self.assertEqual(len(website_errors), 1) + self.assertEqual( + str(website_errors[0]), + "Enter a valid URL.", + ) -def test_party_serializer_validate_website_invalid(): - with pytest.raises(ValidationError): - PartySerializer.validate_website("invalid@ur&l-i.am") + @parameterized.expand( + [ + "random party", + "party-address", + "party!address", + "party-!.<>/%&*;+'(),.address", + "party\r\naddress", + ] + ) + def test_validate_party_address_valid(self, address): + serializer = PartySerializer( + data={"address": address}, + partial=True, + ) + self.assertTrue(serializer.is_valid()) + + @parameterized.expand( + [ + ("\r\n", "Enter an address"), + ( + "party\address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party-\waddress", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party_address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party$address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ( + "party@address", + "Address must only include letters, numbers, and common special characters such as hyphens, brackets and apostrophes", + ), + ] + ) + def test_validate_party_address_invalid(self, address, error_message): + serializer = PartySerializer( + data={"address": address}, + partial=True, + ) + self.assertFalse(serializer.is_valid()) + address_errors = serializer.errors["address"] + self.assertEqual(len(address_errors), 1) + self.assertEqual( + str(address_errors[0]), + error_message, + ) diff --git a/api/queues/constants.py b/api/queues/constants.py index 5f677d4705..7d75e3afd9 100644 --- a/api/queues/constants.py +++ b/api/queues/constants.py @@ -23,7 +23,6 @@ ALL_CASES_QUEUE_ID: ALL_CASES_QUEUE_NAME, OPEN_CASES_QUEUE_ID: OPEN_CASES_QUEUE_NAME, MY_TEAMS_QUEUES_CASES_ID: MY_TEAMS_QUEUES_CASES_NAME, - UPDATED_CASES_QUEUE_ID: UPDATED_CASES_QUEUE_NAME, } NON_WORK_QUEUES = { diff --git a/api/queues/service.py b/api/queues/service.py index ec2b28719d..d39394601f 100644 --- a/api/queues/service.py +++ b/api/queues/service.py @@ -10,7 +10,6 @@ MY_TEAMS_QUEUES_CASES_ID, MY_ASSIGNED_CASES_QUEUE_ID, MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID, - UPDATED_CASES_QUEUE_ID, SYSTEM_QUEUES, ) from api.queues.models import Queue @@ -85,7 +84,6 @@ def _get_system_queues_case_count(user) -> Dict: MY_TEAMS_QUEUES_CASES_ID: case_qs.in_team(team_id=user.team.id).count(), MY_ASSIGNED_CASES_QUEUE_ID: case_qs.assigned_to_user(user=user).not_terminal().count(), MY_ASSIGNED_AS_CASE_OFFICER_CASES_QUEUE_ID: case_qs.assigned_as_case_officer(user=user).not_terminal().count(), - UPDATED_CASES_QUEUE_ID: case_qs.is_updated(user=user).count(), } return cases_count diff --git a/api/staticdata/management/commands/seedinternalusers.py b/api/staticdata/management/commands/seedinternalusers.py index c59e21ee1c..b54e6b2702 100644 --- a/api/staticdata/management/commands/seedinternalusers.py +++ b/api/staticdata/management/commands/seedinternalusers.py @@ -3,6 +3,7 @@ from api.core.constants import Teams, Roles from api.conf.settings import env +from api.queues.constants import ALL_CASES_QUEUE_ID from api.staticdata.management.SeedCommand import SeedCommand from api.users.enums import UserType from api.users.models import Role, GovUser, BaseUser @@ -23,6 +24,8 @@ def operation(self, *args, **options): for admin_user in admin_users: email = admin_user["email"] + team_id = admin_user.get("team_id", Teams.ADMIN_TEAM_ID) + default_queue = admin_user.get("default_queue", ALL_CASES_QUEUE_ID) role = Role.objects.get( name=admin_user.get("role", Roles.INTERNAL_SUPER_USER_ROLE_NAME), type=UserType.INTERNAL @@ -31,14 +34,15 @@ def operation(self, *args, **options): email__iexact=email, defaults={"email": email}, type=UserType.INTERNAL ) admin_user, created = GovUser.objects.get_or_create( - baseuser_ptr=base_user, defaults={"team_id": Teams.ADMIN_TEAM_ID, "role": role} + baseuser_ptr=base_user, + defaults={"team_id": team_id, "role": role, "default_queue": default_queue}, ) if created or admin_user.role != role: admin_user.role = role admin_user.save() - admin_data = dict(email=email, team=Teams.ADMIN_TEAM_NAME, role=role.name) + admin_data = dict(email=email, team=admin_user.team.name, role=role.name) self.print_created_or_updated(GovUser, admin_data, is_created=created) if not GovUser.objects.count() >= 1: diff --git a/api/staticdata/management/commands/tests/test_commands.py b/api/staticdata/management/commands/tests/test_commands.py index ca85c4fff8..d6d861b289 100644 --- a/api/staticdata/management/commands/tests/test_commands.py +++ b/api/staticdata/management/commands/tests/test_commands.py @@ -1,15 +1,18 @@ +import json import os import pytest +from parameterized import parameterized from tempfile import NamedTemporaryFile from api.cases.enums import CaseTypeEnum from api.cases.models import CaseType -from api.core.constants import GovPermissions, ExporterPermissions +from api.core.constants import GovPermissions, ExporterPermissions, Teams from api.conf.settings import BASE_DIR from api.letter_templates.models import LetterTemplate from api.staticdata.countries.models import Country from api.cases.enums import AdviceType +from api.queues.constants import ALL_CASES_QUEUE_ID from api.staticdata.decisions.models import Decision from api.staticdata.denial_reasons.models import DenialReason from api.staticdata.letter_layouts.models import LetterLayout @@ -23,12 +26,22 @@ seedlettertemplates, seedrolepermissions, seedfinaldecisions, + seedinternalusers, ) from api.staticdata.statuses.models import CaseStatus, CaseStatusCaseType -from api.users.models import Permission +from api.teams.enums import TeamIdEnum +from api.users.enums import UserType +from api.users.models import GovUser, Permission +from api.users.tests.factories import RoleFactory class SeedingTests(SeedCommandTest): + def setUp(self) -> None: + super().setUp() + role_names = ["Super User", "Case officer", "Case adviser", "Manager", "Senior Manager"] + for name in role_names: + RoleFactory(name=name, type=UserType.INTERNAL) + @pytest.mark.seeding def test_seed_case_types(self): self.seed_command(seedcasetypes.Command) @@ -111,3 +124,40 @@ def test_seed_letter_templates(self): # running again with existing templates does nothing self.seed_command(seedlettertemplates.Command) self.assertEqual(LetterTemplate.objects.count(), 2) + + @pytest.mark.seeding + @parameterized.expand( + [ + ([{"email": "admin@example.co.uk", "role": "Super User"}],), + ([{"email": "manager@example.co.uk", "role": "Manager", "first_name": "LU", "last_name": "Manager"}],), + ( + [ + { + "email": "senior_manager@example.co.uk", + "role": "Senior Manager", + "team_id": TeamIdEnum.LICENSING_UNIT, + } + ], + ), + ( + [ + { + "email": "case_officer@example.co.uk", + "role": "Case officer", + "team_id": TeamIdEnum.LICENSING_UNIT, + "default_queue": "00000000-0000-0000-0000-000000000004", + } + ], + ), + ] + ) + def test_seed_internal_users(self, data): + os.environ["INTERNAL_USERS"] = json.dumps(data) + + self.seed_command(seedinternalusers.Command) + + for item in data: + user = GovUser.objects.get(baseuser_ptr__email=item["email"]) + self.assertEqual(user.role.name, item["role"]) + self.assertEqual(str(user.team_id), item.get("team_id", Teams.ADMIN_TEAM_ID)) + self.assertEqual(str(user.default_queue), item.get("default_queue", ALL_CASES_QUEUE_ID)) diff --git a/api/users/migrations/0007_delete_govnotification.py b/api/users/migrations/0007_delete_govnotification.py new file mode 100644 index 0000000000..3f55716f4a --- /dev/null +++ b/api/users/migrations/0007_delete_govnotification.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.14 on 2024-09-06 10:22 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0006_alter_userorganisationrelationship_user"), + ] + + operations = [ + migrations.DeleteModel( + name="GovNotification", + ), + ] diff --git a/api/users/models.py b/api/users/models.py index 84e406f459..2d714a3497 100644 --- a/api/users/models.py +++ b/api/users/models.py @@ -115,10 +115,6 @@ class ExporterNotification(BaseNotification): organisation = models.ForeignKey("organisations.Organisation", on_delete=models.CASCADE, null=False) -class GovNotification(BaseNotification): - pass - - class BaseUserCompatMixin: baseuser_ptr: BaseUser @@ -212,20 +208,6 @@ def unassign_from_cases(self): """ self.case_assignments.filter(user=self).delete() - def send_notification(self, content_object, case): - from api.audit_trail.models import Audit - - if isinstance(content_object, Audit): - # There can only be one notification per gov user's case - # If a notification for that gov user's case already exists, update the case activity it points to - try: - content_type = ContentType.objects.get_for_model(Audit) - notification = GovNotification.objects.get(user=self.baseuser_ptr, content_type=content_type, case=case) - notification.content_object = content_object - notification.save() - except GovNotification.DoesNotExist: - GovNotification.objects.create(user=self.baseuser_ptr, content_object=content_object, case=case) - def has_permission(self, permission): user_permissions = self.role.permissions.values_list("id", flat=True) return permission.name in user_permissions diff --git a/pytest.ini b/pytest.ini index 76fc6e4346..4be4053f05 100644 --- a/pytest.ini +++ b/pytest.ini @@ -9,3 +9,4 @@ markers = elasticsearch: Tests that use elasticsearch seeding: tests that check seed commands performance: tests that check performance + requires_transactions: tests that require fine grained controls of transactions diff --git a/test_helpers/clients.py b/test_helpers/clients.py index 668724eba8..933f1f4cb2 100644 --- a/test_helpers/clients.py +++ b/test_helpers/clients.py @@ -886,7 +886,6 @@ def create_audit( target=None, payload=None, ignore_case_status=False, - send_notification=True, ): if not payload: payload = {} @@ -898,7 +897,6 @@ def create_audit( target=target, payload=payload, ignore_case_status=ignore_case_status, - send_notification=send_notification, ) def add_users(self, count=3):