Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom validators, permission classes, and PermissionTestModel fo… #9567

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,48 @@ def to_internal_value(self, data):
return super().to_internal_value(data)


class AlphabeticFieldValidator:
"""
Custom validator to ensure that a field only contains alphabetic characters and spaces.
"""
def __call__(self, value):
if not isinstance(value, str):
raise ValueError("This field must be a string.")
if value == "":
raise ValueError("This field must contain only alphabetic characters and spaces.")
if not re.match(r'^[A-Za-z ]*$', value):
raise ValueError("This field must contain only alphabetic characters and spaces.")


class AlphanumericFieldValidator:
"""
Custom validator to ensure the field contains only alphanumeric characters (letters and numbers).
"""
def __call__(self, value):
if not isinstance(value, str):
raise ValueError("This field must be a string.")
if value == "":
raise ValueError("This field must contain only alphanumeric characters (letters and numbers).")
if not re.match(r'^[A-Za-z0-9]*$', value):
raise ValueError("This field must contain only alphanumeric characters (letters and numbers).")


class CustomLengthValidator:
"""
Custom validator to ensure the length of a string is within specified limits.
"""
def __init__(self, min_length=0, max_length=None):
self.min_length = min_length
self.max_length = max_length

def __call__(self, value):
if len(value) < self.min_length:
raise ValueError(f"This field must be at least {self.min_length} characters long.")

if self.max_length is not None and len(value) > self.max_length:
raise ValueError(f"This field must be no more than {self.max_length} characters long.")


# Number types...

class IntegerField(Field):
Expand Down
26 changes: 26 additions & 0 deletions rest_framework/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,32 @@ def has_permission(self, request, view):
)


class IsAdminUserOrReadOnly(BasePermission):
"""
Custom permission to only allow admin users to edit an object.
"""

def has_permission(self, request, view):
# Allow any user to view the object
if request.method in ['GET', 'HEAD', 'OPTIONS']:
return True
# Only allow admin users to modify the object
return request.user and request.user.is_staff


class IsOwner(BasePermission):
"""
Custom permission to only allow owners of an object to edit it.
"""

def has_object_permission(self, request, view, obj):
# Allow read-only access to any request
if request.method in ['GET', 'HEAD', 'OPTIONS']:
return True
# Write permissions are only allowed to the owner of the object
return obj.owner == request.user


