Skip to content

Commit

Permalink
Stop a file from being deleted from S3 if the file is also pointed to…
Browse files Browse the repository at this point in the history
… by another Document object

The only case where this won't happen is if a file is found to be unsafe and in that case we still remove the file
  • Loading branch information
kevincarrogan committed Sep 30, 2024
1 parent 96d848a commit 7496a2c
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 57 deletions.
82 changes: 35 additions & 47 deletions api/documents/libraries/tests/test_s3_operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging

from contextlib import contextmanager
from unittest.mock import Mock, patch
from unittest.mock import (
Mock,
patch,
)

from moto import mock_aws

Expand All @@ -10,10 +12,11 @@
ReadTimeoutError,
)

from django.conf import settings
from django.http import FileResponse
from django.test import override_settings, SimpleTestCase

from test_helpers.s3 import S3TesterHelper

from ..s3_operations import (
delete_file,
document_download_stream,
Expand Down Expand Up @@ -134,33 +137,19 @@ def test_get_object_boto_core_error(self, mock_client):
)


@contextmanager
def _create_bucket(s3):
s3.create_bucket(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
CreateBucketConfiguration={
"LocationConstraint": settings.AWS_REGION,
},
)
yield


@mock_aws
class S3OperationsDeleteFileTests(SimpleTestCase):
def setUp(self, *args, **kwargs):
super().setUp(*args, **kwargs)

self.s3_test_helper = S3TesterHelper()

def test_delete_file(self):
s3 = init_s3_client()
with _create_bucket(s3):
s3.put_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key="s3-key",
Body=b"test",
)
self.s3_test_helper.add_test_file("s3-key", b"test")

delete_file("document-id", "s3-key")
delete_file("document-id", "s3-key")

objs = s3.list_objects(Bucket=settings.AWS_STORAGE_BUCKET_NAME)
keys = [o["Key"] for o in objs.get("Contents", [])]
self.assertNotIn("s3-key", keys)
self.s3_test_helper.assert_file_not_in_s3("s3-key")

@patch("api.documents.libraries.s3_operations._client")
def test_delete_file_read_timeout_error(self, mock_client):
Expand Down Expand Up @@ -197,35 +186,34 @@ def test_delete_file_boto_core_error(self, mock_client):

@mock_aws
class S3OperationsUploadBytesFileTests(SimpleTestCase):
def setUp(self, *args, **kwargs):
super().setUp(*args, **kwargs)

self.s3_test_helper = S3TesterHelper()

def test_upload_bytes_file(self):
s3 = init_s3_client()
with _create_bucket(s3):
upload_bytes_file(b"test", "s3-key")
upload_bytes_file(b"test", "s3-key")

obj = s3.get_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key="s3-key",
)
self.assertEqual(obj["Body"].read(), b"test")
self.s3_test_helper.assert_file_in_s3("s3-key")
self.s3_test_helper.assert_file_body("s3-key", b"test")


@mock_aws
class S3OperationsDocumentDownloadStreamTests(SimpleTestCase):
def setUp(self, *args, **kwargs):
super().setUp(*args, **kwargs)

self.s3_test_helper = S3TesterHelper()

def test_document_download_stream(self):
s3 = init_s3_client()
with _create_bucket(s3):
s3.put_object(
Bucket=settings.AWS_STORAGE_BUCKET_NAME,
Key="s3-key",
Body=b"test",
)

mock_document = Mock()
mock_document.id = "document-id"
mock_document.s3_key = "s3-key"
mock_document.name = "test.doc"

response = document_download_stream(mock_document)
self.s3_test_helper.add_test_file("s3-key", b"test")

mock_document = Mock()
mock_document.id = "document-id"
mock_document.s3_key = "s3-key"
mock_document.name = "test.doc"

response = document_download_stream(mock_document)

self.assertIsInstance(response, FileResponse)
self.assertEqual(response.status_code, 200)
Expand Down
43 changes: 33 additions & 10 deletions api/documents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import uuid

