Skip to content

Commit

Permalink
ref(rbac): improve rbac implementation for views (#6226)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdriiiPRodri authored Dec 17, 2024
1 parent ec9455f commit fa400de
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 149 deletions.
15 changes: 14 additions & 1 deletion api/src/backend/api/base_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from api.db_utils import POSTGRES_USER_VAR, rls_transaction
from api.filters import CustomDjangoFilterBackend
from api.models import Role, Tenant
from api.rbac.permissions import HasPermissions


class BaseViewSet(ModelViewSet):
authentication_classes = [JWTAuthentication]
permission_classes = [permissions.IsAuthenticated]
required_permissions = []
permission_classes = [permissions.IsAuthenticated, HasPermissions]
filter_backends = [
filters.QueryParameterValidationFilter,
filters.OrderingFilter,
Expand All @@ -29,6 +31,17 @@ class BaseViewSet(ModelViewSet):
ordering_fields = "__all__"
ordering = ["id"]

def initial(self, request, *args, **kwargs):
"""
Sets required_permissions before permissions are checked.
"""
self.set_required_permissions()
super().initial(request, *args, **kwargs)

def set_required_permissions(self):
"""This is an abstract method that must be implemented by subclasses."""
NotImplemented

def get_queryset(self):
raise NotImplementedError

Expand Down
34 changes: 33 additions & 1 deletion api/src/backend/api/rbac/permissions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from enum import Enum
from rest_framework.permissions import BasePermission
from api.models import User
from api.models import Provider, Role, User
from api.db_router import MainRouter
from typing import Optional
from django.db.models import QuerySet


class Permissions(Enum):
Expand Down Expand Up @@ -36,3 +38,33 @@ def has_permission(self, request, view):
return False

return True


def get_role(user: User) -> Optional[Role]:
"""
Retrieve the first role assigned to the given user.
Returns:
The user's first Role instance if the user has any roles, otherwise None.
"""
return user.roles.first()


def get_providers(role: Role) -> QuerySet[Provider]:
"""
Return a distinct queryset of Providers accessible by the given role.
If the role has no associated provider groups, an empty queryset is returned.
Args:
role: A Role instance.
Returns:
A QuerySet of Provider objects filtered by the role's provider groups.
If the role has no provider groups, returns an empty queryset.
"""
provider_groups = role.provider_groups.all()
if not provider_groups.exists():
return Provider.objects.none()

return Provider.objects.filter(provider_groups__in=provider_groups).distinct()
9 changes: 4 additions & 5 deletions api/src/backend/api/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ def test_tenants_list_memberships_as_owner(
# Test user + 2 extra users for tenant 2
assert len(response.json()["data"]) == 3

@patch("api.v1.views.TenantMembersViewSet.required_permissions", [])
def test_tenants_list_memberships_as_member(
self, authenticated_client, tenants_fixture, extra_users
):
Expand Down Expand Up @@ -3274,9 +3275,7 @@ def test_role_list_filters(self, authenticated_client, roles_fixture):
assert len(data) == 1
assert data[0]["attributes"]["name"] == role.name

def test_role_list_sorting(
self, authenticated_client, set_user_admin_roles_fixture, roles_fixture
):
def test_role_list_sorting(self, authenticated_client, roles_fixture):
response = authenticated_client.get(reverse("role-list"), {"sort": "name"})
assert response.status_code == status.HTTP_200_OK
data = response.json()["data"]
Expand Down Expand Up @@ -3342,7 +3341,7 @@ def test_partial_update_relationship(
):
data = {
"data": [
{"type": "role", "id": str(roles_fixture[1].id)},
{"type": "role", "id": str(roles_fixture[2].id)},
]
}
response = authenticated_client.patch(
Expand All @@ -3353,7 +3352,7 @@ def test_partial_update_relationship(
assert response.status_code == status.HTTP_204_NO_CONTENT
relationships = UserRoleRelationship.objects.filter(user=create_test_user.id)
assert relationships.count() == 1
assert {rel.role.id for rel in relationships} == {roles_fixture[1].id}
assert {rel.role.id for rel in relationships} == {roles_fixture[2].id}

data = {
"data": [
Expand Down
Loading

0 comments on commit fa400de

Please sign in to comment.