diff --git a/api/src/backend/api/base_views.py b/api/src/backend/api/base_views.py index bcaaeb5a651..11f1d4bf5f5 100644 --- a/api/src/backend/api/base_views.py +++ b/api/src/backend/api/base_views.py @@ -7,10 +7,10 @@ from rest_framework_json_api.views import ModelViewSet from rest_framework_simplejwt.authentication import JWTAuthentication -from api.db_utils import POSTGRES_USER_VAR, tenant_transaction +from api.db_router import MainRouter +from api.db_utils import POSTGRES_USER_VAR, rls_transaction from api.filters import CustomDjangoFilterBackend from api.models import Role, Tenant -from api.db_router import MainRouter class BaseViewSet(ModelViewSet): @@ -48,7 +48,7 @@ def initial(self, request, *args, **kwargs): if tenant_id is None: raise NotAuthenticated("Tenant ID is not present in token") - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): self.request.tenant_id = tenant_id return super().initial(request, *args, **kwargs) @@ -102,7 +102,7 @@ def initial(self, request, *args, **kwargs): ): user_id = str(request.user.id) - with tenant_transaction(value=user_id, parameter=POSTGRES_USER_VAR): + with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR): return super().initial(request, *args, **kwargs) # TODO: DRY this when we have time @@ -113,7 +113,7 @@ def initial(self, request, *args, **kwargs): if tenant_id is None: raise NotAuthenticated("Tenant ID is not present in token") - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): self.request.tenant_id = tenant_id return super().initial(request, *args, **kwargs) @@ -134,6 +134,6 @@ def initial(self, request, *args, **kwargs): if tenant_id is None: raise NotAuthenticated("Tenant ID is not present in token") - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): self.request.tenant_id = tenant_id return super().initial(request, *args, **kwargs) diff --git a/api/src/backend/api/db_utils.py b/api/src/backend/api/db_utils.py index ecb2f6c455c..466f78aea6b 100644 --- a/api/src/backend/api/db_utils.py +++ b/api/src/backend/api/db_utils.py @@ -48,14 +48,14 @@ def psycopg_connection(database_alias: str): @contextmanager -def tenant_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR): +def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR): """ - Creates a new database transaction setting the given configuration value. It validates the - if the value is a valid UUID to be used for Postgres RLS. + Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the + if the value is a valid UUID. Args: value (str): Database configuration parameter value. - parameter (str): Database configuration parameter name + parameter (str): Database configuration parameter name, by default is 'api.tenant_id'. """ with transaction.atomic(): with connection.cursor() as cursor: diff --git a/api/src/backend/api/renderers.py b/api/src/backend/api/renderers.py index ccca52c26f1..ee03f247582 100644 --- a/api/src/backend/api/renderers.py +++ b/api/src/backend/api/renderers.py @@ -2,7 +2,7 @@ from rest_framework_json_api.renderers import JSONRenderer -from api.db_utils import tenant_transaction +from api.db_utils import rls_transaction class APIJSONRenderer(JSONRenderer): @@ -13,9 +13,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None): tenant_id = getattr(request, "tenant_id", None) if request else None include_param_present = "include" in request.query_params if request else False - # Use tenant_transaction if needed for included resources, otherwise do nothing + # Use rls_transaction if needed for included resources, otherwise do nothing context_manager = ( - tenant_transaction(tenant_id) + rls_transaction(tenant_id) if tenant_id and include_param_present else nullcontext() ) diff --git a/api/src/backend/config/celery.py b/api/src/backend/config/celery.py index 2206a947971..e5b0c304921 100644 --- a/api/src/backend/config/celery.py +++ b/api/src/backend/config/celery.py @@ -35,10 +35,10 @@ def apply_async( **options, ) task_result_instance = TaskResult.objects.get(task_id=result.task_id) - from api.db_utils import tenant_transaction + from api.db_utils import rls_transaction tenant_id = kwargs.get("tenant_id") - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): APITask.objects.create( id=task_result_instance.task_id, tenant_id=tenant_id, diff --git a/api/src/backend/conftest.py b/api/src/backend/conftest.py index f0e76257e1a..273daa6366d 100644 --- a/api/src/backend/conftest.py +++ b/api/src/backend/conftest.py @@ -1,39 +1,38 @@ import logging +from datetime import datetime, timedelta, timezone +from unittest.mock import patch import pytest from django.conf import settings -from datetime import datetime, timezone, timedelta -from django.db import connections as django_connections, connection as django_connection +from django.db import connection as django_connection +from django.db import connections as django_connections from django.urls import reverse from django_celery_results.models import TaskResult -from prowler.lib.check.models import Severity -from prowler.lib.outputs.finding import Status from rest_framework import status from rest_framework.test import APIClient -from unittest.mock import patch -from api.db_utils import tenant_transaction +from api.db_utils import rls_transaction from api.models import ( + ComplianceOverview, Finding, -) -from api.models import ( - User, + Invitation, + Membership, Provider, ProviderGroup, + ProviderSecret, Resource, ResourceTag, Role, Scan, StateChoices, Task, - Membership, - ProviderSecret, - Invitation, - ComplianceOverview, + User, UserRoleRelationship, ) from api.rls import Tenant from api.v1.serializers import TokenSerializer +from prowler.lib.check.models import Severity +from prowler.lib.outputs.finding import Status API_JSON_CONTENT_TYPE = "application/vnd.api+json" NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED @@ -281,7 +280,7 @@ def tenants_fixture(create_test_user): def set_user_admin_roles_fixture(create_test_user, tenants_fixture): user = create_test_user for tenant in tenants_fixture[:2]: - with tenant_transaction(str(tenant.id)): + with rls_transaction(str(tenant.id)): role = Role.objects.create( name="admin", tenant_id=tenant.id, @@ -757,9 +756,10 @@ def get_api_tokens( data=json_body, format="vnd.api+json", ) - return response.json()["data"]["attributes"]["access"], response.json()["data"][ - "attributes" - ]["refresh"] + return ( + response.json()["data"]["attributes"]["access"], + response.json()["data"]["attributes"]["refresh"], + ) def get_authorization_header(access_token: str) -> dict: diff --git a/api/src/backend/tasks/jobs/deletion.py b/api/src/backend/tasks/jobs/deletion.py index 9e1ba600bef..5ca08e70bb2 100644 --- a/api/src/backend/tasks/jobs/deletion.py +++ b/api/src/backend/tasks/jobs/deletion.py @@ -2,7 +2,7 @@ from django.db import transaction from api.db_router import MainRouter -from api.db_utils import batch_delete, tenant_transaction +from api.db_utils import batch_delete, rls_transaction from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant logger = get_task_logger(__name__) @@ -66,7 +66,7 @@ def delete_tenant(pk: str): deletion_summary = {} for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk): - with tenant_transaction(pk): + with rls_transaction(pk): summary = delete_provider(provider.id) deletion_summary.update(summary) diff --git a/api/src/backend/tasks/jobs/scan.py b/api/src/backend/tasks/jobs/scan.py index ffc75d1f372..48d1578ad63 100644 --- a/api/src/backend/tasks/jobs/scan.py +++ b/api/src/backend/tasks/jobs/scan.py @@ -11,7 +11,7 @@ PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE, generate_scan_compliance, ) -from api.db_utils import tenant_transaction +from api.db_utils import rls_transaction from api.models import ( ComplianceOverview, Finding, @@ -69,7 +69,7 @@ def _store_resources( - tuple[str, str]: A tuple containing the resource UID and region. """ - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): resource_instance, created = Resource.objects.get_or_create( tenant_id=tenant_id, provider=provider_instance, @@ -86,7 +86,7 @@ def _store_resources( resource_instance.service = finding.service_name resource_instance.type = finding.resource_type resource_instance.save() - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): tags = [ ResourceTag.objects.get_or_create( tenant_id=tenant_id, key=key, value=value @@ -122,7 +122,7 @@ def perform_prowler_scan( unique_resources = set() start_time = time.time() - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): provider_instance = Provider.objects.get(pk=provider_id) scan_instance = Scan.objects.get(pk=scan_id) scan_instance.state = StateChoices.EXECUTING @@ -130,7 +130,7 @@ def perform_prowler_scan( scan_instance.save() try: - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): try: prowler_provider = initialize_prowler_provider(provider_instance) provider_instance.connected = True @@ -156,7 +156,7 @@ def perform_prowler_scan( for finding in findings: for attempt in range(CELERY_DEADLOCK_ATTEMPTS): try: - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): # Process resource resource_uid = finding.resource_uid if resource_uid not in resource_cache: @@ -188,7 +188,7 @@ def perform_prowler_scan( resource_instance.type = finding.resource_type updated_fields.append("type") if updated_fields: - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): resource_instance.save(update_fields=updated_fields) except (OperationalError, IntegrityError) as db_err: if attempt < CELERY_DEADLOCK_ATTEMPTS - 1: @@ -203,7 +203,7 @@ def perform_prowler_scan( # Update tags tags = [] - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): for key, value in finding.resource_tags.items(): tag_key = (key, value) if tag_key not in tag_cache: @@ -219,7 +219,7 @@ def perform_prowler_scan( unique_resources.add((resource_instance.uid, resource_instance.region)) # Process finding - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): finding_uid = finding.uid if finding_uid not in last_status_cache: most_recent_finding = ( @@ -267,7 +267,7 @@ def perform_prowler_scan( region_dict[finding.check_id] = finding.status.value # Update scan progress - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): scan_instance.progress = progress scan_instance.save() @@ -279,7 +279,7 @@ def perform_prowler_scan( scan_instance.state = StateChoices.FAILED finally: - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): scan_instance.duration = time.time() - start_time scan_instance.completed_at = datetime.now(tz=timezone.utc) scan_instance.unique_resource_count = len(unique_resources) @@ -330,7 +330,7 @@ def perform_prowler_scan( total_requirements=compliance["total_requirements"], ) ) - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): ComplianceOverview.objects.bulk_create(compliance_overview_objects) if exception is not None: @@ -368,7 +368,7 @@ def aggregate_findings(tenant_id: str, scan_id: str): - muted_new: Muted findings with a delta of 'new'. - muted_changed: Muted findings with a delta of 'changed'. """ - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): findings = Finding.objects.filter(scan_id=scan_id) aggregation = findings.values( @@ -464,7 +464,7 @@ def aggregate_findings(tenant_id: str, scan_id: str): ), ) - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): scan_aggregations = { ScanSummary( tenant_id=tenant_id, diff --git a/api/src/backend/tasks/tasks.py b/api/src/backend/tasks/tasks.py index 68eedbb8eab..1f87e0fdeb5 100644 --- a/api/src/backend/tasks/tasks.py +++ b/api/src/backend/tasks/tasks.py @@ -7,7 +7,7 @@ from tasks.jobs.deletion import delete_provider, delete_tenant from tasks.jobs.scan import aggregate_findings, perform_prowler_scan -from api.db_utils import tenant_transaction +from api.db_utils import rls_transaction from api.decorators import set_tenant from api.models import Provider, Scan @@ -99,7 +99,7 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str): """ task_id = self.request.id - with tenant_transaction(tenant_id): + with rls_transaction(tenant_id): provider_instance = Provider.objects.get(pk=provider_id) periodic_task_instance = PeriodicTask.objects.get( name=f"scan-perform-scheduled-{provider_id}" diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index ba5802c1e9f..78ba36fde88 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -27,7 +27,7 @@ def test_perform_prowler_scan_success( providers_fixture, ): with ( - patch("api.db_utils.tenant_transaction"), + patch("api.db_utils.rls_transaction"), patch( "tasks.jobs.scan.initialize_prowler_provider" ) as mock_initialize_prowler_provider, @@ -166,10 +166,10 @@ def test_perform_prowler_scan_success( "tasks.jobs.scan.initialize_prowler_provider", side_effect=Exception("Connection error"), ) - @patch("api.db_utils.tenant_transaction") + @patch("api.db_utils.rls_transaction") def test_perform_prowler_scan_no_connection( self, - mock_tenant_transaction, + mock_rls_transaction, mock_initialize_prowler_provider, mock_prowler_scan_class, tenants_fixture, @@ -206,10 +206,10 @@ def test_create_finding_delta(self, last_status, new_status, expected_delta): @patch("api.models.ResourceTag.objects.get_or_create") @patch("api.models.Resource.objects.get_or_create") - @patch("api.db_utils.tenant_transaction") + @patch("api.db_utils.rls_transaction") def test_store_resources_new_resource( self, - mock_tenant_transaction, + mock_rls_transaction, mock_get_or_create_resource, mock_get_or_create_tag, ): @@ -254,10 +254,10 @@ def test_store_resources_new_resource( @patch("api.models.ResourceTag.objects.get_or_create") @patch("api.models.Resource.objects.get_or_create") - @patch("api.db_utils.tenant_transaction") + @patch("api.db_utils.rls_transaction") def test_store_resources_existing_resource( self, - mock_tenant_transaction, + mock_rls_transaction, mock_get_or_create_resource, mock_get_or_create_tag, ): @@ -311,10 +311,10 @@ def test_store_resources_existing_resource( @patch("api.models.ResourceTag.objects.get_or_create") @patch("api.models.Resource.objects.get_or_create") - @patch("api.db_utils.tenant_transaction") + @patch("api.db_utils.rls_transaction") def test_store_resources_with_tags( self, - mock_tenant_transaction, + mock_rls_transaction, mock_get_or_create_resource, mock_get_or_create_tag, ):