diff --git a/django_downloadview/response.py b/django_downloadview/response.py index 85cacda..fa5c82c 100644 --- a/django_downloadview/response.py +++ b/django_downloadview/response.py @@ -117,6 +117,7 @@ class DownloadResponse(StreamingHttpResponse): attributes (size, name, ...). """ + def __init__(self, file_instance, attachment=True, basename=None, status=200, content_type=None, file_mimetype=None, file_encoding=None): diff --git a/django_downloadview/views/base.py b/django_downloadview/views/base.py index 29121c3..1e7e9b5 100644 --- a/django_downloadview/views/base.py +++ b/django_downloadview/views/base.py @@ -60,7 +60,7 @@ class DownloadMixin(object): #: :mod:`mimetypes`. encoding = None - def get_file(self): + def get_file(self, *args, **kwargs): """Return a file wrapper instance. Raises :class:`~django_downloadview.exceptions.FileNotFound` if file @@ -122,8 +122,13 @@ def was_modified_since(self, file_instance, since): return was_modified_since(since, modification_time, size) def not_modified_response(self, *response_args, **response_kwargs): - """Return :class:`django.http.HttpResponseNotModified` instance.""" - return HttpResponseNotModified(*response_args, **response_kwargs) + """Return :class:`django.http.HttpResponseNotModified` instance. + Only `django.http.HttpResponseBase.__init__()` kwargs will be used.""" + allowed_kwargs = {} + for k in ('content_type', 'status', 'reason', 'charset'): + if k in response_kwargs: + allowed_kwargs[k] = response_kwargs[k] + return HttpResponseNotModified(*response_args, **allowed_kwargs) def download_response(self, *response_args, **response_kwargs): """Return :class:`~django_downloadview.response.DownloadResponse`.""" @@ -151,7 +156,7 @@ def render_to_response(self, *response_args, **response_kwargs): """ try: - self.file_instance = self.get_file() + self.file_instance = self.get_file(*response_args, **response_kwargs) except exceptions.FileNotFound: return self.file_not_found_response() # Respect the If-Modified-Since header. @@ -165,6 +170,7 @@ def render_to_response(self, *response_args, **response_kwargs): class BaseDownloadView(DownloadMixin, View): """A base :class:`DownloadMixin` that implements :meth:`get`.""" + def get(self, request, *args, **kwargs): """Handle GET requests: stream a file.""" - return self.render_to_response() + return self.render_to_response(*args, **kwargs) diff --git a/tests/views.py b/tests/views.py index 6ad8e1d..cd47fc7 100644 --- a/tests/views.py +++ b/tests/views.py @@ -21,6 +21,7 @@ class DownloadMixinTestCase(unittest.TestCase): """Test suite around :class:`django_downloadview.views.DownloadMixin`.""" + def test_get_file(self): """DownloadMixin.get_file() raise NotImplementedError. @@ -218,21 +219,23 @@ def test_file_not_found_response(self): class BaseDownloadViewTestCase(unittest.TestCase): "Tests around :class:`django_downloadviews.views.base.BaseDownloadView`." + def test_get(self): """BaseDownloadView.get() calls render_to_response().""" request = django.test.RequestFactory().get('/dummy-url') - args = ['dummy-arg'] - kwargs = {'dummy': 'kwarg'} + args = [] + kwargs = {'content_type': 'application/pdf'} view = setup_view(views.BaseDownloadView(), request, *args, **kwargs) view.render_to_response = mock.Mock( return_value=mock.sentinel.response) response = view.get(request, *args, **kwargs) self.assertIs(response, mock.sentinel.response) - view.render_to_response.assert_called_once_with() + view.render_to_response.assert_called_once_with(*args, **kwargs) class PathDownloadViewTestCase(unittest.TestCase): "Tests for :class:`django_downloadviews.views.path.PathDownloadView`." + def test_get_file_ok(self): "PathDownloadView.get_file() returns ``File`` instance." view = setup_view(views.PathDownloadView(path=__file__), @@ -262,6 +265,7 @@ def test_get_file_is_directory(self): class ObjectDownloadViewTestCase(unittest.TestCase): "Tests for :class:`django_downloadviews.views.object.ObjectDownloadView`." + def test_get_file_ok(self): "ObjectDownloadView.get_file() returns ``file`` field by default." view = setup_view(views.ObjectDownloadView(), 'fake request') @@ -296,6 +300,7 @@ def test_get_file_empty_field(self): class VirtualDownloadViewTestCase(unittest.TestCase): """Test suite around :py:class:`django_downloadview.views.VirtualDownloadView`.""" + def test_was_modified_since_specific(self): """VirtualDownloadView.was_modified_since() delegates to file wrapper.