-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stop a file from being deleted from S3 if the file is also pointed to…
… by another Document object The only case where this won't happen is if a file is found to be unsafe and in that case we still remove the file
- Loading branch information
1 parent
96d848a
commit 7496a2c
Showing
5 changed files
with
218 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from unittest.mock import patch | ||
|
||
from freezegun import freeze_time | ||
|
||
from moto import mock_aws | ||
|
||
from django.test import TestCase | ||
from django.utils import timezone | ||
|
||
from api.documents.tests.factories import DocumentFactory | ||
from test_helpers.s3 import S3TesterHelper | ||
|
||
|
||
@mock_aws | ||
class DocumentModelTests(TestCase): | ||
def setUp(self, *args, **kwargs): | ||
super().setUp(*args, **kwargs) | ||
|
||
self.s3_test_helper = S3TesterHelper() | ||
|
||
def test_delete_s3_removes_object(self): | ||
self.s3_test_helper.add_test_file("s3-key", b"test") | ||
|
||
document = DocumentFactory(s3_key="s3-key") | ||
document.delete_s3() | ||
|
||
self.s3_test_helper.assert_file_not_in_s3("s3-key") | ||
|
||
def test_delete_s3_on_shared_file_retains_object(self): | ||
self.s3_test_helper.add_test_file("s3-key", b"test") | ||
|
||
document = DocumentFactory(s3_key="s3-key") | ||
another_document = DocumentFactory(s3_key="s3-key") | ||
|
||
document.delete_s3() | ||
self.s3_test_helper.assert_file_in_s3("s3-key") | ||
|
||
another_document.delete_s3() | ||
self.s3_test_helper.assert_file_in_s3("s3-key") | ||
|
||
def test_force_delete_s3_file_on_shared_file_retains_object(self): | ||
self.s3_test_helper.add_test_file("s3-key", b"test") | ||
|
||
document = DocumentFactory(s3_key="s3-key") | ||
DocumentFactory(s3_key="s3-key") | ||
|
||
document.delete_s3(force_delete=True) | ||
self.s3_test_helper.assert_file_not_in_s3("s3-key") | ||
|
||
@freeze_time("2020-01-01 12:00:01") | ||
@patch("api.documents.models.av_operations.scan_file_for_viruses") | ||
def test_scan_for_viruses_safe_file(self, mock_scan_file_for_viruses): | ||
self.s3_test_helper.add_test_file("s3-key", b"test") | ||
mock_scan_file_for_viruses.return_value = False | ||
|
||
document = DocumentFactory(s3_key="s3-key") | ||
is_safe = document.scan_for_viruses() | ||
|
||
self.assertIs(is_safe, True) | ||
mock_scan_file_for_viruses.assert_called() | ||
document.refresh_from_db() | ||
self.assertIs(document.safe, True) | ||
self.assertEqual(document.virus_scanned_at, timezone.now()) | ||
self.s3_test_helper.assert_file_in_s3("s3-key") | ||
|
||
@freeze_time("2020-01-01 12:00:01") | ||
@patch("api.documents.models.av_operations.scan_file_for_viruses") | ||
def test_scan_for_viruses_unsafe_file(self, mock_scan_file_for_viruses): | ||
self.s3_test_helper.add_test_file("s3-key", b"test") | ||
mock_scan_file_for_viruses.return_value = True | ||
|
||
document = DocumentFactory(s3_key="s3-key") | ||
is_safe = document.scan_for_viruses() | ||
|
||
self.assertIs(is_safe, False) | ||
mock_scan_file_for_viruses.assert_called() | ||
document.refresh_from_db() | ||
self.assertIs(document.safe, False) | ||
self.assertEqual(document.virus_scanned_at, timezone.now()) | ||
self.s3_test_helper.assert_file_not_in_s3("s3-key") | ||
|
||
@freeze_time("2020-01-01 12:00:01") | ||
@patch("api.documents.models.av_operations.scan_file_for_viruses") | ||
def test_scan_for_viruses_unsafe_shared_file(self, mock_scan_file_for_viruses): | ||
self.s3_test_helper.add_test_file("s3-key", b"test") | ||
mock_scan_file_for_viruses.return_value = True | ||
|
||
document = DocumentFactory(s3_key="s3-key") | ||
another_document = DocumentFactory(s3_key="s3-key") | ||
is_safe = document.scan_for_viruses() | ||
|
||
self.assertFalse(is_safe) | ||
mock_scan_file_for_viruses.assert_called() | ||
document.refresh_from_db() | ||
self.assertIs(document.safe, False) | ||
self.assertEqual(document.virus_scanned_at, timezone.now()) | ||
self.s3_test_helper.assert_file_not_in_s3("s3-key") | ||
|
||
another_document.refresh_from_db() | ||
self.assertIs(another_document.safe, False) | ||
self.assertEqual(another_document.virus_scanned_at, timezone.now()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from django.conf import settings | ||
|
||
from api.documents.libraries.s3_operations import init_s3_client | ||
|
||
|
||
class S3TesterHelper: | ||
def __init__(self): | ||
self.client = init_s3_client() | ||
self.bucket_name = settings.AWS_STORAGE_BUCKET_NAME | ||
|
||
self.setup() | ||
|
||
def setup(self): | ||
self.client.create_bucket( | ||
Bucket=self.bucket_name, | ||
CreateBucketConfiguration={ | ||
"LocationConstraint": settings.AWS_REGION, | ||
}, | ||
) | ||
|
||
def _get_keys(self): | ||
objs = self.client.list_objects(Bucket=self.bucket_name) | ||
keys = [o["Key"] for o in objs.get("Contents", [])] | ||
return keys | ||
|
||
def get_object(self, s3_key): | ||
return self.client.get_object(Bucket=self.bucket_name, Key=s3_key) | ||
|
||
def add_test_file(self, key, body): | ||
return self.client.put_object( | ||
Bucket=self.bucket_name, | ||
Key=key, | ||
Body=body, | ||
) | ||
|
||
def assert_file_in_s3(self, s3_key): | ||
assert s3_key in self._get_keys(), f"`{s3_key}` not found in S3" | ||
|
||
def assert_file_not_in_s3(self, s3_key): | ||
assert s3_key not in self._get_keys(), f"`{s3_key}` found in S3" | ||
|
||
def assert_file_body(self, s3_key, body): | ||
obj = self.client.get_object( | ||
Bucket=self.bucket_name, | ||
Key=s3_key, | ||
) | ||
assert obj["Body"].read() == body, f"`{s3_key}` body doesn't match" |