Skip to content

Commit

Permalink
Merge pull request #1 from gianpieropa/new-field-mixin
Browse files Browse the repository at this point in the history
new read only mixin
  • Loading branch information
gianpieropa authored Apr 11, 2022
2 parents 66fba22 + 5d5cb3c commit 9cdc84e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 190 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ drf_access_policy.egg-info
.vscode
site/
venv.sh
venv/
.idea/
.venv
.coverage

/db.sqlite3
3 changes: 1 addition & 2 deletions rest_access_policy/access_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def get_user_group_values(self, user) -> List[str]:
def scope_queryset(cls, request, qs):
return qs.none()

@classmethod
def _get_invoked_action(cls, view) -> str:
def _get_invoked_action(self, view) -> str:
"""
If a CBV, the name of the method. If a regular function view,
the name of the function.
Expand Down
155 changes: 10 additions & 145 deletions rest_access_policy/field_access_mixin.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,13 @@
from typing import List

from rest_framework.request import Request

from rest_access_policy import AccessPolicyException
from .access_policy import AccessPolicy


class FieldAccessMixin(object):
def __init__(self, *args, **kwargs):
self.serializer_context = kwargs.get("context", {})
is_many = kwargs.get('many', False)
super().__init__(*args, **kwargs)
self.should_check_conditions = is_many is False and getattr(self,"instance",None) is not None
self._apply_fields_access()

if (
self.request.method
in [
"POST",
"PUT",
"PATCH",
]
and self.field_permissions.get("read_only")
):
self._set_read_only_fields()

@property
def access_policy(self) -> AccessPolicy:
Expand All @@ -36,6 +21,9 @@ def access_policy(self) -> AccessPolicy:
if not access_policy:
raise Exception("Must set access_policy inside Meta for FieldAccessMixin")

if getattr(access_policy, "scope_fields", None) is None:
raise Exception("Must define scope_fields method on access_policy")

return access_policy

@property
Expand All @@ -47,133 +35,10 @@ def request(self) -> Request:

return request

@property
def action(self) -> str:
view = self.serializer_context.get("view")
if not view:
raise Exception("Must pass context with view to FieldAccessMixin")
action = self.access_policy._get_invoked_action(view)
return action


@property
def field_permissions(self) -> dict:
access_policy = self.access_policy

field_permissions = getattr(access_policy, "field_permissions", {})

if not isinstance(field_permissions, dict):
raise Exception("Field permissions must be set on access_policy for FieldAccessMixin")

return field_permissions

def _get_statements_matching_conditions(self, statements: List[dict]):
"""
Filter statements and only return those that match all of their
custom context conditions; if no conditions are provided then
the statement should be returned.
"""
matched = []

for statement in statements:
conditions = statement["condition"]

if len(conditions) == 0:
matched.append(statement)
continue

fails = 0

for condition in conditions:
passed = self._check_condition(condition)

if not passed:
fails += 1
break

if fails == 0:
matched.append(statement)

return matched

def _get_statements_matching_action(self, statements: List[dict]):
"""
Filter statements and return only those that match the specified
action.
"""
matched = []
http_method = "<method:%s>" % self.request.method.lower()

for statement in statements:
if self.action in statement["action"] or "*" in statement["action"]:
matched.append(statement)
elif http_method in statement["action"]:
matched.append(statement)
return matched

def _check_condition(self, condition):
"""
Evaluate a custom context condition;
"""
result = condition(self.instance)

if type(result) is not bool:
raise AccessPolicyException(
"condition '%s' must return true/false, not %s" % (condition, type(result))
)

return result

def _set_read_only_fields(self):
read_only_statements = self._validate_and_clean_statements(
self.field_permissions["read_only"]
)

matched = self.access_policy._get_statements_matching_principal(
request=self.request, statements=read_only_statements
)
matched = self._get_statements_matching_action(statements=matched)
if self.should_check_conditions:
matched = self._get_statements_matching_conditions(statements=matched)
for statement in matched:
if "*" in statement["fields"]:
for field in self.fields.values():
field.read_only = True
break
else:
for field in statement["fields"]:
if self.fields.get(field, None) is not None:
self.fields[field].read_only = True

def _validate_and_clean_statements(self, statements: List[dict]) -> List[dict]:
for statement in statements:
if not isinstance(statement, dict):
raise Exception("Must pass a dict as statement")

if len(statement) == 0:
raise Exception("Cannot pass empty dict as statement")

if statement.get("principal", None) is None:
raise Exception("Must pass principal in statement")

if statement.get("fields", None) is None:
raise Exception("Must pass fields in statement")

if isinstance(statement["principal"], str):
statement["principal"] = [statement["principal"]]

if isinstance(statement["fields"], str):
statement["fields"] = [statement["fields"]]

if "action" not in statement:
statement["action"] = ["*"]

if isinstance(statement["action"], str):
statement["action"] = [statement["action"]]

if "condition" not in statement:
statement["condition"] = []
elif callable(statement["condition"]):
statement["condition"] = [statement["condition"]]
def _apply_fields_access(self):
fields = self.access_policy.scope_fields(self.request, self.fields, instance=self.instance)
if fields is None:
raise Exception("scope_fields method must return fields variable")
self.fields = self.access_policy.scope_fields(self.request, self.fields, instance=self.instance)

return statements

11 changes: 5 additions & 6 deletions test_project/testapp/access_policies.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from rest_access_policy import AccessPolicy


def is_account_mario(instance):
return instance.username == "mario"

class UserAccountAccessPolicy(AccessPolicy):
statements = [
{"principal": "group:admin", "action": ["create", "update"], "effect": "allow"},
Expand All @@ -20,9 +17,11 @@ class UserAccountAccessPolicy(AccessPolicy):
},
]

field_permissions = {"read_only": [{"principal": "group:dev", "fields": "status"},
{"principal": "*", "fields": "last_name","condition":[is_account_mario]}]}

@classmethod
def scope_fields(cls, request, fields: dict, instance=None):
if request.user.groups.filter(name="dev"):
fields["status"].read_only = True
return fields


class LogsAccessPolicy(AccessPolicy):
Expand Down
34 changes: 0 additions & 34 deletions test_project/testapp/tests/test_view_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,37 +83,3 @@ def test_partial_update_should_not_update_status_for_dev_group(self):
)
self.assertEqual(response.data["last_name"], "Mercury")
self.assertEqual(response.data["status"], "active")

def test_partial_update_should_not_update_status_for_account_mario(self):
account = UserAccount.objects.create(
username="mario", first_name="Mario", last_name="Rogers"
)

dev = Group.objects.create(name="dev")
user = User.objects.create()
user.groups.add(dev)
self.client.force_authenticate(user=user)

url = reverse("account-detail", args=[account.id])

response = self.client.patch(
url, data={"last_name": "Mercury"}, format="json"
)
self.assertEqual(response.data["last_name"], "Rogers")

def test_partial_update_should_not_update_status_for_account_pino(self):
account = UserAccount.objects.create(
username="pino", first_name="Mario", last_name="Rogers"
)

dev = Group.objects.create(name="dev")
user = User.objects.create()
user.groups.add(dev)
self.client.force_authenticate(user=user)

url = reverse("account-detail", args=[account.id])

response = self.client.patch(
url, data={"last_name": "Mercury"}, format="json"
)
self.assertEqual(response.data["last_name"], "Mercury")

0 comments on commit 9cdc84e

Please sign in to comment.