Skip to content

Commit

Permalink
[change] Disable API operations on deactivated devices
Browse files Browse the repository at this point in the history
  • Loading branch information
pandafy committed Aug 9, 2024
1 parent 56e801c commit 5d7f122
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 35 deletions.
7 changes: 4 additions & 3 deletions openwisp_controller/config/base/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ def save(self, *args, **kwargs):
if not state_adding:
self._check_changed_fields()

def delete(self, using=None, keep_parents=False):
if not self.is_deactivated() or (
self._has_config() and not self.config.is_deactivated()
def delete(self, using=None, keep_parents=False, check_deactivated=True):
if check_deactivated and (
not self.is_deactivated()
or (self._has_config() and not self.config.is_deactivated())
):
raise PermissionDenied('The device should be deactivated before deleting')
return super().delete(using, keep_parents)
Expand Down
9 changes: 9 additions & 0 deletions openwisp_controller/config/tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,15 @@ def test_download_device_config(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get('content-type'), 'application/octet-stream')

def test_download_deactivated_device_config(self):
device = self._create_device(name='download')
self._create_config(device=device)
device.deactivate()
path = reverse(f'admin:{self.app_label}_device_download', args=[device.pk.hex])
response = self.client.get(path)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get('content-type'), 'application/octet-stream')

def test_download_device_config_404(self):
d = self._create_device(name='download')
path = reverse(f'admin:{self.app_label}_device_download', args=[d.pk])
Expand Down
2 changes: 1 addition & 1 deletion openwisp_controller/config/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def test_remove_duplicate_files(self):
else:
self.assertIn('# path: /etc/vpnserver1', result)

config.device.delete()
config.device.delete(check_deactivated=False)
config.delete()
with self.subTest('Test template applied after creating config object'):
config = self._create_config(organization=org)
Expand Down
2 changes: 1 addition & 1 deletion openwisp_controller/config/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_device_get_object_cached(self):
self.assertEqual(obj.os, 'test_cache')

