Skip to content

Commit

Permalink
1065 Select for update (#1085)
Browse files Browse the repository at this point in the history
* Add for_update clause

* Skip testing for_update for SQLite

* Fix typos in exception messages

* Rename method to 'lock_for'

* Format 'test_select'

* wip doc changes

* rename `lock_for` to `lock_rows`

---------

Co-authored-by: Denis Kopitsa <[email protected]>
  • Loading branch information
dantownsend and dkopitsa authored Sep 23, 2024
1 parent b17e335 commit d3d0d30
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/src/piccolo/query_clauses/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ by modifying the return values.
./distinct
./freeze
./group_by
./lock_rows
./offset
./on_conflict
./output
Expand Down
93 changes: 93 additions & 0 deletions docs/src/piccolo/query_clauses/lock_rows.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
.. _lock_rows:

lock_rows
=========

You can use the ``lock_rows`` clause with the following queries:

* :ref:`Objects`
* :ref:`Select`

It returns a query that locks rows until the end of the transaction, generating a ``SELECT ... FOR UPDATE`` SQL statement or similar with other lock strengths.

.. note:: Postgres and CockroachDB only.

-------------------------------------------------------------------------------

Basic Usage
-----------

Basic usage without parameters:

.. code-block:: python
await Band.select(Band.name == 'Pythonistas').lock_rows()
Equivalent to:

.. code-block:: sql
SELECT ... FOR UPDATE
lock_strength
-------------

The parameter ``lock_strength`` controls the strength of the row lock when performing an operation in PostgreSQL.
The value can be a predefined constant from the ``LockStrength`` enum or one of the following strings (case-insensitive):

* ``UPDATE`` (default): Acquires an exclusive lock on the selected rows, preventing other transactions from modifying or locking them until the current transaction is complete.
* ``NO KEY UPDATE`` (Postgres only): Similar to ``UPDATE``, but allows other transactions to insert or delete rows that do not affect the primary key or unique constraints.
* ``KEY SHARE`` (Postgres only): Permits other transactions to acquire key-share or share locks, allowing non-key modifications while preventing updates or deletes.
* ``SHARE``: Acquires a shared lock, allowing other transactions to read the rows but not modify or lock them.

You can specify a different lock strength:

.. code-block:: python
await Band.select(Band.name == 'Pythonistas').lock_rows('SHARE')
Which is equivalent to:

.. code-block:: sql
SELECT ... FOR SHARE
nowait
------

If another transaction has already acquired a lock on one or more selected rows, an exception will be raised instead of
waiting for the other transaction to release the lock.

.. code-block:: python
await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', nowait=True)
skip_locked
-----------

Ignore locked rows.

.. code-block:: python
await Band.select(Band.name == 'Pythonistas').lock_rows('UPDATE', skip_locked=True)
of
--

By default, if there are many tables in a query (e.g. when joining), all tables will be locked.
Using ``of``, you can specify which tables should be locked.

.. code-block:: python
await Band.select().where(Band.manager.name == 'Guido').lock_rows('UPDATE', of=(Band, ))
Learn more
----------

* `Postgres docs <https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE>`_
* `CockroachDB docs <https://www.cockroachlabs.com/docs/stable/select-for-update#lock-strengths>`_
5 changes: 5 additions & 0 deletions docs/src/piccolo/query_types/objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ limit

See :ref:`limit`.

lock_rows
~~~~~~~~

See :ref:`lock_rows`.

offset
~~~~~~

Expand Down
6 changes: 6 additions & 0 deletions docs/src/piccolo/query_types/select.rst
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@ limit

See :ref:`limit`.


lock_rows
~~~~~~~~

See :ref:`lock_rows`.

offset
~~~~~~

Expand Down
25 changes: 25 additions & 0 deletions piccolo/query/methods/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
CallbackDelegate,
CallbackType,
LimitDelegate,
LockRowsDelegate,
LockStrength,
OffsetDelegate,
OrderByDelegate,
OrderByRaw,
Expand Down Expand Up @@ -250,6 +252,7 @@ class Objects(
"callback_delegate",
"prefetch_delegate",
"where_delegate",
"lock_rows_delegate",
)

def __init__(
Expand All @@ -269,6 +272,7 @@ def __init__(
self.prefetch_delegate = PrefetchDelegate()
self.prefetch(*prefetch)
self.where_delegate = WhereDelegate()
self.lock_rows_delegate = LockRowsDelegate()

def output(self: Self, load_json: bool = False) -> Self:
self.output_delegate.output(
Expand Down Expand Up @@ -328,6 +332,26 @@ def first(self) -> First[TableInstance]:
self.limit_delegate.limit(1)
return First[TableInstance](query=self)

def lock_rows(
self: Self,
lock_strength: t.Union[
LockStrength,
t.Literal[
"UPDATE",
"NO KEY UPDATE",
"KEY SHARE",
"SHARE",
],
] = LockStrength.update,
nowait: bool = False,
skip_locked: bool = False,
of: t.Tuple[type[Table], ...] = (),
) -> Self:
self.lock_rows_delegate.lock_rows(
lock_strength, nowait, skip_locked, of
)
return self

def get(self, where: Combinable) -> Get[TableInstance]:
self.where_delegate.where(where)
self.limit_delegate.limit(1)
Expand Down Expand Up @@ -378,6 +402,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]:
"offset_delegate",
"output_delegate",
"order_by_delegate",
"lock_rows_delegate",
):
setattr(select, attr, getattr(self, attr))

Expand Down
34 changes: 34 additions & 0 deletions piccolo/query/methods/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
DistinctDelegate,
GroupByDelegate,
LimitDelegate,
LockRowsDelegate,
LockStrength,
OffsetDelegate,
OrderByDelegate,
OrderByRaw,
Expand Down Expand Up @@ -150,6 +152,7 @@ class Select(Query[TableInstance, t.List[t.Dict[str, t.Any]]]):
"output_delegate",
"callback_delegate",
"where_delegate",
"lock_rows_delegate",
)

def __init__(
Expand All @@ -174,6 +177,7 @@ def __init__(
self.output_delegate = OutputDelegate()
self.callback_delegate = CallbackDelegate()
self.where_delegate = WhereDelegate()
self.lock_rows_delegate = LockRowsDelegate()

self.columns(*columns_list)

Expand Down Expand Up @@ -219,6 +223,26 @@ def offset(self: Self, number: int) -> Self:
self.offset_delegate.offset(number)
return self

def lock_rows(
self: Self,
lock_strength: t.Union[
LockStrength,
t.Literal[
"UPDATE",
"NO KEY UPDATE",
"KEY SHARE",
"SHARE",
],
] = LockStrength.update,
nowait: bool = False,
skip_locked: bool = False,
of: t.Tuple[type[Table], ...] = (),
) -> Self:
self.lock_rows_delegate.lock_rows(
lock_strength, nowait, skip_locked, of
)
return self

async def _splice_m2m_rows(
self,
response: t.List[t.Dict[str, t.Any]],
Expand Down Expand Up @@ -618,6 +642,16 @@ def default_querystrings(self) -> t.Sequence[QueryString]:
query += "{}"
args.append(self.offset_delegate._offset.querystring)

if self.lock_rows_delegate._lock_rows:
if engine_type == "sqlite":
raise NotImplementedError(
"SQLite doesn't support row locking e.g. SELECT ... FOR "
"UPDATE"
)

query += "{}"
args.append(self.lock_rows_delegate._lock_rows.querystring)

querystring = QueryString(query, *args)

return [querystring]
Expand Down
88 changes: 88 additions & 0 deletions piccolo/query/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,91 @@ def on_conflict(
target=target, action=action_, values=values, where=where
)
)


class LockStrength(str, Enum):
"""
Specify lock strength
https://www.postgresql.org/docs/current/sql-select.html#SQL-FOR-UPDATE-SHARE
"""

update = "UPDATE"
no_key_update = "NO KEY UPDATE"
share = "SHARE"
key_share = "KEY SHARE"


@dataclass
class LockRows:
__slots__ = ("lock_strength", "nowait", "skip_locked", "of")

lock_strength: LockStrength
nowait: bool
skip_locked: bool
of: t.Tuple[t.Type[Table], ...]

def __post_init__(self):
if not isinstance(self.lock_strength, LockStrength):
raise TypeError("lock_strength must be a LockStrength")
if not isinstance(self.nowait, bool):
raise TypeError("nowait must be a bool")
if not isinstance(self.skip_locked, bool):
raise TypeError("skip_locked must be a bool")
if not isinstance(self.of, tuple) or not all(
hasattr(x, "_meta") for x in self.of
):
raise TypeError("of must be a tuple of Table")
if self.nowait and self.skip_locked:
raise TypeError(
"The nowait option cannot be used with skip_locked"
)

@property
def querystring(self) -> QueryString:
sql = f" FOR {self.lock_strength.value}"
if self.of:
tables = ", ".join(
i._meta.get_formatted_tablename() for i in self.of
)
sql += " OF " + tables
if self.nowait:
sql += " NOWAIT"
if self.skip_locked:
sql += " SKIP LOCKED"

return QueryString(sql)

def __str__(self) -> str:
return self.querystring.__str__()


@dataclass
class LockRowsDelegate:

_lock_rows: t.Optional[LockRows] = None

def lock_rows(
self,
lock_strength: t.Union[
LockStrength,
t.Literal[
"UPDATE",
"NO KEY UPDATE",
"KEY SHARE",
"SHARE",
],
] = LockStrength.update,
nowait=False,
skip_locked=False,
of: t.Tuple[type[Table], ...] = (),
):
lock_strength_: LockStrength
if isinstance(lock_strength, LockStrength):
lock_strength_ = lock_strength
elif isinstance(lock_strength, str):
lock_strength_ = LockStrength(lock_strength.upper())
else:
raise ValueError("Unrecognised `lock_strength` value.")

self._lock_rows = LockRows(lock_strength_, nowait, skip_locked, of)
34 changes: 34 additions & 0 deletions tests/table/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,40 @@ def test_select_raw(self):
response, [{"name": "Pythonistas", "popularity_log": 3.0}]
)

@pytest.mark.skipif(
is_running_sqlite(),
reason="SQLite doesn't support SELECT ... FOR UPDATE.",
)
def test_lock_rows(self):
"""
Make sure the for_update clause works.
"""
self.insert_rows()

query = Band.select()
self.assertNotIn("FOR UPDATE", query.__str__())

query = query.lock_rows()
self.assertTrue(query.__str__().endswith("FOR UPDATE"))

query = query.lock_rows(lock_strength="KEY SHARE")
self.assertTrue(query.__str__().endswith("FOR KEY SHARE"))

query = query.lock_rows(skip_locked=True)
self.assertTrue(query.__str__().endswith("FOR UPDATE SKIP LOCKED"))

query = query.lock_rows(nowait=True)
self.assertTrue(query.__str__().endswith("FOR UPDATE NOWAIT"))

query = query.lock_rows(of=(Band,))
self.assertTrue(query.__str__().endswith('FOR UPDATE OF "band"'))

with self.assertRaises(TypeError):
query = query.lock_rows(skip_locked=True, nowait=True)

response = query.run_sync()
assert response is not None


class TestSelectSecret(TestCase):
def setUp(self):
Expand Down

0 comments on commit d3d0d30

Please sign in to comment.