Skip to content

Commit

Permalink
fix(RLS): enforce config security (#6066)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfagoagas authored Dec 13, 2024
1 parent 32f69d2 commit da4f9b8
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 38 deletions.
33 changes: 6 additions & 27 deletions api/src/backend/api/base_views.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import uuid

from django.db import connection, transaction
from django.db import transaction
from rest_framework import permissions
from rest_framework.exceptions import NotAuthenticated
from rest_framework.filters import SearchFilter
from rest_framework_json_api import filters
from rest_framework_json_api.serializers import ValidationError
from rest_framework_json_api.views import ModelViewSet
from rest_framework_simplejwt.authentication import JWTAuthentication

from api.db_utils import POSTGRES_USER_VAR, tenant_transaction
from api.filters import CustomDjangoFilterBackend


Expand Down Expand Up @@ -47,13 +45,7 @@ def initial(self, request, *args, **kwargs):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with tenant_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

Expand All @@ -75,8 +67,7 @@ def initial(self, request, *args, **kwargs):
):
user_id = str(request.user.id)

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.user_id', '{user_id}', TRUE);")
with tenant_transaction(value=user_id, parameter=POSTGRES_USER_VAR):
return super().initial(request, *args, **kwargs)

# TODO: DRY this when we have time
Expand All @@ -87,13 +78,7 @@ def initial(self, request, *args, **kwargs):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with tenant_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)

Expand All @@ -114,12 +99,6 @@ def initial(self, request, *args, **kwargs):
if tenant_id is None:
raise NotAuthenticated("Tenant ID is not present in token")

try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")

with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
with tenant_transaction(tenant_id):
self.request.tenant_id = tenant_id
return super().initial(request, *args, **kwargs)
21 changes: 19 additions & 2 deletions api/src/backend/api/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import secrets
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone

Expand All @@ -8,6 +9,7 @@
from django.db import connection, models, transaction
from psycopg2 import connect as psycopg2_connect
from psycopg2.extensions import AsIs, new_type, register_adapter, register_type
from rest_framework_json_api.serializers import ValidationError

DB_USER = settings.DATABASES["default"]["USER"] if not settings.TESTING else "test"
DB_PASSWORD = (
Expand All @@ -23,6 +25,8 @@
POSTGRES_TENANT_VAR = "api.tenant_id"
POSTGRES_USER_VAR = "api.user_id"

SET_CONFIG_QUERY = "SELECT set_config(%s, %s::text, TRUE);"


@contextmanager
def psycopg_connection(database_alias: str):
Expand All @@ -44,10 +48,23 @@ def psycopg_connection(database_alias: str):


@contextmanager
def tenant_transaction(tenant_id: str):
def tenant_transaction(value: str, parameter: str = POSTGRES_TENANT_VAR):
"""
Creates a new database transaction setting the given configuration value. It validates the
if the value is a valid UUID to be used for Postgres RLS.
Args:
value (str): Database configuration parameter value.
parameter (str): Database configuration parameter name
"""
with transaction.atomic():
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
try:
# just in case the value is an UUID object
uuid.UUID(str(value))
except ValueError:
raise ValidationError("Must be a valid UUID")
cursor.execute(SET_CONFIG_QUERY, [parameter, value])
yield cursor


Expand Down
13 changes: 10 additions & 3 deletions api/src/backend/api/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import uuid
from functools import wraps

from django.db import connection, transaction
from rest_framework_json_api.serializers import ValidationError

from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY


def set_tenant(func):
Expand Down Expand Up @@ -31,7 +35,7 @@ def some_task(arg1, **kwargs):
pass
# When calling the task
some_task.delay(arg1, tenant_id="1234-abcd-5678")
some_task.delay(arg1, tenant_id="8db7ca86-03cc-4d42-99f6-5e480baf6ab5")
# The tenant context will be set before the task logic executes.
"""
Expand All @@ -43,9 +47,12 @@ def wrapper(*args, **kwargs):
tenant_id = kwargs.pop("tenant_id")
except KeyError:
raise KeyError("This task requires the tenant_id")

try:
uuid.UUID(tenant_id)
except ValueError:
raise ValidationError("Tenant ID must be a valid UUID")
with connection.cursor() as cursor:
cursor.execute(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
cursor.execute(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])

return func(*args, **kwargs)

Expand Down
8 changes: 5 additions & 3 deletions api/src/backend/api/tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import patch, call
import uuid
from unittest.mock import call, patch

import pytest

from api.db_utils import POSTGRES_TENANT_VAR, SET_CONFIG_QUERY
from api.decorators import set_tenant


Expand All @@ -15,12 +17,12 @@ def test_set_tenant(self, mock_cursor):
def random_func(arg):
return arg

tenant_id = "1234-abcd-5678"
tenant_id = str(uuid.uuid4())

result = random_func("test_arg", tenant_id=tenant_id)

assert (
call(f"SELECT set_config('api.tenant_id', '{tenant_id}', TRUE);")
call(SET_CONFIG_QUERY, [POSTGRES_TENANT_VAR, tenant_id])
in mock_cursor.execute.mock_calls
)
assert result == "test_arg"
Expand Down
7 changes: 4 additions & 3 deletions api/src/backend/tasks/tests/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -212,7 +213,7 @@ def test_store_resources_new_resource(
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"

Expand Down Expand Up @@ -260,7 +261,7 @@ def test_store_resources_existing_resource(
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"

Expand Down Expand Up @@ -317,7 +318,7 @@ def test_store_resources_with_tags(
mock_get_or_create_resource,
mock_get_or_create_tag,
):
tenant_id = "tenant123"
tenant_id = uuid.uuid4()
provider_instance = MagicMock()
provider_instance.id = "provider456"

Expand Down

0 comments on commit da4f9b8

Please sign in to comment.