From 2a871fb306cb4b4ebe442da80700dae697a0fa8d Mon Sep 17 00:00:00 2001 From: Fantix King Date: Thu, 25 May 2023 18:30:54 -0400 Subject: [PATCH 1/3] Add savepoint API --- edgedb/asyncio_client.py | 3 +++ edgedb/blocking_client.py | 13 +++++++++ edgedb/transaction.py | 51 ++++++++++++++++++++++++++++++++++++ tests/test_async_tx.py | 55 +++++++++++++++++++++++++++++++++++++++ tests/test_sync_tx.py | 55 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 177 insertions(+) diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index c03c6423..3c2a8ae8 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -322,6 +322,9 @@ def _exclusive(self): finally: self._locked = False + async def declare_savepoint(self, savepoint: str) -> transaction.Savepoint: + return await self._declare_savepoint(savepoint) + class AsyncIORetry(transaction.BaseRetry): diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 7eb761b9..16bafc47 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -270,6 +270,14 @@ async def close(self, timeout=None): self._closing = False +class Savepoint(transaction.Savepoint): + def release(self): + self._tx._client._iter_coroutine(super().release()) + + def rollback(self): + self._tx._client._iter_coroutine(super().rollback()) + + class Iteration(transaction.BaseTransaction, abstract.Executor): __slots__ = ("_managed", "_lock") @@ -320,6 +328,11 @@ def _exclusive(self): finally: self._lock.release() + def declare_savepoint(self, savepoint: str) -> Savepoint: + return self._client._iter_coroutine( + self._declare_savepoint(savepoint, cls=Savepoint) + ) + class Retry(transaction.BaseRetry): diff --git a/edgedb/transaction.py b/edgedb/transaction.py index 511b8f42..ba7aceaf 100644 --- a/edgedb/transaction.py +++ b/edgedb/transaction.py @@ -17,6 +17,8 @@ # +from __future__ import annotations + import enum from . import abstract @@ -32,12 +34,47 @@ class TransactionState(enum.Enum): FAILED = 4 +class Savepoint: + __slots__ = ('_name', '_tx', '_active') + + def __init__(self, name: str, transaction: BaseTransaction): + self._name = name + self._tx = transaction + self._active = True + + @property + def active(self): + return self._active + + def _ensure_active(self): + if not self._active: + raise errors.InterfaceError( + f"savepoint {self._name!r} is no longer active" + ) + + async def release(self): + self._ensure_active() + await self._tx._privileged_execute(f"release savepoint {self._name}") + del self._tx._savepoints[self._name] + self._active = False + + async def rollback(self): + self._ensure_active() + await self._tx._privileged_execute( + f"rollback to savepoint {self._name}" + ) + names = list(self._tx._savepoints) + for name in names[names.index(self._name):]: + self._tx._savepoints.pop(name)._active = False + + class BaseTransaction: __slots__ = ( '_client', '_connection', '_options', + '_savepoints', '_state', '__retry', '__iteration', @@ -48,6 +85,7 @@ def __init__(self, retry, client, iteration): self._client = client self._connection = None self._options = retry._options.transaction_options + self._savepoints = {} self._state = TransactionState.NEW self.__retry = retry self.__iteration = iteration @@ -128,6 +166,9 @@ async def _exit(self, extype, ex): if not self.__started: return False + for sp in self._savepoints.values(): + sp._active = False + try: if extype is None: query = self._make_commit_query() @@ -200,6 +241,16 @@ async def _privileged_execute(self, query: str) -> None: state=self._get_state(), )) + async def _declare_savepoint(self, savepoint: str, cls=Savepoint): + if savepoint in self._savepoints: + raise errors.InterfaceError( + f"savepoint {savepoint!r} already exists" + ) + await self._ensure_transaction() + await self._privileged_execute(f"declare savepoint {savepoint}") + self._savepoints[savepoint] = rv = cls(savepoint, self) + return rv + class BaseRetry: diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index 8ceeb239..4de176c0 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -34,6 +34,10 @@ class TestAsyncTx(tb.AsyncQueryTestCase): }; ''' + TEARDOWN_METHOD = ''' + DELETE test::TransactionTest; + ''' + TEARDOWN = ''' DROP TYPE test::TransactionTest; ''' @@ -104,3 +108,54 @@ async def test_async_transaction_exclusive(self): ): await asyncio.wait_for(f1, timeout=5) await asyncio.wait_for(f2, timeout=5) + + async def test_async_transaction_savepoint_1(self): + async for tx in self.client.transaction(): + async with tx: + sp1 = await tx.declare_savepoint("sp1") + sp2 = await tx.declare_savepoint("sp2") + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*already exists" + ): + await tx.declare_savepoint("sp1") + await tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + await sp2.release() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + await sp2.release() + await sp1.release() + + result = await self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) + + async def test_async_transaction_savepoint_2(self): + async for tx in self.client.transaction(): + async with tx: + await tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + sp1 = await tx.declare_savepoint("sp1") + await tx.execute(''' + INSERT test::TransactionTest { name := '2' }; + ''') + sp2 = await tx.declare_savepoint("sp2") + await tx.execute(''' + INSERT test::TransactionTest { name := '3' }; + ''') + await sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + await sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + await sp2.rollback() + + result = await self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 3ed2fc55..77a386ee 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -33,6 +33,10 @@ class TestSyncTx(tb.SyncQueryTestCase): }; ''' + TEARDOWN_METHOD = ''' + DELETE test::TransactionTest; + ''' + TEARDOWN = ''' DROP TYPE test::TransactionTest; ''' @@ -113,3 +117,54 @@ def test_sync_transaction_exclusive(self): ): f1.result(timeout=5) f2.result(timeout=5) + + def test_sync_transaction_savepoint_1(self): + for tx in self.client.transaction(): + with tx: + sp1 = tx.declare_savepoint("sp1") + sp2 = tx.declare_savepoint("sp2") + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*already exists" + ): + tx.declare_savepoint("sp1") + tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + sp2.release() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + sp2.release() + sp1.release() + + result = self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) + + def test_sync_transaction_savepoint_2(self): + for tx in self.client.transaction(): + with tx: + tx.execute(''' + INSERT test::TransactionTest { name := '1' }; + ''') + sp1 = tx.declare_savepoint("sp1") + tx.execute(''' + INSERT test::TransactionTest { name := '2' }; + ''') + sp2 = tx.declare_savepoint("sp2") + tx.execute(''' + INSERT test::TransactionTest { name := '3' }; + ''') + sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + sp1.rollback() + with self.assertRaisesRegex( + edgedb.InterfaceError, "savepoint.*is no longer active" + ): + sp2.rollback() + + result = self.client.query('SELECT test::TransactionTest.name') + + self.assertEqual(result, ["1"]) From a4e80791e61d889ad09fc3ff9e8c494319b53909 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 28 Jun 2023 01:09:59 +0900 Subject: [PATCH 2/3] CRF: rename to savepoint() and use random sp name --- edgedb/asyncio_client.py | 6 ++++-- edgedb/blocking_client.py | 6 ++++-- tests/test_async_tx.py | 12 ++++-------- tests/test_sync_tx.py | 12 ++++-------- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index 3c2a8ae8..ce30dd10 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -23,6 +23,7 @@ import socket import ssl import typing +import uuid from . import abstract from . import base_client @@ -322,8 +323,9 @@ def _exclusive(self): finally: self._locked = False - async def declare_savepoint(self, savepoint: str) -> transaction.Savepoint: - return await self._declare_savepoint(savepoint) + async def savepoint(self) -> transaction.Savepoint: + name = uuid.uuid4().hex + return await self._declare_savepoint(name) class AsyncIORetry(transaction.BaseRetry): diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 16bafc47..af10b8a2 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -25,6 +25,7 @@ import threading import time import typing +import uuid from . import abstract from . import base_client @@ -328,9 +329,10 @@ def _exclusive(self): finally: self._lock.release() - def declare_savepoint(self, savepoint: str) -> Savepoint: + def savepoint(self) -> Savepoint: + name = uuid.uuid4().hex return self._client._iter_coroutine( - self._declare_savepoint(savepoint, cls=Savepoint) + self._declare_savepoint(name, cls=Savepoint) ) diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index 4de176c0..b174c18f 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -112,12 +112,8 @@ async def test_async_transaction_exclusive(self): async def test_async_transaction_savepoint_1(self): async for tx in self.client.transaction(): async with tx: - sp1 = await tx.declare_savepoint("sp1") - sp2 = await tx.declare_savepoint("sp2") - with self.assertRaisesRegex( - edgedb.InterfaceError, "savepoint.*already exists" - ): - await tx.declare_savepoint("sp1") + sp1 = await tx.savepoint() + sp2 = await tx.savepoint() await tx.execute(''' INSERT test::TransactionTest { name := '1' }; ''') @@ -138,11 +134,11 @@ async def test_async_transaction_savepoint_2(self): await tx.execute(''' INSERT test::TransactionTest { name := '1' }; ''') - sp1 = await tx.declare_savepoint("sp1") + sp1 = await tx.savepoint() await tx.execute(''' INSERT test::TransactionTest { name := '2' }; ''') - sp2 = await tx.declare_savepoint("sp2") + sp2 = await tx.savepoint() await tx.execute(''' INSERT test::TransactionTest { name := '3' }; ''') diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index 77a386ee..d8735049 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -121,12 +121,8 @@ def test_sync_transaction_exclusive(self): def test_sync_transaction_savepoint_1(self): for tx in self.client.transaction(): with tx: - sp1 = tx.declare_savepoint("sp1") - sp2 = tx.declare_savepoint("sp2") - with self.assertRaisesRegex( - edgedb.InterfaceError, "savepoint.*already exists" - ): - tx.declare_savepoint("sp1") + sp1 = tx.savepoint() + sp2 = tx.savepoint() tx.execute(''' INSERT test::TransactionTest { name := '1' }; ''') @@ -147,11 +143,11 @@ def test_sync_transaction_savepoint_2(self): tx.execute(''' INSERT test::TransactionTest { name := '1' }; ''') - sp1 = tx.declare_savepoint("sp1") + sp1 = tx.savepoint() tx.execute(''' INSERT test::TransactionTest { name := '2' }; ''') - sp2 = tx.declare_savepoint("sp2") + sp2 = tx.savepoint() tx.execute(''' INSERT test::TransactionTest { name := '3' }; ''') From 6024465f41756ad8c4a8e4db9e3c23711ad26808 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 28 Jun 2023 01:35:55 +0900 Subject: [PATCH 3/3] Prefix sp name with "s" to avoid leading digits --- edgedb/asyncio_client.py | 2 +- edgedb/blocking_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index ce30dd10..336bc636 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -324,7 +324,7 @@ def _exclusive(self): self._locked = False async def savepoint(self) -> transaction.Savepoint: - name = uuid.uuid4().hex + name = "s" + uuid.uuid4().hex return await self._declare_savepoint(name) diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index af10b8a2..53719a6c 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -330,7 +330,7 @@ def _exclusive(self): self._lock.release() def savepoint(self) -> Savepoint: - name = uuid.uuid4().hex + name = "s" + uuid.uuid4().hex return self._client._iter_coroutine( self._declare_savepoint(name, cls=Savepoint) )