from django.db import models
from django.utils.timezone import now
from django.utils import timezone

from api.common.models import TimestampableModel
from api.documents.libraries import s3_operations, av_operations


logger = logging.getLogger(__name__)


class Document(TimestampableModel):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=1000, null=False, blank=False)
Expand All @@ -20,25 +23,45 @@ class Document(TimestampableModel):
def __str__(self):
return self.name

def delete_s3(self):
def get_other_documents_sharing_file(self):
return Document.objects.filter(s3_key=self.s3_key).exclude(pk=self.pk)

def delete_s3(self, *, force_delete=False):
"""Removes the document's file from S3."""
file_shared_with_other_documents = self.get_other_documents_sharing_file().exists()
if not force_delete and file_shared_with_other_documents:
logger.info("Shared file %s was not deleted", self.s3_key)
return

s3_operations.delete_file(self.id, self.s3_key)

def set_virus_scan_result(self, document, is_safe, virus_scanned_at):
document.safe = is_safe
document.virus_scanned_at = virus_scanned_at
document.save()

def scan_for_viruses(self):
"""Retrieves the document's file from S3 and scans it for viruses."""

file = s3_operations.get_object(self.id, self.s3_key)

if not file:
logging.warning(f"Failed to retrieve file '{self.s3_key}' from S3 for document '{self.id}'")
logger.warning(
"Failed to retrieve file `%s` from S3 for document `%s` for virus scan",
self.s3_key,
self.id,
)

self.safe = not av_operations.scan_file_for_viruses(self.id, self.name, file)
self.virus_scanned_at = now()
self.save()
is_safe = not av_operations.scan_file_for_viruses(self.id, self.name, file)
virus_scanned_at = timezone.now()
self.set_virus_scan_result(self, is_safe, virus_scanned_at)

if not self.safe:
logging.warning(f"Document '{self.id}' is not safe")
self.delete_s3()
if not is_safe:
logger.warning("Document `%s` is not safe", self.id)
self.delete_s3(force_delete=True)
file_shared_with_other_documents = self.get_other_documents_sharing_file()
for other_document in file_shared_with_other_documents:
logger.warning("Other document `%s` is not safe because `%s` is not safe", other_document.id, self.id)
self.set_virus_scan_result(other_document, is_safe, virus_scanned_at)

return self.safe
return is_safe
2 changes: 2 additions & 0 deletions api/documents/tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


class DocumentFactory(factory.django.DjangoModelFactory):
safe = None
virus_scanned_at = None
s3_key = factory.Faker("file_name", category="office")

class Meta:
Expand Down
101 changes: 101 additions & 0 deletions api/documents/tests/test_document_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from unittest.mock import patch

from freezegun import freeze_time

from moto import mock_aws

from django.test import TestCase
from django.utils import timezone

from api.documents.tests.factories import DocumentFactory
from test_helpers.s3 import S3TesterHelper


@mock_aws
class DocumentModelTests(TestCase):
def setUp(self, *args, **kwargs):
super().setUp(*args, **kwargs)

self.s3_test_helper = S3TesterHelper()

def test_delete_s3_removes_object(self):
self.s3_test_helper.add_test_file("s3-key", b"test")

document = DocumentFactory(s3_key="s3-key")
document.delete_s3()

self.s3_test_helper.assert_file_not_in_s3("s3-key")

def test_delete_s3_on_shared_file_retains_object(self):
self.s3_test_helper.add_test_file("s3-key", b"test")

document = DocumentFactory(s3_key="s3-key")
another_document = DocumentFactory(s3_key="s3-key")

document.delete_s3()
self.s3_test_helper.assert_file_in_s3("s3-key")

another_document.delete_s3()
self.s3_test_helper.assert_file_in_s3("s3-key")

def test_force_delete_s3_file_on_shared_file_retains_object(self):
self.s3_test_helper.add_test_file("s3-key", b"test")

