diff --git a/datasette/__init__.py b/datasette/__init__.py index 64fb4ff7d0..271e09ada0 100644 --- a/datasette/__init__.py +++ b/datasette/__init__.py @@ -1,6 +1,7 @@ -from datasette.permissions import Permission +from datasette.permissions import Permission # noqa from datasette.version import __version_info__, __version__ # noqa from datasette.utils.asgi import Forbidden, NotFound, Request, Response # noqa from datasette.utils import actor_matches_allow # noqa +from datasette.views import Context # noqa from .hookspecs import hookimpl # noqa from .hookspecs import hookspec # noqa diff --git a/datasette/app.py b/datasette/app.py index b8b84168b0..b2644ace17 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,7 +1,8 @@ import asyncio -from typing import Sequence, Union, Tuple, Optional, Dict, Iterable +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union import asgi_csrf import collections +import dataclasses import datetime import functools import glob @@ -33,6 +34,7 @@ from jinja2.environment import Template from jinja2.exceptions import TemplateNotFound +from .views import Context from .views.base import ureg from .views.database import database_download, DatabaseView, TableCreateView from .views.index import IndexView @@ -1115,7 +1117,11 @@ def _register_renderers(self): ) async def render_template( - self, templates, context=None, request=None, view_name=None + self, + templates: Union[List[str], str, Template], + context: Optional[Union[Dict[str, Any], Context]] = None, + request: Optional[Request] = None, + view_name: Optional[str] = None, ): if not self._startup_invoked: raise Exception("render_template() called before await ds.invoke_startup()") @@ -1126,6 +1132,8 @@ async def render_template( if isinstance(templates, str): templates = [templates] template = self.jinja_env.select_template(templates) + if dataclasses.is_dataclass(context): + context = dataclasses.asdict(context) body_scripts = [] # pylint: disable=no-member for extra_script in pm.hook.extra_body_script( @@ -1368,7 +1376,8 @@ def add_route(view, regex): r"/(?P[^\/\.]+)\.db$", ) add_route( - DatabaseView.as_view(self), r"/(?P[^\/\.]+)(\.(?P\w+))?$" + wrap_view(DatabaseView, self), + r"/(?P[^\/\.]+)(\.(?P\w+))?$", ) add_route(TableCreateView.as_view(self), r"/(?P[^\/\.]+)/-/create$") add_route( @@ -1707,6 +1716,7 @@ async def async_view_for_class(request, send): datasette=datasette, ) + async_view_for_class.view_class = view_class return async_view_for_class diff --git a/datasette/renderer.py b/datasette/renderer.py index 5354f34852..0bd74e81ea 100644 --- a/datasette/renderer.py +++ b/datasette/renderer.py @@ -27,7 +27,7 @@ def convert_specific_columns_to_json(rows, columns, json_cols): return new_rows -def json_renderer(args, data, view_name): +def json_renderer(args, data, error, truncated=None): """Render a response as JSON""" status_code = 200 @@ -47,8 +47,15 @@ def json_renderer(args, data, view_name): # Deal with the _shape option shape = args.get("_shape", "objects") # if there's an error, ignore the shape entirely - if data.get("error"): + data["ok"] = True + if error: shape = "objects" + status_code = 400 + data["error"] = error + data["ok"] = False + + if truncated is not None: + data["truncated"] = truncated if shape == "arrayfirst": if not data["rows"]: @@ -64,13 +71,13 @@ def json_renderer(args, data, view_name): if rows and columns: data["rows"] = [dict(zip(columns, row)) for row in rows] if shape == "object": - error = None + shape_error = None if "primary_keys" not in data: - error = "_shape=object is only available on tables" + shape_error = "_shape=object is only available on tables" else: pks = data["primary_keys"] if not pks: - error = ( + shape_error = ( "_shape=object not available for tables with no primary keys" ) else: @@ -79,8 +86,8 @@ def json_renderer(args, data, view_name): pk_string = path_from_row_pks(row, pks, not pks) object_rows[pk_string] = row data = object_rows - if error: - data = {"ok": False, "error": error} + if shape_error: + data = {"ok": False, "error": shape_error} elif shape == "array": data = data["rows"] diff --git a/datasette/views/__init__.py b/datasette/views/__init__.py index e69de29bb2..e3b1b7f44b 100644 --- a/datasette/views/__init__.py +++ b/datasette/views/__init__.py @@ -0,0 +1,3 @@ +class Context: + "Base class for all documented contexts" + pass diff --git a/datasette/views/base.py b/datasette/views/base.py index 94645cd8dd..da5c55ad5c 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -309,6 +309,8 @@ async def get(self, request): table=data.get("table"), request=request, view_name=self.name, + truncated=False, # TODO: support this + error=data.get("error"), # These will be deprecated in Datasette 1.0: args=request.args, data=data, diff --git a/datasette/views/database.py b/datasette/views/database.py index ffa79e9643..77f3f5b04c 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,17 +1,22 @@ -import os +from asyncinject import Registry +from dataclasses import dataclass, field +from typing import Callable +from urllib.parse import parse_qsl, urlencode +import asyncio import hashlib import itertools import json -from markupsafe import Markup, escape -from urllib.parse import parse_qsl, urlencode +import markupsafe +import os import re import sqlite_utils +import textwrap -import markupsafe - +from datasette.database import QueryInterrupted from datasette.utils import ( add_cors_headers, await_me_maybe, + call_with_supported_arguments, derive_named_parameters, format_bytes, tilde_decode, @@ -28,17 +33,19 @@ from datasette.utils.asgi import AsgiFileDownload, NotFound, Response, Forbidden from datasette.plugins import pm -from .base import BaseView, DatasetteError, DataView, _error +from .base import BaseView, DatasetteError, DataView, View, _error, stream_csv -class DatabaseView(DataView): - name = "database" +class DatabaseView(View): + async def get(self, request, datasette): + format_ = request.url_vars.get("format") or "html" - async def data(self, request, default_labels=False, _size=None): - db = await self.ds.resolve_database(request) + await datasette.refresh_schemas() + + db = await datasette.resolve_database(request) database = db.name - visible, private = await self.ds.check_visibility( + visible, private = await datasette.check_visibility( request.actor, permissions=[ ("view-database", database), @@ -48,23 +55,23 @@ async def data(self, request, default_labels=False, _size=None): if not visible: raise Forbidden("You do not have permission to view this database") - metadata = (self.ds.metadata("databases") or {}).get(database, {}) - self.ds.update_with_inherited_metadata(metadata) + sql = (request.args.get("sql") or "").strip() + if sql: + return await query_view(request, datasette) - if request.args.get("sql"): - sql = request.args.get("sql") - validate_sql_select(sql) - return await QueryView(self.ds).data( - request, sql, _size=_size, metadata=metadata - ) + if format_ not in ("html", "json"): + raise NotFound("Invalid format: {}".format(format_)) + + metadata = (datasette.metadata("databases") or {}).get(database, {}) + datasette.update_with_inherited_metadata(metadata) table_counts = await db.table_counts(5) hidden_table_names = set(await db.hidden_table_names()) all_foreign_keys = await db.get_all_foreign_keys() - views = [] + sql_views = [] for view_name in await db.view_names(): - view_visible, view_private = await self.ds.check_visibility( + view_visible, view_private = await datasette.check_visibility( request.actor, permissions=[ ("view-table", (database, view_name)), @@ -73,45 +80,19 @@ async def data(self, request, default_labels=False, _size=None): ], ) if view_visible: - views.append( + sql_views.append( { "name": view_name, "private": view_private, } ) - tables = [] - for table in table_counts: - table_visible, table_private = await self.ds.check_visibility( - request.actor, - permissions=[ - ("view-table", (database, table)), - ("view-database", database), - "view-instance", - ], - ) - if not table_visible: - continue - table_columns = await db.table_columns(table) - tables.append( - { - "name": table, - "columns": table_columns, - "primary_keys": await db.primary_keys(table), - "count": table_counts[table], - "hidden": table in hidden_table_names, - "fts_table": await db.fts_table(table), - "foreign_keys": all_foreign_keys[table], - "private": table_private, - } - ) - - tables.sort(key=lambda t: (t["hidden"], t["name"])) + tables = await get_tables(datasette, request, db) canned_queries = [] for query in ( - await self.ds.get_canned_queries(database, request.actor) + await datasette.get_canned_queries(database, request.actor) ).values(): - query_visible, query_private = await self.ds.check_visibility( + query_visible, query_private = await datasette.check_visibility( request.actor, permissions=[ ("view-query", (database, query["name"])), @@ -125,7 +106,7 @@ async def data(self, request, default_labels=False, _size=None): async def database_actions(): links = [] for hook in pm.hook.database_actions( - datasette=self.ds, + datasette=datasette, database=database, actor=request.actor, request=request, @@ -137,36 +118,165 @@ async def database_actions(): attached_databases = [d.name for d in await db.attached_databases()] - allow_execute_sql = await self.ds.permission_allowed( + allow_execute_sql = await datasette.permission_allowed( request.actor, "execute-sql", database ) - return ( - { - "database": database, - "private": private, - "path": self.ds.urls.database(database), - "size": db.size, - "tables": tables, - "hidden_count": len([t for t in tables if t["hidden"]]), - "views": views, - "queries": canned_queries, - "allow_execute_sql": allow_execute_sql, - "table_columns": await _table_columns(self.ds, database) - if allow_execute_sql - else {}, + json_data = { + "database": database, + "private": private, + "path": datasette.urls.database(database), + "size": db.size, + "tables": tables, + "hidden_count": len([t for t in tables if t["hidden"]]), + "views": sql_views, + "queries": canned_queries, + "allow_execute_sql": allow_execute_sql, + "table_columns": await _table_columns(datasette, database) + if allow_execute_sql + else {}, + } + + if format_ == "json": + response = Response.json(json_data) + if datasette.cors: + add_cors_headers(response.headers) + return response + + assert format_ == "html" + alternate_url_json = datasette.absolute_url( + request, + datasette.urls.path(path_with_format(request=request, format="json")), + ) + templates = (f"database-{to_css_class(database)}.html", "database.html") + template = datasette.jinja_env.select_template(templates) + context = { + **json_data, + "database_actions": database_actions, + "show_hidden": request.args.get("_show_hidden"), + "editable": True, + "metadata": metadata, + "allow_download": datasette.setting("allow_download") + and not db.is_mutable + and not db.is_memory, + "attached_databases": attached_databases, + "database_color": lambda _: "#ff0000", + "alternate_url_json": alternate_url_json, + "select_templates": [ + f"{'*' if template_name == template.name else ''}{template_name}" + for template_name in templates + ], + } + return Response.html( + await datasette.render_template( + templates, + context, + request=request, + view_name="database", + ), + headers={ + "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) }, + ) + + +@dataclass +class QueryContext: + database: str = field(metadata={"help": "The name of the database being queried"}) + query: dict = field( + metadata={"help": "The SQL query object containing the `sql` string"} + ) + canned_query: str = field( + metadata={"help": "The name of the canned query if this is a canned query"} + ) + private: bool = field( + metadata={"help": "Boolean indicating if this is a private database"} + ) + # urls: dict = field( + # metadata={"help": "Object containing URL helpers like `database()`"} + # ) + canned_write: bool = field( + metadata={"help": "Boolean indicating if this canned query allows writes"} + ) + db_is_immutable: bool = field( + metadata={"help": "Boolean indicating if this database is immutable"} + ) + error: str = field(metadata={"help": "Any query error message"}) + hide_sql: bool = field( + metadata={"help": "Boolean indicating if the SQL should be hidden"} + ) + show_hide_link: str = field( + metadata={"help": "The URL to toggle showing/hiding the SQL"} + ) + show_hide_text: str = field( + metadata={"help": "The text for the show/hide SQL link"} + ) + editable: bool = field( + metadata={"help": "Boolean indicating if the SQL can be edited"} + ) + allow_execute_sql: bool = field( + metadata={"help": "Boolean indicating if custom SQL can be executed"} + ) + tables: list = field(metadata={"help": "List of table objects in the database"}) + named_parameter_values: dict = field( + metadata={"help": "Dictionary of parameter names/values"} + ) + edit_sql_url: str = field( + metadata={"help": "URL to edit the SQL for a canned query"} + ) + display_rows: list = field(metadata={"help": "List of result rows to display"}) + columns: list = field(metadata={"help": "List of column names"}) + renderers: dict = field(metadata={"help": "Dictionary of renderer name to URL"}) + url_csv: str = field(metadata={"help": "URL for CSV export"}) + show_hide_hidden: str = field( + metadata={"help": "Hidden input field for the _show_sql parameter"} + ) + metadata: dict = field(metadata={"help": "Metadata about the query/database"}) + database_color: Callable = field( + metadata={"help": "Function that returns a color for a given database name"} + ) + table_columns: dict = field( + metadata={"help": "Dictionary of table name to list of column names"} + ) + alternate_url_json: str = field( + metadata={"help": "URL for alternate JSON version of this page"} + ) + + +async def get_tables(datasette, request, db): + tables = [] + database = db.name + table_counts = await db.table_counts(5) + hidden_table_names = set(await db.hidden_table_names()) + all_foreign_keys = await db.get_all_foreign_keys() + + for table in table_counts: + table_visible, table_private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-table", (database, table)), + ("view-database", database), + "view-instance", + ], + ) + if not table_visible: + continue + table_columns = await db.table_columns(table) + tables.append( { - "database_actions": database_actions, - "show_hidden": request.args.get("_show_hidden"), - "editable": True, - "metadata": metadata, - "allow_download": self.ds.setting("allow_download") - and not db.is_mutable - and not db.is_memory, - "attached_databases": attached_databases, - }, - (f"database-{to_css_class(database)}.html", "database.html"), + "name": table, + "columns": table_columns, + "primary_keys": await db.primary_keys(table), + "count": table_counts[table], + "hidden": table in hidden_table_names, + "fts_table": await db.fts_table(table), + "foreign_keys": all_foreign_keys[table], + "private": table_private, + } ) + tables.sort(key=lambda t: (t["hidden"], t["name"])) + return tables async def database_download(request, datasette): @@ -210,6 +320,244 @@ async def database_download(request, datasette): ) +async def query_view( + request, + datasette, + # canned_query=None, + # _size=None, + # named_parameters=None, + # write=False, +): + db = await datasette.resolve_database(request) + database = db.name + # Flattened because of ?sql=&name1=value1&name2=value2 feature + params = {key: request.args.get(key) for key in request.args} + sql = None + if "sql" in params: + sql = params.pop("sql") + if "_shape" in params: + params.pop("_shape") + + # extras come from original request.args to avoid being flattened + extras = request.args.getlist("_extra") + + # TODO: Behave differently for canned query here: + await datasette.ensure_permissions(request.actor, [("execute-sql", database)]) + + _, private = await datasette.check_visibility( + request.actor, + permissions=[ + ("view-database", database), + "view-instance", + ], + ) + + extra_args = {} + if params.get("_timelimit"): + extra_args["custom_time_limit"] = int(params["_timelimit"]) + + format_ = request.url_vars.get("format") or "html" + query_error = None + try: + validate_sql_select(sql) + results = await datasette.execute( + database, sql, params, truncate=True, **extra_args + ) + columns = results.columns + rows = results.rows + except QueryInterrupted as ex: + raise DatasetteError( + textwrap.dedent( + """ +

