Skip to content

Commit

Permalink
chore(rls): rename tenant_transaction to rls_transaction (#6202)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfagoagas authored Dec 16, 2024
1 parent 9d7499b commit 57854f2
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 59 deletions.
12 changes: 6 additions & 6 deletions api/src/backend/api/base_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from rest_framework_json_api.views import ModelViewSet
from rest_framework_simplejwt.authentication import JWTAuthentication

from api.db_utils import POSTGRES_USER_VAR, tenant_transaction
from api.db_router import MainRouter
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, Tenant
from api.db_router import MainRouter


class BaseViewSet(ModelViewSet):
Expand Down Expand Up @@ -48,7 +48,7 @@ def initial(self, request, *args, **kwargs):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

Expand Down Expand Up @@ -102,7 +102,7 @@ def initial(self, request, *args, **kwargs):
):
user_id = str(request.user.id)

with tenant_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
with rls_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)

# TODO: DRY this when we have time
Expand All @@ -113,7 +113,7 @@ def initial(self, request, *args, **kwargs):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

Expand All @@ -134,6 +134,6 @@ def initial(self, request, *args, **kwargs):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
8 changes: 4 additions & 4 deletions api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def psycopg_connection(database_alias: str):


@contextmanager
def tenant_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
def rls_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
"""
Creates a new database transaction setting the given configuration value. It validates the
if the value is a valid UUID to be used for Postgres RLS.
Creates a new database transaction setting the given configuration value for Postgres RLS. It validates the
if the value is a valid UUID.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name
parameter (str): Database configuration parameter name, by default is 'api.tenant_id'.
"""
with transaction.atomic():
with connection.cursor() as cursor:
Expand Down
6 changes: 3 additions & 3 deletions api/src/backend/api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from rest_framework_json_api.renderers import JSONRenderer

from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction


class APIJSONRenderer(JSONRenderer):
Expand All @@ -13,9 +13,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
tenant_id = getattr(request, "tenant_id", None) if request else None
include_param_present = "include" in request.query_params if request else False

# Use tenant_transaction if needed for included resources, otherwise do nothing
# Use rls_transaction if needed for included resources, otherwise do nothing
context_manager = (
tenant_transaction(tenant_id)
rls_transaction(tenant_id)
if tenant_id and include_param_present
else nullcontext()
)
Expand Down
4 changes: 2 additions & 2 deletions api/src/backend/config/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def apply_async(
**options,
)
task_result_instance = TaskResult.objects.get(task_id=result.task_id)
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction

tenant_id = kwargs.get("tenant_id")
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
APITask.objects.create(
id=task_result_instance.task_id,
tenant_id=tenant_id,
Expand Down
34 changes: 17 additions & 17 deletions api/src/backend/conftest.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,38 @@
import logging
from datetime import datetime, timedelta, timezone
from unittest.mock import patch

import pytest
from django.conf import settings
from datetime import datetime, timezone, timedelta
from django.db import connections as django_connections, connection as django_connection
from django.db import connection as django_connection
from django.db import connections as django_connections
from django.urls import reverse
from django_celery_results.models import TaskResult
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status
from rest_framework import status
from rest_framework.test import APIClient
from unittest.mock import patch
from api.db_utils import tenant_transaction

from api.db_utils import rls_transaction
from api.models import (
ComplianceOverview,
Finding,
)
from api.models import (
User,
Invitation,
Membership,
Provider,
ProviderGroup,
ProviderSecret,
Resource,
ResourceTag,
Role,
Scan,
StateChoices,
Task,
Membership,
ProviderSecret,
Invitation,
ComplianceOverview,
User,
UserRoleRelationship,
)
from api.rls import Tenant
from api.v1.serializers import TokenSerializer
from prowler.lib.check.models import Severity
from prowler.lib.outputs.finding import Status

API_JSON_CONTENT_TYPE = "application/vnd.api+json"
NO_TENANT_HTTP_STATUS = status.HTTP_401_UNAUTHORIZED
Expand Down Expand Up @@ -281,7 +280,7 @@ def tenants_fixture(create_test_user):
def set_user_admin_roles_fixture(create_test_user, tenants_fixture):
user = create_test_user
for tenant in tenants_fixture[:2]:
with tenant_transaction(str(tenant.id)):
with rls_transaction(str(tenant.id)):
role = Role.objects.create(
name="admin",
tenant_id=tenant.id,
Expand Down Expand Up @@ -757,9 +756,10 @@ def get_api_tokens(
data=json_body,
format="vnd.api+json",
)
return response.json()["data"]["attributes"]["access"], response.json()["data"][
"attributes"
]["refresh"]
return (
response.json()["data"]["attributes"]["access"],
response.json()["data"]["attributes"]["refresh"],
)


def get_authorization_header(access_token: str) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions api/src/backend/tasks/jobs/deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.db import transaction

from api.db_router import MainRouter
from api.db_utils import batch_delete, tenant_transaction
from api.db_utils import batch_delete, rls_transaction
from api.models import Finding, Provider, Resource, Scan, ScanSummary, Tenant

logger = get_task_logger(__name__)
Expand Down Expand Up @@ -66,7 +66,7 @@ def delete_tenant(pk: str):
deletion_summary = {}

for provider in Provider.objects.using(MainRouter.admin_db).filter(tenant_id=pk):
with tenant_transaction(pk):
with rls_transaction(pk):
summary = delete_provider(provider.id)
deletion_summary.update(summary)

Expand Down
28 changes: 14 additions & 14 deletions api/src/backend/tasks/jobs/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
PROWLER_COMPLIANCE_OVERVIEW_TEMPLATE,
generate_scan_compliance,
)
from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
from api.models import (
ComplianceOverview,
Finding,
Expand Down Expand Up @@ -69,7 +69,7 @@ def _store_resources(
- tuple[str, str]: A tuple containing the resource UID and region.
"""
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
resource_instance, created = Resource.objects.get_or_create(
tenant_id=tenant_id,
provider=provider_instance,
Expand All @@ -86,7 +86,7 @@ def _store_resources(
resource_instance.service = finding.service_name
resource_instance.type = finding.resource_type
resource_instance.save()
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
tags = [
ResourceTag.objects.get_or_create(
tenant_id=tenant_id, key=key, value=value
Expand Down Expand Up @@ -122,15 +122,15 @@ def perform_prowler_scan(
unique_resources = set()
start_time = time.time()

with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
scan_instance = Scan.objects.get(pk=scan_id)
scan_instance.state = StateChoices.EXECUTING
scan_instance.started_at = datetime.now(tz=timezone.utc)
scan_instance.save()

try:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
try:
prowler_provider = initialize_prowler_provider(provider_instance)
provider_instance.connected = True
Expand All @@ -156,7 +156,7 @@ def perform_prowler_scan(
for finding in findings:
for attempt in range(CELERY_DEADLOCK_ATTEMPTS):
try:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
# Process resource
resource_uid = finding.resource_uid
if resource_uid not in resource_cache:
Expand Down Expand Up @@ -188,7 +188,7 @@ def perform_prowler_scan(
resource_instance.type = finding.resource_type
updated_fields.append("type")
if updated_fields:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
resource_instance.save(update_fields=updated_fields)
except (OperationalError, IntegrityError) as db_err:
if attempt < CELERY_DEADLOCK_ATTEMPTS - 1:
Expand All @@ -203,7 +203,7 @@ def perform_prowler_scan(

# Update tags
tags = []
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
for key, value in finding.resource_tags.items():
tag_key = (key, value)
if tag_key not in tag_cache:
Expand All @@ -219,7 +219,7 @@ def perform_prowler_scan(
unique_resources.add((resource_instance.uid, resource_instance.region))

# Process finding
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
finding_uid = finding.uid
if finding_uid not in last_status_cache:
most_recent_finding = (
Expand Down Expand Up @@ -267,7 +267,7 @@ def perform_prowler_scan(
region_dict[finding.check_id] = finding.status.value

# Update scan progress
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_instance.progress = progress
scan_instance.save()

Expand All @@ -279,7 +279,7 @@ def perform_prowler_scan(
scan_instance.state = StateChoices.FAILED

finally:
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_instance.duration = time.time() - start_time
scan_instance.completed_at = datetime.now(tz=timezone.utc)
scan_instance.unique_resource_count = len(unique_resources)
Expand Down Expand Up @@ -330,7 +330,7 @@ def perform_prowler_scan(
total_requirements=compliance["total_requirements"],
)
)
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
ComplianceOverview.objects.bulk_create(compliance_overview_objects)

if exception is not None:
Expand Down Expand Up @@ -368,7 +368,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
- muted_new: Muted findings with a delta of 'new'.
- muted_changed: Muted findings with a delta of 'changed'.
"""
with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
findings = Finding.objects.filter(scan_id=scan_id)

aggregation = findings.values(
Expand Down Expand Up @@ -464,7 +464,7 @@ def aggregate_findings(tenant_id: str, scan_id: str):
),
)

with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
scan_aggregations = {
ScanSummary(
tenant_id=tenant_id,
Expand Down
4 changes: 2 additions & 2 deletions api/src/backend/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tasks.jobs.deletion import delete_provider, delete_tenant
from tasks.jobs.scan import aggregate_findings, perform_prowler_scan

from api.db_utils import tenant_transaction
from api.db_utils import rls_transaction
from api.decorators import set_tenant
from api.models import Provider, Scan

Expand Down Expand Up @@ -99,7 +99,7 @@ def perform_scheduled_scan_task(self, tenant_id: str, provider_id: str):
"""
task_id = self.request.id

with tenant_transaction(tenant_id):
with rls_transaction(tenant_id):
provider_instance = Provider.objects.get(pk=provider_id)
periodic_task_instance = PeriodicTask.objects.get(
name=f"scan-perform-scheduled-{provider_id}"
Expand Down
18 changes: 9 additions & 9 deletions api/src/backend/tasks/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_perform_prowler_scan_success(
providers_fixture,
):
with (
patch("api.db_utils.tenant_transaction"),
patch("api.db_utils.rls_transaction"),
patch(
"tasks.jobs.scan.initialize_prowler_provider"
) as mock_initialize_prowler_provider,
Expand Down Expand Up @@ -166,10 +166,10 @@ def test_perform_prowler_scan_success(
"tasks.jobs.scan.initialize_prowler_provider",
side_effect=Exception("Connection error"),
)
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_perform_prowler_scan_no_connection(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_initialize_prowler_provider,
mock_prowler_scan_class,
tenants_fixture,
Expand Down Expand Up @@ -206,10 +206,10 @@ def test_create_finding_delta(self, last_status, new_status, expected_delta):

@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_new_resource(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
Expand Down Expand Up @@ -254,10 +254,10 @@ def test_store_resources_new_resource(

@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_existing_resource(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
Expand Down Expand Up @@ -311,10 +311,10 @@ def test_store_resources_existing_resource(

@patch("api.models.ResourceTag.objects.get_or_create")
@patch("api.models.Resource.objects.get_or_create")
@patch("api.db_utils.tenant_transaction")
@patch("api.db_utils.rls_transaction")
def test_store_resources_with_tags(
self,
mock_tenant_transaction,
mock_rls_transaction,
mock_get_or_create_resource,
mock_get_or_create_tag,
):
Expand Down

0 comments on commit 57854f2

Please sign in to comment.