From da4f9b8e5fad26a53dddd76f0c7e56edfd67e70f Mon Sep 17 00:00:00 2001 From: Pepe Fagoaga Date: Fri, 13 Dec 2024 12:55:09 +0100 Subject: [PATCH] fix(RLS): enforce config security (#6066) --- api/src/backend/api/base_views.py | 33 ++++---------------- api/src/backend/api/db_utils.py | 21 +++++++++++-- api/src/backend/api/decorators.py | 13 ++++++-- api/src/backend/api/tests/test_decorators.py | 8 +++-- api/src/backend/tasks/tests/test_scan.py | 7 +++-- 5 files changed, 44 insertions(+), 38 deletions(-) diff --git a/api/src/backend/api/base_views.py b/api/src/backend/api/base_views.py index ba225c14791..5872203cb75 100644 --- a/api/src/backend/api/base_views.py +++ b/api/src/backend/api/base_views.py @@ -1,14 +1,12 @@ -import uuid - -from django.db import connection, transaction +from django.db import transaction from rest_framework import permissions from rest_framework.exceptions import NotAuthenticated from rest_framework.filters import SearchFilter from rest_framework_json_api import filters -from rest_framework_json_api.serializers import ValidationError from rest_framework_json_api.views import ModelViewSet from rest_framework_simplejwt.authentication import JWTAuthentication +from api.db_utils import POSTGRES_USER_VAR, tenant_transaction from api.filters import CustomDjangoFilterBackend @@ -47,13 +45,7 @@ def initial(self, request, *args, **kwargs): if tenant_id is None: raise NotAuthenticated("Tenant ID is not present in token") - try: - uuid.UUID(tenant_id) - except ValueError: - raise ValidationError("Tenant ID must be a valid UUID") - - with connection.cursor() as cursor: - cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);") + with tenant_transaction(tenant_id): self.request.tenant_id = tenant_id return super().initial(request, *args, **kwargs) @@ -75,8 +67,7 @@ def initial(self, request, *args, **kwargs): ): user_id = str(request.user.id) - with connection.cursor() as cursor: - cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);") + with tenant_transaction(value=user_id, parameter=POSTGRES_USER_VAR): return super().initial(request, *args, **kwargs) # TODO: DRY this when we have time @@ -87,13 +78,7 @@ def initial(self, request, *args, **kwargs): if tenant_id is None: raise NotAuthenticated("Tenant ID is not present in token") - try: - uuid.UUID(tenant_id) - except ValueError: - raise ValidationError("Tenant ID must be a valid UUID") - - with connection.cursor() as cursor: - cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);") + with tenant_transaction(tenant_id): self.request.tenant_id = tenant_id return super().initial(request, *args, **kwargs) @@ -114,12 +99,6 @@ def initial(self, request, *args, **kwargs): if tenant_id is None: raise NotAuthenticated("Tenant ID is not present in token") - try: - uuid.UUID(tenant_id) - except ValueError: - raise ValidationError("Tenant ID must be a valid UUID") - - with connection.cursor() as cursor: - cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);") + with tenant_transaction(tenant_id): self.request.tenant_id = tenant_id return super().initial(request, *args, **kwargs) diff --git a/api/src/backend/api/db_utils.py b/api/src/backend/api/db_utils.py index d9255d193dd..ecb2f6c455c 100644 --- a/api/src/backend/api/db_utils.py +++ b/api/src/backend/api/db_utils.py @@ -1,4 +1,5 @@ import secrets +import uuid from contextlib import contextmanager from datetime import datetime, timedelta, timezone @@ -8,6 +9,7 @@ from django.db import connection, models, transaction from psycopg2 import connect as psycopg2_connect from psycopg2.extensions import AsIs, new_type, register_adapter, register_type +from rest_framework_json_api.serializers import ValidationError DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test" DB_PASSWORD = ( @@ -23,6 +25,8 @@ POSTGRES_TENANT_VAR = "api.tenant_id" POSTGRES_USER_VAR = "api.user_id" +SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);" + @contextmanager def psycopg_connection(database_alias: str): @@ -44,10 +48,23 @@ def psycopg_connection(database_alias: str): @contextmanager -def tenant_transaction(tenant_id: str): +def tenant_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR): + """ + Creates a new database transaction setting the given configuration value. It validates the + if the value is a valid UUID to be used for Postgres RLS. + + Args: + value (str): Database configuration parameter value. + parameter (str): Database configuration parameter name + """ with transaction.atomic(): with connection.cursor() as cursor: - cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);") + try: + # just in case the value is an UUID object + uuid.UUID(str(value)) + except ValueError: + raise ValidationError("Must be a valid UUID") + cursor.execute(SET_CONFIG_QUERY, [parameter, value]) yield cursor diff --git a/api/src/backend/api/decorators.py b/api/src/backend/api/decorators.py index f30cb2458ee..a98b38b5ebe 100644 --- a/api/src/backend/api/decorators.py +++ b/api/src/backend/api/decorators.py @@ -1,6 +1,10 @@ +import uuid from functools import wraps from django.db import connection, transaction +from rest_framework_json_api.serializers import ValidationError + +from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY def set_tenant(func): @@ -31,7 +35,7 @@ def some_task(arg1, **kwargs): pass # When calling the task - some_task.delay(arg1, tenant_id="1234-abcd-5678") + some_task.delay(arg1, tenant_id="8db7ca86-03cc-4d42-99f6-5e480baf6ab5") # The tenant context will be set before the task logic executes. """ @@ -43,9 +47,12 @@ def wrapper(*args, **kwargs): tenant_id = kwargs.pop("tenant_id") except KeyError: raise KeyError("This task requires the tenant_id") - + try: + uuid.UUID(tenant_id) + except ValueError: + raise ValidationError("Tenant ID must be a valid UUID") with connection.cursor() as cursor: - cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);") + cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id]) return func(*args, **kwargs) diff --git a/api/src/backend/api/tests/test_decorators.py b/api/src/backend/api/tests/test_decorators.py index a9a333bb7b6..8ac31b1d499 100644 --- a/api/src/backend/api/tests/test_decorators.py +++ b/api/src/backend/api/tests/test_decorators.py @@ -1,7 +1,9 @@ -from unittest.mock import patch, call +import uuid +from unittest.mock import call, patch import pytest +from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY from api.decorators import set_tenant @@ -15,12 +17,12 @@ def test_set_tenant(self, mock_cursor): def random_func(arg): return arg - tenant_id = "1234-abcd-5678" + tenant_id = str(uuid.uuid4()) result = random_func("test_arg", tenant_id=tenant_id) assert ( - call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);") + call(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id]) in mock_cursor.execute.mock_calls ) assert result == "test_arg" diff --git a/api/src/backend/tasks/tests/test_scan.py b/api/src/backend/tasks/tests/test_scan.py index da79f785553..ba5802c1e9f 100644 --- a/api/src/backend/tasks/tests/test_scan.py +++ b/api/src/backend/tasks/tests/test_scan.py @@ -1,3 +1,4 @@ +import uuid from unittest.mock import MagicMock, patch import pytest @@ -212,7 +213,7 @@ def test_store_resources_new_resource( mock_get_or_create_resource, mock_get_or_create_tag, ): - tenant_id = "tenant123" + tenant_id = uuid.uuid4() provider_instance = MagicMock() provider_instance.id = "provider456" @@ -260,7 +261,7 @@ def test_store_resources_existing_resource( mock_get_or_create_resource, mock_get_or_create_tag, ): - tenant_id = "tenant123" + tenant_id = uuid.uuid4() provider_instance = MagicMock() provider_instance.id = "provider456" @@ -317,7 +318,7 @@ def test_store_resources_with_tags( mock_get_or_create_resource, mock_get_or_create_tag, ): - tenant_id = "tenant123" + tenant_id = uuid.uuid4() provider_instance = MagicMock() provider_instance.id = "provider456"