with self.subTest('test cache invalidation on device delete'):
d.delete()
d.delete(check_deactivated=False)
with self.assertNumQueries(1):
with self.assertRaises(Http404):
view.get_device()
Expand Down
2 changes: 1 addition & 1 deletion openwisp_controller/config/tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_default_template(self):
org = self._get_org()
c = self._create_config(organization=org)
self.assertEqual(c.templates.count(), 0)
c.device.delete()
c.device.delete(check_deactivated=False)
# create default templates for different backends
t1 = self._create_template(
name='default-openwrt', backend='netjsonconfig.OpenWrt', default=True
Expand Down
8 changes: 4 additions & 4 deletions openwisp_controller/config/tests/test_vpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def test_ip_deleted_when_vpnclient_deleted(self):
def test_ip_deleted_when_device_deleted(self):
device, vpn, template = self._create_wireguard_vpn_template()
self.assertEqual(device.config.vpnclient_set.count(), 1)
device.delete()
device.delete(check_deactivated=False)
self.assertEqual(IpAddress.objects.count(), 1)

def test_delete_vpnclient_ip(self):
Expand Down Expand Up @@ -745,7 +745,7 @@ def test_auto_peer_configuration(self):
self.assertEqual(len(vpn_config.get('peers', [])), 2)

with self.subTest('cache updated when a new peer is deleted'):
device2.delete()
device2.delete(check_deactivated=False)
# cache is invalidated but not updated
# hence we expect queries to be generated
with self.assertNumQueries(1):
Expand Down Expand Up @@ -954,7 +954,7 @@ def test_auto_peer_configuration(self):
self.assertEqual(len(peers), 2)

with self.subTest('cache updated when a new peer is deleted'):
device2.delete()
device2.delete(check_deactivated=False)
# cache is invalidated but not updated
# hence we expect queries to be generated
with self.assertNumQueries(2):
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def test_ip_deleted_when_device_deleted(self, mock_requests, mock_subprocess):
self.assertEqual(mock_subprocess.run.call_count, 1)
self.assertEqual(device.config.vpnclient_set.count(), 1)
self.assertEqual(IpAddress.objects.count(), 2)
device.delete()
device.delete(check_deactivated=False)
self.assertEqual(IpAddress.objects.count(), 1)

@mock.patch(_ZT_GENERATE_IDENTITY_SUBPROCESS)
Expand Down
20 changes: 15 additions & 5 deletions openwisp_controller/connection/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from openwisp_users.api.mixins import FilterByParentManaged
from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin

from ...mixins import ProtectedAPIMixin
from ...mixins import (
ProtectedAPIMixin,
RelatedDeviceModelPermission,
RelatedDeviceProtectedAPIMixin,
)
from .serializer import (
CommandSerializer,
CredentialSerializer,
Expand All @@ -32,11 +36,17 @@ class ListViewPagination(pagination.PageNumberPagination):
max_page_size = 100


class BaseCommandView(FilterByParentManaged, BaseProtectedAPIMixin):
class BaseCommandView(
BaseProtectedAPIMixin,
FilterByParentManaged,
):
model = Command
queryset = Command.objects.prefetch_related('device')
serializer_class = CommandSerializer

def get_permissions(self):
return super().get_permissions() + [RelatedDeviceModelPermission()]

def get_parent_queryset(self):
return Device.objects.filter(
pk=self.kwargs['id'],
Expand Down Expand Up @@ -82,7 +92,7 @@ class CredentialDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView):
serializer_class = CredentialSerializer


class BaseDeviceConection(ProtectedAPIMixin, GenericAPIView):
class BaseDeviceConnection(RelatedDeviceProtectedAPIMixin, GenericAPIView):
model = DeviceConnection
serializer_class = DeviceConnectionSerializer

Expand All @@ -109,7 +119,7 @@ def get_parent_queryset(self):
return Device.objects.filter(pk=self.kwargs['pk'])


class DeviceConnenctionListCreateView(BaseDeviceConection, ListCreateAPIView):
class DeviceConnenctionListCreateView(BaseDeviceConnection, ListCreateAPIView):
pagination_class = ListViewPagination

def get_queryset(self):
Expand All @@ -121,7 +131,7 @@ def get_queryset(self):
)


class DeviceConnectionDetailView(BaseDeviceConection, RetrieveUpdateDestroyAPIView):
class DeviceConnectionDetailView(BaseDeviceConnection, RetrieveUpdateDestroyAPIView):
def get_object(self):
queryset = self.filter_queryset(self.get_queryset())
filter_kwargs = {
Expand Down
101 changes: 96 additions & 5 deletions openwisp_controller/connection/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,35 @@ def test_endpoints_for_non_existent_device(self):
self.assertEqual(response.status_code, 404)
self.assertDictEqual(response.data, device_not_found)

def test_endpoints_for_deactivated_device(self):
self.device_conn.device.deactivate()

with self.subTest('Test listing commands'):
url = self._get_path('device_command_list', self.device_id)
response = self.client.get(
url,
)
self.assertEqual(response.status_code, 200)

with self.subTest('Test creating commands'):
url = self._get_path('device_command_list', self.device_id)
payload = {
'type': 'custom',
'input': {'command': 'echo test'},
}
response = self.client.post(
url, data=payload, content_type='application/json'
)
self.assertEqual(response.status_code, 403)

with self.subTest('Test retrieving commands'):
command = self._create_command(device_conn=self.device_conn)
url = self._get_path('device_command_details', self.device_id, command.id)
response = self.client.get(
url,
)
self.assertEqual(response.status_code, 200)

def test_non_superuser(self):
list_url = self._get_path('device_command_list', self.device_id)
command = self._create_command(device_conn=self.device_conn)
Expand Down Expand Up @@ -424,7 +453,7 @@ def test_post_deviceconnection_list(self):
'enabled': True,
'failure_reason': '',
}
with self.assertNumQueries(12):
with self.assertNumQueries(13):
response = self.client.post(path, data, content_type='application/json')
self.assertEqual(response.status_code, 201)

Expand All @@ -437,7 +466,7 @@ def test_post_deviceconenction_with_no_config_device(self):
'enabled': True,
'failure_reason': '',
}
with self.assertNumQueries(12):
with self.assertNumQueries(13):
response = self.client.post(path, data, content_type='application/json')
error_msg = '''
the update strategy can be determined automatically only if
Expand Down Expand Up @@ -469,7 +498,7 @@ def test_put_devceconnection_detail(self):
'enabled': False,
'failure_reason': '',
}
with self.assertNumQueries(14):
with self.assertNumQueries(16):
response = self.client.put(path, data, content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertEqual(
Expand All @@ -483,7 +512,7 @@ def test_patch_deviceconnectoin_detail(self):
path = reverse('connection_api:deviceconnection_detail', args=(d1, dc.pk))
self.assertEqual(dc.update_strategy, app_settings.UPDATE_STRATEGIES[0][0])
data = {'update_strategy': app_settings.UPDATE_STRATEGIES[1][0]}
with self.assertNumQueries(13):
with self.assertNumQueries(15):
response = self.client.patch(path, data, content_type='application/json')
self.assertEqual(response.status_code, 200)
self.assertEqual(
Expand All @@ -494,7 +523,7 @@ def test_delete_deviceconnection_detail(self):
dc = self._create_device_connection()
d1 = dc.device.id
path = reverse('connection_api:deviceconnection_detail', args=(d1, dc.pk))
with self.assertNumQueries(9):
with self.assertNumQueries(11):
response = self.client.delete(path)
self.assertEqual(response.status_code, 204)

Expand Down Expand Up @@ -535,3 +564,65 @@ def test_bearer_authentication(self):
HTTP_AUTHORIZATION=f'Bearer {token}',
)
self.assertEqual(response.status_code, 200)

def test_deactivated_device(self):
credentials = self._create_credentials(auto_add=True)
device = self._create_config(organization=credentials.organization).device
device_conn = device.deviceconnection_set.first()
create_api_path = reverse(
'connection_api:deviceconnection_list', args=(device.pk,)
)
detail_api_path = reverse(
'connection_api:deviceconnection_detail',
args=[device.id, device_conn.id],
)
device.deactivate()

with self.subTest('Test creating DeviceConnection'):
response = self.client.post(
create_api_path,
data={
'credentials': credentials.pk,
'update_strategy': app_settings.UPDATE_STRATEGIES[0][0],
'enabled': True,
'failure_reason': '',
},
content_type='application/json',
)
self.assertEqual(response.status_code, 403)

with self.subTest('Test listing DeviceConnection'):
response = self.client.get(
create_api_path,
)
self.assertEqual(response.status_code, 200)

with self.subTest('Test retrieving DeviceConnection detail'):
response = self.client.get(
detail_api_path,
)
self.assertEqual(response.status_code, 200)

with self.subTest('Test updating DeviceConnection'):
response = self.client.put(
detail_api_path,
{
'credentials': credentials.pk,
'update_strategy': app_settings.UPDATE_STRATEGIES[1][0],
'enabled': False,
'failure_reason': '',
},
content_type='application/json',
)
self.assertEqual(response.status_code, 403)

response = self.client.patch(
detail_api_path, {'enabled': False}, content_type='application/json'
)
self.assertEqual(response.status_code, 403)

with self.subTest('Test deleting DeviceConnection'):
response = self.client.delete(
detail_api_path,
)
self.assertEqual(response.status_code, 403)
16 changes: 13 additions & 3 deletions openwisp_controller/geo/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.http import Http404
from django_filters import rest_framework as filters
from rest_framework import generics, pagination, status
from rest_framework.exceptions import NotFound
from rest_framework.exceptions import NotFound, PermissionDenied
from rest_framework.permissions import BasePermission
from rest_framework.request import clone_request
from rest_framework.response import Response
Expand All @@ -14,7 +14,10 @@
from openwisp_users.api.filters import OrganizationManagedFilter
from openwisp_users.api.mixins import FilterByOrganizationManaged, FilterByParentManaged

from ...mixins import ProtectedAPIMixin
from ...mixins import (
ProtectedAPIMixin,
RelatedDeviceProtectedAPIMixin,
)
from .filters import DeviceListFilter
from .serializers import (
DeviceCoordinatesSerializer,
Expand Down Expand Up @@ -77,6 +80,8 @@ def get_location(self, device):

def get_object(self, *args, **kwargs):
device = super().get_object()
if self.request.method not in ('GET', 'HEAD') and device.is_deactivated():
raise PermissionDenied
location = self.get_location(device)
if location:
return location
Expand All @@ -102,7 +107,7 @@ def create_location(self, device):


class DeviceLocationView(
ProtectedAPIMixin,
RelatedDeviceProtectedAPIMixin,
generics.RetrieveUpdateDestroyAPIView,
):
serializer_class = DeviceLocationSerializer
Expand All @@ -120,6 +125,11 @@ def get_queryset(self):
except ValidationError:
return qs.none()

def get_parent_queryset(self):
return Device.objects.filter(
pk=self.kwargs['pk'],
)

def get_serializer_context(self):
context = super().get_serializer_context()
context.update({'device_id': self.kwargs['pk']})
Expand Down
Loading

0 comments on commit 5d7f122

Please sign in to comment.