Skip to content

Commit

Permalink
named_parameters(sql) sync function, refs simonw#2354
Browse files Browse the repository at this point in the history
Also refs simonw#2353 and simonw#2352
  • Loading branch information
simonw committed Jun 12, 2024
1 parent c514897 commit 5e653fd
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 20 deletions.
33 changes: 24 additions & 9 deletions datasette/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,23 +1131,38 @@ class StartupError(Exception):
pass


_re_named_parameter = re.compile(":([a-zA-Z0-9_]+)")
_single_line_comment_re = re.compile(r"--.*")
_multi_line_comment_re = re.compile(r"/\*.*?\*/", re.DOTALL)
_single_quote_re = re.compile(r"'(?:''|[^'])*'")
_double_quote_re = re.compile(r'"(?:\"\"|[^"])*"')
_named_param_re = re.compile(r":(\w+)")


@documented
async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
def named_parameters(sql: str) -> List[str]:
"""
Given a SQL statement, return a list of named parameters that are used in the statement
e.g. for ``select * from foo where id=:id`` this would return ``["id"]``
"""
explain = "explain {}".format(sql.strip().rstrip(";"))
possible_params = _re_named_parameter.findall(sql)
try:
results = await db.execute(explain, {p: None for p in possible_params})
return [row["p4"].lstrip(":") for row in results if row["opcode"] == "Variable"]
except (sqlite3.DatabaseError, AttributeError):
return possible_params
# Remove single-line comments
sql = _single_line_comment_re.sub("", sql)
# Remove multi-line comments
sql = _multi_line_comment_re.sub("", sql)
# Remove single-quoted strings
sql = _single_quote_re.sub("", sql)
# Remove double-quoted strings
sql = _double_quote_re.sub("", sql)
# Extract parameters from what is left
return _named_param_re.findall(sql)


async def derive_named_parameters(db: "Database", sql: str) -> List[str]:
"""
This undocumented but stable method exists for backwards compatibility
with plugins that were using it before it switched to named_parameters()
"""
return named_parameters(sql)


def add_cors_headers(headers):
Expand Down
6 changes: 2 additions & 4 deletions datasette/views/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
add_cors_headers,
await_me_maybe,
call_with_supported_arguments,
derive_named_parameters,
named_parameters as derive_named_parameters,
format_bytes,
make_slot_function,
tilde_decode,
Expand Down Expand Up @@ -484,9 +484,7 @@ async def get(self, request, datasette):
if canned_query and canned_query.get("params"):
named_parameters = canned_query["params"]
if not named_parameters:
named_parameters = await derive_named_parameters(
datasette.get_database(database), sql
)
named_parameters = derive_named_parameters(sql)
named_parameter_values = {
named_parameter: params.get(named_parameter) or ""
for named_parameter in named_parameters
Expand Down
10 changes: 5 additions & 5 deletions docs/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1256,14 +1256,14 @@ Utility function for calling ``await`` on a return value if it is awaitable, oth

.. autofunction:: datasette.utils.await_me_maybe

.. _internals_utils_derive_named_parameters:
.. _internals_utils_named_parameters:

derive_named_parameters(db, sql)
--------------------------------
named_parameters(sql)
---------------------

Derive the list of named parameters referenced in a SQL query, using an ``explain`` query executed against the provided database.
Derive the list of ``:named`` parameters referenced in a SQL query.

.. autofunction:: datasette.utils.derive_named_parameters
.. autofunction:: datasette.utils.named_parameters

.. _internals_tilde_encoding:

Expand Down
8 changes: 6 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,10 +612,14 @@ def test_parse_metadata(content, expected):
("select this is invalid :one, :two, :three", ["one", "two", "three"]),
),
)
async def test_derive_named_parameters(sql, expected):
@pytest.mark.parametrize("use_async_version", (False, True))
async def test_named_parameters(sql, expected, use_async_version):
ds = Datasette([], memory=True)
db = ds.get_database("_memory")
params = await utils.derive_named_parameters(db, sql)
if use_async_version:
params = await utils.derive_named_parameters(db, sql)
else:
params = utils.named_parameters(sql)
assert params == expected


Expand Down

0 comments on commit 5e653fd

Please sign in to comment.