document = DocumentFactory(s3_key="s3-key")
DocumentFactory(s3_key="s3-key")

document.delete_s3(force_delete=True)
self.s3_test_helper.assert_file_not_in_s3("s3-key")

@freeze_time("2020-01-01 12:00:01")
@patch("api.documents.models.av_operations.scan_file_for_viruses")
def test_scan_for_viruses_safe_file(self, mock_scan_file_for_viruses):
self.s3_test_helper.add_test_file("s3-key", b"test")
mock_scan_file_for_viruses.return_value = False

document = DocumentFactory(s3_key="s3-key")
is_safe = document.scan_for_viruses()

self.assertIs(is_safe, True)
mock_scan_file_for_viruses.assert_called()
document.refresh_from_db()
self.assertIs(document.safe, True)
self.assertEqual(document.virus_scanned_at, timezone.now())
self.s3_test_helper.assert_file_in_s3("s3-key")

@freeze_time("2020-01-01 12:00:01")
@patch("api.documents.models.av_operations.scan_file_for_viruses")
def test_scan_for_viruses_unsafe_file(self, mock_scan_file_for_viruses):
self.s3_test_helper.add_test_file("s3-key", b"test")
mock_scan_file_for_viruses.return_value = True

document = DocumentFactory(s3_key="s3-key")
is_safe = document.scan_for_viruses()

self.assertIs(is_safe, False)
mock_scan_file_for_viruses.assert_called()
document.refresh_from_db()
self.assertIs(document.safe, False)
self.assertEqual(document.virus_scanned_at, timezone.now())
self.s3_test_helper.assert_file_not_in_s3("s3-key")

@freeze_time("2020-01-01 12:00:01")
@patch("api.documents.models.av_operations.scan_file_for_viruses")
def test_scan_for_viruses_unsafe_shared_file(self, mock_scan_file_for_viruses):
self.s3_test_helper.add_test_file("s3-key", b"test")
mock_scan_file_for_viruses.return_value = True

document = DocumentFactory(s3_key="s3-key")
another_document = DocumentFactory(s3_key="s3-key")
is_safe = document.scan_for_viruses()

self.assertFalse(is_safe)
mock_scan_file_for_viruses.assert_called()
document.refresh_from_db()
self.assertIs(document.safe, False)
self.assertEqual(document.virus_scanned_at, timezone.now())
self.s3_test_helper.assert_file_not_in_s3("s3-key")

another_document.refresh_from_db()
self.assertIs(another_document.safe, False)
self.assertEqual(another_document.virus_scanned_at, timezone.now())
47 changes: 47 additions & 0 deletions test_helpers/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from django.conf import settings

from api.documents.libraries.s3_operations import init_s3_client


class S3TesterHelper:
def __init__(self):
self.client = init_s3_client()
self.bucket_name = settings.AWS_STORAGE_BUCKET_NAME

self.setup()

def setup(self):
self.client.create_bucket(
Bucket=self.bucket_name,
CreateBucketConfiguration={
"LocationConstraint": settings.AWS_REGION,
},
)

def _get_keys(self):
objs = self.client.list_objects(Bucket=self.bucket_name)
keys = [o["Key"] for o in objs.get("Contents", [])]
return keys

def get_object(self, s3_key):
return self.client.get_object(Bucket=self.bucket_name, Key=s3_key)

def add_test_file(self, key, body):
return self.client.put_object(
Bucket=self.bucket_name,
Key=key,
Body=body,
)

def assert_file_in_s3(self, s3_key):
assert s3_key in self._get_keys(), f"`{s3_key}` not found in S3"

def assert_file_not_in_s3(self, s3_key):
assert s3_key not in self._get_keys(), f"`{s3_key}` found in S3"

def assert_file_body(self, s3_key, body):
obj = self.client.get_object(
Bucket=self.bucket_name,
Key=s3_key,
)
assert obj["Body"].read() == body, f"`{s3_key}` body doesn't match"

0 comments on commit 7496a2c

Please sign in to comment.