From 67e66f36c1a0e9534d3bc3ea7f5469e886f48e4d Mon Sep 17 00:00:00 2001 From: Agustin Bacigalup Date: Sun, 17 Mar 2024 16:18:40 -0300 Subject: [PATCH] Add ETag header for static responses (#2306) * add etag to static responses * fix RuntimeError related to static headers * Remove unnecessary import --------- Co-authored-by: Simon Willison --- datasette/utils/__init__.py | 22 ++++++++++++++++++++++ datasette/utils/asgi.py | 19 +++++++++++++++++-- tests/test_html.py | 4 ++++ tests/test_utils.py | 12 ++++++++++++ 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index 9c0bbfa330..e110891119 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -1,5 +1,6 @@ import asyncio from contextlib import contextmanager +import aiofiles import click from collections import OrderedDict, namedtuple, Counter import copy @@ -1418,3 +1419,24 @@ def md5_not_usedforsecurity(s): except TypeError: # For Python 3.8 which does not support usedforsecurity=False return hashlib.md5(s.encode("utf8")).hexdigest() + + +_etag_cache = {} + + +async def calculate_etag(filepath, chunk_size=4096): + if filepath in _etag_cache: + return _etag_cache[filepath] + + hasher = hashlib.md5() + async with aiofiles.open(filepath, "rb") as f: + while True: + chunk = await f.read(chunk_size) + if not chunk: + break + hasher.update(chunk) + + etag = f'"{hasher.hexdigest()}"' + _etag_cache[filepath] = etag + + return etag diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py index b2c6f3abb6..2fad1d425b 100644 --- a/datasette/utils/asgi.py +++ b/datasette/utils/asgi.py @@ -1,5 +1,6 @@ +import hashlib import json -from datasette.utils import MultiParams +from datasette.utils import MultiParams, calculate_etag from mimetypes import guess_type from urllib.parse import parse_qs, urlunparse, parse_qsl from pathlib import Path @@ -285,6 +286,7 @@ async def asgi_send_file( headers = headers or {} if filename: headers["content-disposition"] = f'attachment; filename="{filename}"' + first = True headers["content-length"] = str((await aiofiles.os.stat(str(filepath))).st_size) async with aiofiles.open(str(filepath), mode="rb") as fp: @@ -307,9 +309,14 @@ async def asgi_send_file( def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None): root_path = Path(root_path) + static_headers = {} + + if headers: + static_headers = headers.copy() async def inner_static(request, send): path = request.scope["url_route"]["kwargs"]["path"] + headers = static_headers.copy() try: full_path = (root_path / path).resolve().absolute() except FileNotFoundError: @@ -325,7 +332,15 @@ async def inner_static(request, send): await asgi_send_html(send, "404: Path not inside root path", 404) return try: - await asgi_send_file(send, full_path, chunk_size=chunk_size) + # Calculate ETag for filepath + etag = await calculate_etag(full_path, chunk_size=chunk_size) + headers["ETag"] = etag + if_none_match = request.headers.get("if-none-match") + if if_none_match and if_none_match == etag: + return await asgi_send(send, "", 304) + await asgi_send_file( + send, full_path, chunk_size=chunk_size, headers=headers + ) except FileNotFoundError: await asgi_send_html(send, "404: File not found", 404) return diff --git a/tests/test_html.py b/tests/test_html.py index 8229b166fa..42b290c86e 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -78,6 +78,10 @@ async def test_static(ds_client): response = await ds_client.get("/-/static/app.css") assert response.status_code == 200 assert "text/css" == response.headers["content-type"] + assert "etag" in response.headers + etag = response.headers.get("etag") + response = await ds_client.get("/-/static/app.css", headers={"if-none-match": etag}) + assert response.status_code == 304 def test_static_mounts(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 515776156c..254b130055 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -706,3 +706,15 @@ def test_truncate_url(url, length, expected): def test_pairs_to_nested_config(pairs, expected): actual = utils.pairs_to_nested_config(pairs) assert actual == expected + + +@pytest.mark.asyncio +async def test_calculate_etag(tmp_path): + path = tmp_path / "test.txt" + path.write_text("hello") + etag = '"5d41402abc4b2a76b9719d911017c592"' + assert etag == await utils.calculate_etag(path) + assert utils._etag_cache[path] == etag + utils._etag_cache[path] = "hash" + assert "hash" == await utils.calculate_etag(path) + utils._etag_cache.clear()