Skip to content

Commit

Permalink
1017 Improve array serialisation in get_sql_value (#1018)
Browse files Browse the repository at this point in the history
* fix arrays

* make multidimensional arrays work, and SQLite edgecases

* add tests

* fix sqlite tests on certain Python versions
  • Loading branch information
dantownsend authored Jun 13, 2024
1 parent fb7a5ed commit 70dac99
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 29 deletions.
84 changes: 55 additions & 29 deletions piccolo/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,11 @@ def get_where_string(self, engine_type: str) -> QueryString:
engine_type=engine_type, with_alias=False
)

def get_sql_value(self, value: t.Any) -> t.Any:
def get_sql_value(
self,
value: t.Any,
delimiter: str = "'",
) -> str:
"""
When using DDL statements, we can't parameterise the values. An example
is when setting the default for a column. So we have to convert from
Expand All @@ -839,49 +843,71 @@ def get_sql_value(self, value: t.Any) -> t.Any:
:param value:
The Python value to convert to a string usable in a DDL statement
e.g. 1.
e.g. ``1``.
:param delimiter:
The string returned by this function is wrapped in delimiters,
ready to be added to a DDL statement. For example:
``'hello world'``.
:returns:
The string usable in the DDL statement e.g. '1'.
The string usable in the DDL statement e.g. ``'1'``.
"""
from piccolo.engine.sqlite import ADAPTERS as sqlite_adapters

# Common across all DB engines
if isinstance(value, Default):
return getattr(value, self._meta.engine_type)
elif value is None:
return "null"
elif isinstance(value, (float, decimal.Decimal)):
return str(value)
elif isinstance(value, str):
return f"'{value}'"
return f"{delimiter}{value}{delimiter}"
elif isinstance(value, bool):
return str(value).lower()
elif isinstance(value, datetime.datetime):
return f"'{value.isoformat().replace('T', ' ')}'"
elif isinstance(value, datetime.date):
return f"'{value.isoformat()}'"
elif isinstance(value, datetime.time):
return f"'{value.isoformat()}'"
elif isinstance(value, datetime.timedelta):
interval = IntervalCustom.from_timedelta(value)
return getattr(interval, self._meta.engine_type)
elif isinstance(value, bytes):
return f"'{value.hex()}'"
elif isinstance(value, uuid.UUID):
return f"'{value}'"
elif isinstance(value, list):
# Convert to the array syntax.
return (
"'{"
+ ", ".join(
(
f'"{i}"'
if isinstance(i, str)
else str(self.get_sql_value(i))
return f"{delimiter}{value.hex()}{delimiter}"

# SQLite specific
if self._meta.engine_type == "sqlite":
if adapter := sqlite_adapters.get(type(value)):
sqlite_value = adapter(value)
return (
f"{delimiter}{sqlite_value}{delimiter}"
if isinstance(sqlite_value, str)
else sqlite_value
)

# Postgres and Cockroach
if self._meta.engine_type in ["postgres", "cockroach"]:
if isinstance(value, datetime.datetime):
return f"{delimiter}{value.isoformat().replace('T', ' ')}{delimiter}" # noqa: E501
elif isinstance(value, datetime.date):
return f"{delimiter}{value.isoformat()}{delimiter}"
elif isinstance(value, datetime.time):
return f"{delimiter}{value.isoformat()}{delimiter}"
elif isinstance(value, datetime.timedelta):
interval = IntervalCustom.from_timedelta(value)
return getattr(interval, self._meta.engine_type)
elif isinstance(value, uuid.UUID):
return f"{delimiter}{value}{delimiter}"
elif isinstance(value, list):
# Convert to the array syntax.
return (
delimiter
+ "{"
+ ",".join(
self.get_sql_value(
i,
delimiter="" if isinstance(i, list) else '"',
)
for i in value
)
for i in value
+ "}"
+ delimiter
)
) + "}'"
else:
return value

return str(value)

@property
def column_type(self):
Expand Down
66 changes: 66 additions & 0 deletions tests/columns/test_get_sql_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import datetime
from unittest import TestCase

from tests.base import engines_only
from tests.example_apps.music.tables import Band


@engines_only("postgres", "cockroach")
class TestArrayPostgres(TestCase):

def test_string(self):
self.assertEqual(
Band.name.get_sql_value(["a", "b", "c"]),
'\'{"a","b","c"}\'',
)

def test_int(self):
self.assertEqual(
Band.name.get_sql_value([1, 2, 3]),
"'{1,2,3}'",
)

def test_nested(self):
self.assertEqual(
Band.name.get_sql_value([1, 2, 3, [4, 5, 6]]),
"'{1,2,3,{4,5,6}}'",
)

def test_time(self):
self.assertEqual(
Band.name.get_sql_value([datetime.time(hour=8, minute=0)]),
"'{\"08:00:00\"}'",
)


@engines_only("sqlite")
class TestArraySQLite(TestCase):
"""
Note, we use ``.replace(" ", "")`` because we serialise arrays using
Python's json library, and there is inconsistency between Python versions
(some output ``["a", "b", "c"]``, and others ``["a","b","c"]``).
"""

def test_string(self):
self.assertEqual(
Band.name.get_sql_value(["a", "b", "c"]).replace(" ", ""),
'\'["a","b","c"]\'',
)

def test_int(self):
self.assertEqual(
Band.name.get_sql_value([1, 2, 3]).replace(" ", ""),
"'[1,2,3]'",
)

def test_nested(self):
self.assertEqual(
Band.name.get_sql_value([1, 2, 3, [4, 5, 6]]).replace(" ", ""),
"'[1,2,3,[4,5,6]]'",
)

def test_time(self):
self.assertEqual(
Band.name.get_sql_value([datetime.time(hour=8, minute=0)]),
"'[\"08:00:00\"]'",
)

0 comments on commit 70dac99

Please sign in to comment.