class DjangoModelPermissions(BasePermission):
"""
The request is authenticated using `django.contrib.auth` permissions.
Expand Down
8 changes: 8 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,11 @@ def __new__(cls, *args, **kwargs):
help_text='OneToOneTarget',
verbose_name='OneToOneTarget',
on_delete=models.CASCADE)


class OwnershipTestModel(models.Model):
owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name='ownership_test_models')
title = models.CharField(max_length=100)

def __str__(self):
return self.title
112 changes: 110 additions & 2 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
import rest_framework
from rest_framework import exceptions, serializers
from rest_framework.fields import (
BuiltinSignatureError, DjangoImageField, SkipField, empty,
is_simple_callable
AlphabeticFieldValidator, AlphanumericFieldValidator,
BuiltinSignatureError, CustomLengthValidator, DjangoImageField, SkipField,
empty, is_simple_callable
)
from tests.models import UUIDForeignKeyTarget

Expand Down Expand Up @@ -1061,6 +1062,113 @@ class TestFilePathField(FieldValues):
)


class TestAlphabeticField:
valid_inputs = {
'John Doe': 'John Doe',
'Alice': 'Alice',
'Bob Marley': 'Bob Marley',
}
invalid_inputs = {
'John123': ['This field must contain only alphabetic characters and spaces.'],
'Alice!': ['This field must contain only alphabetic characters and spaces.'],
'': ['This field must contain only alphabetic characters and spaces.'],
}
non_string_inputs = [
123, # Integer
45.67, # Float
None, # NoneType
[], # Empty list
{}, # Empty dict
set() # Empty set
]

def test_valid_inputs(self):
validator = AlphabeticFieldValidator()
for value in self.valid_inputs.keys():
validator(value)

def test_invalid_inputs(self):
validator = AlphabeticFieldValidator()
for value, expected_errors in self.invalid_inputs.items():
with pytest.raises(ValueError) as excinfo:
validator(value)
assert str(excinfo.value) == expected_errors[0]

def test_non_string_inputs(self):
validator = AlphabeticFieldValidator()
for value in self.non_string_inputs:
with pytest.raises(ValueError) as excinfo:
validator(value)
assert str(excinfo.value) == "This field must be a string."


class TestAlphanumericField:
valid_inputs = {
'John123': 'John123',
'Alice007': 'Alice007',
'Bob1990': 'Bob1990',
}
invalid_inputs = {
'John!': ['This field must contain only alphanumeric characters (letters and numbers).'],
'Alice 007': ['This field must contain only alphanumeric characters (letters and numbers).'],
'': ['This field must contain only alphanumeric characters (letters and numbers).'],
}
non_string_inputs = [
123, # Integer
45.67, # Float
None, # NoneType
[], # Empty list
{}, # Empty dict
set() # Empty set
]

def test_valid_inputs(self):
validator = AlphanumericFieldValidator()
for value in self.valid_inputs.keys():
validator(value)

def test_invalid_inputs(self):
validator = AlphanumericFieldValidator()
for value, expected_errors in self.invalid_inputs.items():
with pytest.raises(ValueError) as excinfo:
validator(value)
assert str(excinfo.value) == expected_errors[0]

def test_non_string_inputs(self):
validator = AlphanumericFieldValidator()
for value in self.non_string_inputs:
with pytest.raises(ValueError) as excinfo:
validator(value)
assert str(excinfo.value) == "This field must be a string."


class TestCustomLengthField:
"""
Valid and invalid values for `CustomLengthValidator`.
"""
valid_inputs = {
'abc': 'abc', # 3 characters
'abcdefghij': 'abcdefghij', # 10 characters
}
invalid_inputs = {
'ab': ['This field must be at least 3 characters long.'], # Too short
'abcdefghijk': ['This field must be no more than 10 characters long.'], # Too long
}
field = str

def test_valid_inputs(self):
validator = CustomLengthValidator(min_length=3, max_length=10)
for value in self.valid_inputs.keys():
validator(value)

def test_invalid_inputs(self):
validator = CustomLengthValidator(min_length=3, max_length=10)
for value, expected_errors in self.invalid_inputs.items():
with pytest.raises(ValueError) as excinfo:
validator(value)
assert str(excinfo.value) == expected_errors[0]


# Number types...

class TestIntegerField(FieldValues):
Expand Down
50 changes: 49 additions & 1 deletion tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory
from tests.models import BasicModel
from tests.models import BasicModel, OwnershipTestModel

factory = APIRequestFactory()

Expand Down Expand Up @@ -772,3 +772,51 @@ def test_filtering_permissions(self):
]

assert filtered_permissions == expected_permissions


class PermissionTests(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
self.admin_user = User.objects.create_user(username='admin', password='password', is_staff=True)
self.regular_user = User.objects.create_user(username='user', password='password')
self.anonymous_user = AnonymousUser()

def test_is_admin_user_or_read_only_allow_read(self):
request = self.factory.get('/1', format='json')
request.user = self.anonymous_user
permission = permissions.IsAdminUserOrReadOnly()
self.assertTrue(permission.has_permission(request, None))

request.user = self.admin_user
self.assertTrue(permission.has_permission(request, None))

def test_is_admin_user_or_read_only_allow_write(self):
request = self.factory.post('/1', format='json')
request.user = self.admin_user
permission = permissions.IsAdminUserOrReadOnly()
self.assertTrue(permission.has_permission(request, None))

request.user = self.regular_user
self.assertFalse(permission.has_permission(request, None))

def test_is_owner_permission(self):
obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title')

request = self.factory.post('/1', format='json')
request.user = self.admin_user
permission = permissions.IsOwner()
self.assertTrue(permission.has_object_permission(request, None, obj))

request.user = self.regular_user
self.assertFalse(permission.has_object_permission(request, None, obj))

def test_is_owner_read_access(self):
obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title')

request = self.factory.get('/1', format='json')
request.user = self.regular_user
permission = permissions.IsOwner()
self.assertTrue(permission.has_object_permission(request, None, obj))

request.user = self.admin_user
self.assertTrue(permission.has_object_permission(request, None, obj))
Loading