SQL query took too long. The time limit is controlled by the + sql_time_limit_ms + configuration option.

+ + + """.format( + markupsafe.escape(ex.sql) + ) + ).strip(), + title="SQL Interrupted", + status=400, + message_is_html=True, + ) + except sqlite3.DatabaseError as ex: + query_error = str(ex) + results = None + rows = [] + columns = [] + except (sqlite3.OperationalError, InvalidSql) as ex: + raise DatasetteError(str(ex), title="Invalid SQL", status=400) + except sqlite3.OperationalError as ex: + raise DatasetteError(str(ex)) + except DatasetteError: + raise + + # Handle formats from plugins + if format_ == "csv": + + async def fetch_data_for_csv(request, _next=None): + results = await db.execute(sql, params, truncate=True) + data = {"rows": results.rows, "columns": results.columns} + return data, None, None + + return await stream_csv(datasette, fetch_data_for_csv, request, db.name) + elif format_ in datasette.renderers.keys(): + # Dispatch request to the correct output format renderer + # (CSV is not handled here due to streaming) + result = call_with_supported_arguments( + datasette.renderers[format_][0], + datasette=datasette, + columns=columns, + rows=rows, + sql=sql, + query_name=None, + database=database, + table=None, + request=request, + view_name="table", + truncated=results.truncated if results else False, + error=query_error, + # These will be deprecated in Datasette 1.0: + args=request.args, + data={"rows": rows, "columns": columns}, + ) + if asyncio.iscoroutine(result): + result = await result + if result is None: + raise NotFound("No data") + if isinstance(result, dict): + r = Response( + body=result.get("body"), + status=result.get("status_code") or 200, + content_type=result.get("content_type", "text/plain"), + headers=result.get("headers"), + ) + elif isinstance(result, Response): + r = result + # if status_code is not None: + # # Over-ride the status code + # r.status = status_code + else: + assert False, f"{result} should be dict or Response" + elif format_ == "html": + headers = {} + templates = [f"query-{to_css_class(database)}.html", "query.html"] + template = datasette.jinja_env.select_template(templates) + alternate_url_json = datasette.absolute_url( + request, + datasette.urls.path(path_with_format(request=request, format="json")), + ) + data = {} + headers.update( + { + "Link": '{}; rel="alternate"; type="application/json+datasette"'.format( + alternate_url_json + ) + } + ) + metadata = (datasette.metadata("databases") or {}).get(database, {}) + datasette.update_with_inherited_metadata(metadata) + + renderers = {} + for key, (_, can_render) in datasette.renderers.items(): + it_can_render = call_with_supported_arguments( + can_render, + datasette=datasette, + columns=data.get("columns") or [], + rows=data.get("rows") or [], + sql=data.get("query", {}).get("sql", None), + query_name=data.get("query_name"), + database=database, + table=data.get("table"), + request=request, + view_name="database", + ) + it_can_render = await await_me_maybe(it_can_render) + if it_can_render: + renderers[key] = datasette.urls.path( + path_with_format(request=request, format=key) + ) + + allow_execute_sql = await datasette.permission_allowed( + request.actor, "execute-sql", database + ) + + show_hide_hidden = "" + if metadata.get("hide_sql"): + if bool(params.get("_show_sql")): + show_hide_link = path_with_removed_args(request, {"_show_sql"}) + show_hide_text = "hide" + show_hide_hidden = '' + else: + show_hide_link = path_with_added_args(request, {"_show_sql": 1}) + show_hide_text = "show" + else: + if bool(params.get("_hide_sql")): + show_hide_link = path_with_removed_args(request, {"_hide_sql"}) + show_hide_text = "show" + show_hide_hidden = '' + else: + show_hide_link = path_with_added_args(request, {"_hide_sql": 1}) + show_hide_text = "hide" + hide_sql = show_hide_text == "show" + + r = Response.html( + await datasette.render_template( + template, + QueryContext( + database=database, + query={ + "sql": sql, + # TODO: Params? + }, + canned_query=None, + private=private, + canned_write=False, + db_is_immutable=not db.is_mutable, + error=query_error, + hide_sql=hide_sql, + show_hide_link=datasette.urls.path(show_hide_link), + show_hide_text=show_hide_text, + editable=True, # TODO + allow_execute_sql=allow_execute_sql, + tables=await get_tables(datasette, request, db), + named_parameter_values={}, # TODO + edit_sql_url="todo", + display_rows=await display_rows( + datasette, database, request, rows, columns + ), + table_columns=await _table_columns(datasette, database) + if allow_execute_sql + else {}, + columns=columns, + renderers=renderers, + url_csv=datasette.urls.path( + path_with_format( + request=request, format="csv", extra_qs={"_size": "max"} + ) + ), + show_hide_hidden=markupsafe.Markup(show_hide_hidden), + metadata=metadata, + database_color=lambda _: "#ff0000", + alternate_url_json=alternate_url_json, + ), + request=request, + view_name="database", + ), + headers=headers, + ) + else: + assert False, "Invalid format: {}".format(format_) + if datasette.cors: + add_cors_headers(r.headers) + return r + + class QueryView(DataView): async def data( self, @@ -404,7 +752,7 @@ async def extra_template(): display_value = plugin_display_value else: if value in ("", None): - display_value = Markup(" ") + display_value = markupsafe.Markup(" ") elif is_url(str(display_value).strip()): display_value = markupsafe.Markup( '{truncated_url}'.format( @@ -755,3 +1103,69 @@ async def _table_columns(datasette, database_name): for view_name in await db.view_names(): table_columns[view_name] = [] return table_columns + + +async def display_rows(datasette, database, request, rows, columns): + display_rows = [] + truncate_cells = datasette.setting("truncate_cells_html") + for row in rows: + display_row = [] + for column, value in zip(columns, row): + display_value = value + # Let the plugins have a go + # pylint: disable=no-member + plugin_display_value = None + for candidate in pm.hook.render_cell( + row=row, + value=value, + column=column, + table=None, + database=database, + datasette=datasette, + request=request, + ): + candidate = await await_me_maybe(candidate) + if candidate is not None: + plugin_display_value = candidate + break + if plugin_display_value is not None: + display_value = plugin_display_value + else: + if value in ("", None): + display_value = markupsafe.Markup(" ") + elif is_url(str(display_value).strip()): + display_value = markupsafe.Markup( + '{truncated_url}'.format( + url=markupsafe.escape(value.strip()), + truncated_url=markupsafe.escape( + truncate_url(value.strip(), truncate_cells) + ), + ) + ) + elif isinstance(display_value, bytes): + blob_url = path_with_format( + request=request, + format="blob", + extra_qs={ + "_blob_column": column, + "_blob_hash": hashlib.sha256(display_value).hexdigest(), + }, + ) + formatted = format_bytes(len(value)) + display_value = markupsafe.Markup( + '<Binary: {:,} byte{}>'.format( + blob_url, + ' title="{}"'.format(formatted) + if "bytes" not in formatted + else "", + len(value), + "" if len(value) == 1 else "s", + ) + ) + else: + display_value = str(value) + if truncate_cells and len(display_value) > truncate_cells: + display_value = display_value[:truncate_cells] + "\u2026" + display_row.append(display_value) + display_rows.append(display_row) + return display_rows diff --git a/datasette/views/table.py b/datasette/views/table.py index c102c10319..77acfd9504 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -833,6 +833,8 @@ async def fetch_data(request, _next=None): table=resolved.table, request=request, view_name="table", + truncated=False, + error=None, # These will be deprecated in Datasette 1.0: args=request.args, data=data, diff --git a/docs/plugin_hooks.rst b/docs/plugin_hooks.rst index 973065292b..9bbe6fc6fa 100644 --- a/docs/plugin_hooks.rst +++ b/docs/plugin_hooks.rst @@ -516,6 +516,12 @@ When a request is received, the ``"render"`` callback function is called with ze ``request`` - :ref:`internals_request` The current HTTP request. +``error`` - string or None + If an error occurred this string will contain the error message. + +``truncated`` - bool or None + If the query response was truncated - for example a SQL query returning more than 1,000 results where pagination is not available - this will be ``True``. + ``view_name`` - string The name of the current view being called. ``index``, ``database``, ``table``, and ``row`` are the most important ones. diff --git a/tests/test_api.py b/tests/test_api.py index 40a3e2b825..28415a0bcd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -638,22 +638,21 @@ def test_database_page_for_database_with_dot_in_name(app_client_with_dot): @pytest.mark.asyncio async def test_custom_sql(ds_client): response = await ds_client.get( - "/fixtures.json?sql=select+content+from+simple_primary_key&_shape=objects" + "/fixtures.json?sql=select+content+from+simple_primary_key" ) data = response.json() - assert {"sql": "select content from simple_primary_key", "params": {}} == data[ - "query" - ] - assert [ - {"content": "hello"}, - {"content": "world"}, - {"content": ""}, - {"content": "RENDER_CELL_DEMO"}, - {"content": "RENDER_CELL_ASYNC"}, - ] == data["rows"] - assert ["content"] == data["columns"] - assert "fixtures" == data["database"] - assert not data["truncated"] + assert data == { + "rows": [ + {"content": "hello"}, + {"content": "world"}, + {"content": ""}, + {"content": "RENDER_CELL_DEMO"}, + {"content": "RENDER_CELL_ASYNC"}, + ], + "columns": ["content"], + "ok": True, + "truncated": False, + } def test_sql_time_limit(app_client_shorter_time_limit): diff --git a/tests/test_cli_serve_get.py b/tests/test_cli_serve_get.py index ac44e1e285..2e0390bb8c 100644 --- a/tests/test_cli_serve_get.py +++ b/tests/test_cli_serve_get.py @@ -36,7 +36,6 @@ def startup(datasette): ) assert 0 == result.exit_code, result.output assert { - "database": "_memory", "truncated": False, "columns": ["sqlite_version()"], }.items() <= json.loads(result.output).items() diff --git a/tests/test_html.py b/tests/test_html.py index eadbd720a5..6c3860d73a 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -248,6 +248,9 @@ async def test_css_classes_on_body(ds_client, path, expected_classes): assert classes == expected_classes +templates_considered_re = re.compile(r"") + + @pytest.mark.asyncio @pytest.mark.parametrize( "path,expected_considered", @@ -271,7 +274,10 @@ async def test_css_classes_on_body(ds_client, path, expected_classes): async def test_templates_considered(ds_client, path, expected_considered): response = await ds_client.get(path) assert response.status_code == 200 - assert f"" in response.text + match = templates_considered_re.search(response.text) + assert match, "No templates considered comment found" + actual_considered = match.group(1) + assert actual_considered == expected_considered @pytest.mark.asyncio diff --git a/tests/test_internals_datasette.py b/tests/test_internals_datasette.py index 3d5bb2da58..d59ff72976 100644 --- a/tests/test_internals_datasette.py +++ b/tests/test_internals_datasette.py @@ -1,10 +1,12 @@ """ Tests for the datasette.app.Datasette class """ -from datasette import Forbidden +import dataclasses +from datasette import Forbidden, Context from datasette.app import Datasette, Database from itsdangerous import BadSignature import pytest +from typing import Optional @pytest.fixture @@ -136,6 +138,22 @@ async def test_datasette_render_template_no_request(): assert "Error " in rendered +@pytest.mark.asyncio +async def test_datasette_render_template_with_dataclass(): + @dataclasses.dataclass + class ExampleContext(Context): + title: str + status: int + error: str + + context = ExampleContext(title="Hello", status=200, error="Error message") + ds = Datasette(memory=True) + await ds.invoke_startup() + rendered = await ds.render_template("error.html", context) + assert "

Hello

" in rendered + assert "Error message" in rendered + + def test_datasette_error_if_string_not_list(tmpdir): # https://github.com/simonw/datasette/issues/1985 db_path = str(tmpdir / "data.db") diff --git a/tests/test_messages.py b/tests/test_messages.py index 8417b9ae2c..a7e4d04698 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -12,7 +12,7 @@ ], ) async def test_add_message_sets_cookie(ds_client, qs, expected): - response = await ds_client.get(f"/fixtures.message?{qs}") + response = await ds_client.get(f"/fixtures.message?sql=select+1&{qs}") signed = response.cookies["ds_messages"] decoded = ds_client.ds.unsign(signed, "messages") assert expected == decoded @@ -21,7 +21,9 @@ async def test_add_message_sets_cookie(ds_client, qs, expected): @pytest.mark.asyncio async def test_messages_are_displayed_and_cleared(ds_client): # First set the message cookie - set_msg_response = await ds_client.get("/fixtures.message?add_msg=xmessagex") + set_msg_response = await ds_client.get( + "/fixtures.message?sql=select+1&add_msg=xmessagex" + ) # Now access a page that displays messages response = await ds_client.get("/", cookies=set_msg_response.cookies) # Messages should be in that HTML diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 6971bbf739..28fe720fa5 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -121,9 +121,8 @@ async def test_hook_extra_css_urls(ds_client, path, expected_decoded_object): ][0]["href"] # This link has a base64-encoded JSON blob in it encoded = special_href.split("/")[3] - assert expected_decoded_object == json.loads( - base64.b64decode(encoded).decode("utf8") - ) + actual_decoded_object = json.loads(base64.b64decode(encoded).decode("utf8")) + assert expected_decoded_object == actual_decoded_object @pytest.mark.asyncio diff --git a/tests/test_table_api.py b/tests/test_table_api.py index cd664ffbff..46d1c9b8cc 100644 --- a/tests/test_table_api.py +++ b/tests/test_table_api.py @@ -700,7 +700,6 @@ async def test_max_returned_rows(ds_client): "/fixtures.json?sql=select+content+from+no_primary_key" ) data = response.json() - assert {"sql": "select content from no_primary_key", "params": {}} == data["query"] assert data["truncated"] assert 100 == len(data["rows"])