From e2b60f57ae2fd20baa2308464843cf11039efc92 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Mon, 7 Aug 2023 13:34:28 -0700 Subject: [PATCH] Fix OPTIONS bug by porting DatbaseView to be a View subclass --- datasette/app.py | 4 +- datasette/views/database.py | 213 ++++++++++++++++++------------------ tests/test_api.py | 1 - 3 files changed, 110 insertions(+), 108 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 39c2bb6de9..595ce78092 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -36,7 +36,7 @@ from .views import Context from .views.base import ureg -from .views.database import database_download, database_view, TableCreateView +from .views.database import database_download, DatabaseView, TableCreateView from .views.index import IndexView from .views.special import ( JsonDataView, @@ -1376,7 +1376,7 @@ def add_route(view, regex): r"/(?P[^\/\.]+)\.db$", ) add_route( - wrap_view(database_view, self), + wrap_view(DatabaseView, self), r"/(?P[^\/\.]+)(\.(?P\w+))?$", ) add_route(TableCreateView.as_view(self), r"/(?P[^\/\.]+)/-/create$") diff --git a/datasette/views/database.py b/datasette/views/database.py index 77f729d26d..ea3cafb3c3 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -32,128 +32,131 @@ from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden from datasette.plugins import pm -from .base import BaseView, DatasetteError, DataView, _error, stream_csv +from .base import BaseView, DatasetteError, DataView, View, _error, stream_csv -async def database_view(request, datasette): - format_ = request.url_vars.get("format") or "html" - if format_ not in ("html", "json"): - raise NotFound("Invalid format: {}".format(format_)) - - await datasette.refresh_schemas() - - db = await datasette.resolve_database(request) - database = db.name - - visible, private = await datasette.check_visibility( - request.actor, - permissions=[ - ("view-database", database), - "view-instance", - ], - ) - if not visible: - raise Forbidden("You do not have permission to view this database") +class DatabaseView(View): + async def get(self, request, datasette): + format_ = request.url_vars.get("format") or "html" + if format_ not in ("html", "json"): + raise NotFound("Invalid format: {}".format(format_)) - sql = (request.args.get("sql") or "").strip() - if sql: - return await query_view(request, datasette) + await datasette.refresh_schemas() - metadata = (datasette.metadata("databases") or {}).get(database, {}) - datasette.update_with_inherited_metadata(metadata) - - table_counts = await db.table_counts(5) - hidden_table_names = set(await db.hidden_table_names()) - all_foreign_keys = await db.get_all_foreign_keys() + db = await datasette.resolve_database(request) + database = db.name - sql_views = [] - for view_name in await db.view_names(): - view_visible, view_private = await datasette.check_visibility( + visible, private = await datasette.check_visibility( request.actor, permissions=[ - ("view-table", (database, view_name)), ("view-database", database), "view-instance", ], ) - if view_visible: - sql_views.append( - { - "name": view_name, - "private": view_private, - } + if not visible: + raise Forbidden("You do not have permission to view this database") + + sql = (request.args.get("sql") or "").strip() + if sql: + return await query_view(request, datasette) + + metadata = (datasette.metadata("databases") or {}).get(database, {}) + datasette.update_with_inherited_metadata(metadata) + + table_counts = await db.table_counts(5) + hidden_table_names = set(await db.hidden_table_names()) + all_foreign_keys = await db.get_all_foreign_keys() + + sql_views = [] + for view_name in await db.view_names(): + view_visible, view_private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-table", (database, view_name)), + ("view-database", database), + "view-instance", + ], ) + if view_visible: + sql_views.append( + { + "name": view_name, + "private": view_private, + } + ) - tables = await get_tables(datasette, request, db) - canned_queries = [] - for query in (await datasette.get_canned_queries(database, request.actor)).values(): - query_visible, query_private = await datasette.check_visibility( - request.actor, - permissions=[ - ("view-query", (database, query["name"])), - ("view-database", database), - "view-instance", - ], - ) - if query_visible: - canned_queries.append(dict(query, private=query_private)) + tables = await get_tables(datasette, request, db) + canned_queries = [] + for query in ( + await datasette.get_canned_queries(database, request.actor) + ).values(): + query_visible, query_private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-query", (database, query["name"])), + ("view-database", database), + "view-instance", + ], + ) + if query_visible: + canned_queries.append(dict(query, private=query_private)) - async def database_actions(): - links = [] - for hook in pm.hook.database_actions( - datasette=datasette, - database=database, - actor=request.actor, - request=request, - ): - extra_links = await await_me_maybe(hook) - if extra_links: - links.extend(extra_links) - return links + async def database_actions(): + links = [] + for hook in pm.hook.database_actions( + datasette=datasette, + database=database, + actor=request.actor, + request=request, + ): + extra_links = await await_me_maybe(hook) + if extra_links: + links.extend(extra_links) + return links - attached_databases = [d.name for d in await db.attached_databases()] + attached_databases = [d.name for d in await db.attached_databases()] - allow_execute_sql = await datasette.permission_allowed( - request.actor, "execute-sql", database - ) - json_data = { - "database": database, - "private": private, - "path": datasette.urls.database(database), - "size": db.size, - "tables": tables, - "hidden_count": len([t for t in tables if t["hidden"]]), - "views": sql_views, - "queries": canned_queries, - "allow_execute_sql": allow_execute_sql, - "table_columns": await _table_columns(datasette, database) - if allow_execute_sql - else {}, - } + allow_execute_sql = await datasette.permission_allowed( + request.actor, "execute-sql", database + ) + json_data = { + "database": database, + "private": private, + "path": datasette.urls.database(database), + "size": db.size, + "tables": tables, + "hidden_count": len([t for t in tables if t["hidden"]]), + "views": sql_views, + "queries": canned_queries, + "allow_execute_sql": allow_execute_sql, + "table_columns": await _table_columns(datasette, database) + if allow_execute_sql + else {}, + } - if format_ == "json": - response = Response.json(json_data) - if datasette.cors: - add_cors_headers(response.headers) - return response - - assert format_ == "html" - context = { - **json_data, - "database_actions": database_actions, - "show_hidden": request.args.get("_show_hidden"), - "editable": True, - "metadata": metadata, - "allow_download": datasette.setting("allow_download") - and not db.is_mutable - and not db.is_memory, - "attached_databases": attached_databases, - "database_color": lambda _: "#ff0000", - } - templates = (f"database-{to_css_class(database)}.html", "database.html") - return Response.html( - await datasette.render_template(templates, context, request=request) - ) + if format_ == "json": + response = Response.json(json_data) + if datasette.cors: + add_cors_headers(response.headers) + return response + + assert format_ == "html" + context = { + **json_data, + "database_actions": database_actions, + "show_hidden": request.args.get("_show_hidden"), + "editable": True, + "metadata": metadata, + "allow_download": datasette.setting("allow_download") + and not db.is_mutable + and not db.is_memory, + "attached_databases": attached_databases, + "database_color": lambda _: "#ff0000", + } + templates = (f"database-{to_css_class(database)}.html", "database.html") + return Response.html( + await datasette.render_template(templates, context, request=request) + ) from dataclasses import dataclass, field diff --git a/tests/test_api.py b/tests/test_api.py index c136e433c3..9d5966ba6c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -988,7 +988,6 @@ def test_inspect_file_used_for_count(app_client_immutable_and_inspect_file): @pytest.mark.asyncio -@pytest.mark.xfail # TODO: Fix this feature async def test_http_options_request(ds_client): response = await ds_client.options("/fixtures") assert response.status_code == 200