diff --git a/examples/basic_example_v2/basic_example.py b/examples/basic_example_v2/basic_example.py index 40aef03d..7bb61f69 100644 --- a/examples/basic_example_v2/basic_example.py +++ b/examples/basic_example_v2/basic_example.py @@ -140,7 +140,7 @@ def select_with_parameters(pool: ydb.QuerySessionPool, series_id, season_id, epi # calls instead to avoid additional hops to YDB cluster and allow more efficient # execution of queries. def explicit_transaction_control(pool: ydb.QuerySessionPool, series_id, season_id, episode_id): - def callee(session: ydb.QuerySessionSync): + def callee(session: ydb.QuerySession): query = """ DECLARE $seriesId AS Int64; DECLARE $seasonId AS Int64; @@ -175,7 +175,7 @@ def callee(session: ydb.QuerySessionSync): def huge_select(pool: ydb.QuerySessionPool): - def callee(session: ydb.QuerySessionSync): + def callee(session: ydb.QuerySession): query = """SELECT * from episodes;""" with session.transaction().execute( diff --git a/examples/basic_example_v2/basic_example_async.py b/examples/basic_example_v2/basic_example_async.py index f57ec491..ad66462a 100644 --- a/examples/basic_example_v2/basic_example_async.py +++ b/examples/basic_example_v2/basic_example_async.py @@ -58,7 +58,7 @@ """ -async def fill_tables_with_data(pool: ydb.aio.QuerySessionPoolAsync): +async def fill_tables_with_data(pool: ydb.aio.QuerySessionPool): print("\nFilling tables with data...") await pool.execute_with_retries( FillDataQuery, @@ -70,7 +70,7 @@ async def fill_tables_with_data(pool: ydb.aio.QuerySessionPoolAsync): ) -async def select_simple(pool: ydb.aio.QuerySessionPoolAsync): +async def select_simple(pool: ydb.aio.QuerySessionPool): print("\nCheck series table...") result_sets = await pool.execute_with_retries( """ @@ -96,7 +96,7 @@ async def select_simple(pool: ydb.aio.QuerySessionPoolAsync): return first_set -async def upsert_simple(pool: ydb.aio.QuerySessionPoolAsync): +async def upsert_simple(pool: ydb.aio.QuerySessionPool): print("\nPerforming UPSERT into episodes...") await pool.execute_with_retries( @@ -106,7 +106,7 @@ async def upsert_simple(pool: ydb.aio.QuerySessionPoolAsync): ) -async def select_with_parameters(pool: ydb.aio.QuerySessionPoolAsync, series_id, season_id, episode_id): +async def select_with_parameters(pool: ydb.aio.QuerySessionPool, series_id, season_id, episode_id): result_sets = await pool.execute_with_retries( """ DECLARE $seriesId AS Int64; @@ -138,8 +138,8 @@ async def select_with_parameters(pool: ydb.aio.QuerySessionPoolAsync, series_id, # In most cases it's better to use transaction control settings in session.transaction # calls instead to avoid additional hops to YDB cluster and allow more efficient # execution of queries. -async def explicit_transaction_control(pool: ydb.aio.QuerySessionPoolAsync, series_id, season_id, episode_id): - async def callee(session: ydb.aio.QuerySessionAsync): +async def explicit_transaction_control(pool: ydb.aio.QuerySessionPool, series_id, season_id, episode_id): + async def callee(session: ydb.aio.QuerySession): query = """ DECLARE $seriesId AS Int64; DECLARE $seasonId AS Int64; @@ -173,8 +173,8 @@ async def callee(session: ydb.aio.QuerySessionAsync): return await pool.retry_operation_async(callee) -async def huge_select(pool: ydb.aio.QuerySessionPoolAsync): - async def callee(session: ydb.aio.QuerySessionAsync): +async def huge_select(pool: ydb.aio.QuerySessionPool): + async def callee(session: ydb.aio.QuerySession): query = """SELECT * from episodes;""" async with await session.transaction().execute( @@ -189,12 +189,12 @@ async def callee(session: ydb.aio.QuerySessionAsync): return await pool.retry_operation_async(callee) -async def drop_tables(pool: ydb.aio.QuerySessionPoolAsync): +async def drop_tables(pool: ydb.aio.QuerySessionPool): print("\nCleaning up existing tables...") await pool.execute_with_retries(DropTablesQuery) -async def create_tables(pool: ydb.aio.QuerySessionPoolAsync): +async def create_tables(pool: ydb.aio.QuerySessionPool): print("\nCreating table series...") await pool.execute_with_retries( """ @@ -246,7 +246,7 @@ async def run(endpoint, database): ) as driver: await driver.wait(timeout=5, fail_fast=True) - async with ydb.aio.QuerySessionPoolAsync(driver) as pool: + async with ydb.aio.QuerySessionPool(driver) as pool: await drop_tables(pool) await create_tables(pool) diff --git a/examples/query-service/basic_example.py b/examples/query-service/basic_example.py index cfbb3042..854c2dfe 100644 --- a/examples/query-service/basic_example.py +++ b/examples/query-service/basic_example.py @@ -82,7 +82,7 @@ def callee(session): pool.retry_operation_sync(callee) - def callee(session: ydb.QuerySessionSync): + def callee(session: ydb.QuerySession): query_print = """select $a""" print("=" * 50) diff --git a/examples/query-service/basic_example_asyncio.py b/examples/query-service/basic_example_asyncio.py index cd7a4919..c26db535 100644 --- a/examples/query-service/basic_example_asyncio.py +++ b/examples/query-service/basic_example_asyncio.py @@ -15,7 +15,7 @@ async def main(): except TimeoutError: raise RuntimeError("Connect failed to YDB") - pool = ydb.aio.QuerySessionPoolAsync(driver) + pool = ydb.aio.QuerySessionPool(driver) print("=" * 50) print("DELETE TABLE IF EXISTS") @@ -83,7 +83,7 @@ async def callee(session): await pool.retry_operation_async(callee) - async def callee(session: ydb.aio.QuerySessionAsync): + async def callee(session: ydb.aio.QuerySession): query_print = """select $a""" print("=" * 50) diff --git a/test-requirements.txt b/test-requirements.txt index 705bf22f..fcf5779b 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -34,7 +34,7 @@ pyjwt==2.0.0 requests==2.31.0 texttable==1.6.4 toml==0.10.2 -typing-extensions==3.10.0.0 +typing-extensions==4.12.2 urllib3==1.26.6 websocket-client==0.59.0 zipp==3.19.1 diff --git a/tests/aio/query/conftest.py b/tests/aio/query/conftest.py index 0fbdbd38..27d96343 100644 --- a/tests/aio/query/conftest.py +++ b/tests/aio/query/conftest.py @@ -1,11 +1,11 @@ import pytest -from ydb.aio.query.session import QuerySessionAsync -from ydb.aio.query.pool import QuerySessionPoolAsync +from ydb.aio.query.session import QuerySession +from ydb.aio.query.pool import QuerySessionPool @pytest.fixture async def session(driver): - session = QuerySessionAsync(driver) + session = QuerySession(driver) yield session @@ -29,6 +29,6 @@ async def tx(session): @pytest.fixture -def pool(driver): - pool = QuerySessionPoolAsync(driver) - yield pool +async def pool(driver): + async with QuerySessionPool(driver) as pool: + yield pool diff --git a/tests/aio/query/test_query_session.py b/tests/aio/query/test_query_session.py index 117e39af..0bd06fba 100644 --- a/tests/aio/query/test_query_session.py +++ b/tests/aio/query/test_query_session.py @@ -1,14 +1,14 @@ import pytest -from ydb.aio.query.session import QuerySessionAsync +from ydb.aio.query.session import QuerySession -def _check_session_state_empty(session: QuerySessionAsync): +def _check_session_state_empty(session: QuerySession): assert session._state.session_id is None assert session._state.node_id is None assert not session._state.attached -def _check_session_state_full(session: QuerySessionAsync): +def _check_session_state_full(session: QuerySession): assert session._state.session_id is not None assert session._state.node_id is not None assert session._state.attached @@ -16,7 +16,7 @@ def _check_session_state_full(session: QuerySessionAsync): class TestAsyncQuerySession: @pytest.mark.asyncio - async def test_session_normal_lifecycle(self, session: QuerySessionAsync): + async def test_session_normal_lifecycle(self, session: QuerySession): _check_session_state_empty(session) await session.create() @@ -26,7 +26,7 @@ async def test_session_normal_lifecycle(self, session: QuerySessionAsync): _check_session_state_empty(session) @pytest.mark.asyncio - async def test_second_create_do_nothing(self, session: QuerySessionAsync): + async def test_second_create_do_nothing(self, session: QuerySession): await session.create() _check_session_state_full(session) @@ -40,30 +40,30 @@ async def test_second_create_do_nothing(self, session: QuerySessionAsync): assert session._state.node_id == node_id_before @pytest.mark.asyncio - async def test_second_delete_do_nothing(self, session: QuerySessionAsync): + async def test_second_delete_do_nothing(self, session: QuerySession): await session.create() await session.delete() await session.delete() @pytest.mark.asyncio - async def test_delete_before_create_not_possible(self, session: QuerySessionAsync): + async def test_delete_before_create_not_possible(self, session: QuerySession): with pytest.raises(RuntimeError): await session.delete() @pytest.mark.asyncio - async def test_create_after_delete_not_possible(self, session: QuerySessionAsync): + async def test_create_after_delete_not_possible(self, session: QuerySession): await session.create() await session.delete() with pytest.raises(RuntimeError): await session.create() - def test_transaction_before_create_raises(self, session: QuerySessionAsync): + def test_transaction_before_create_raises(self, session: QuerySession): with pytest.raises(RuntimeError): session.transaction() @pytest.mark.asyncio - async def test_transaction_after_delete_raises(self, session: QuerySessionAsync): + async def test_transaction_after_delete_raises(self, session: QuerySession): await session.create() await session.delete() @@ -72,24 +72,24 @@ async def test_transaction_after_delete_raises(self, session: QuerySessionAsync) session.transaction() @pytest.mark.asyncio - async def test_transaction_after_create_not_raises(self, session: QuerySessionAsync): + async def test_transaction_after_create_not_raises(self, session: QuerySession): await session.create() session.transaction() @pytest.mark.asyncio - async def test_execute_before_create_raises(self, session: QuerySessionAsync): + async def test_execute_before_create_raises(self, session: QuerySession): with pytest.raises(RuntimeError): await session.execute("select 1;") @pytest.mark.asyncio - async def test_execute_after_delete_raises(self, session: QuerySessionAsync): + async def test_execute_after_delete_raises(self, session: QuerySession): await session.create() await session.delete() with pytest.raises(RuntimeError): await session.execute("select 1;") @pytest.mark.asyncio - async def test_basic_execute(self, session: QuerySessionAsync): + async def test_basic_execute(self, session: QuerySession): await session.create() it = await session.execute("select 1;") result_sets = [result_set async for result_set in it] @@ -100,7 +100,7 @@ async def test_basic_execute(self, session: QuerySessionAsync): assert list(result_sets[0].rows[0].values()) == [1] @pytest.mark.asyncio - async def test_two_results(self, session: QuerySessionAsync): + async def test_two_results(self, session: QuerySession): await session.create() res = [] diff --git a/tests/aio/query/test_query_session_pool.py b/tests/aio/query/test_query_session_pool.py index e544f7b6..26b12082 100644 --- a/tests/aio/query/test_query_session_pool.py +++ b/tests/aio/query/test_query_session_pool.py @@ -1,42 +1,42 @@ +import asyncio import pytest import ydb -from ydb.aio.query.pool import QuerySessionPoolAsync -from ydb.aio.query.session import QuerySessionAsync, QuerySessionStateEnum +from ydb.aio.query.pool import QuerySessionPool +from ydb.aio.query.session import QuerySession, QuerySessionStateEnum -class TestQuerySessionPoolAsync: +class TestQuerySessionPool: @pytest.mark.asyncio - async def test_checkout_provides_created_session(self, pool: QuerySessionPoolAsync): + async def test_checkout_provides_created_session(self, pool: QuerySessionPool): async with pool.checkout() as session: assert session._state._state == QuerySessionStateEnum.CREATED - assert session._state._state == QuerySessionStateEnum.CLOSED - @pytest.mark.asyncio - async def test_oneshot_query_normal(self, pool: QuerySessionPoolAsync): + async def test_oneshot_query_normal(self, pool: QuerySessionPool): res = await pool.execute_with_retries("select 1;") assert len(res) == 1 @pytest.mark.asyncio - async def test_oneshot_ddl_query(self, pool: QuerySessionPoolAsync): + async def test_oneshot_ddl_query(self, pool: QuerySessionPool): + await pool.execute_with_retries("drop table if exists Queen;") await pool.execute_with_retries("create table Queen(key UInt64, PRIMARY KEY (key));") await pool.execute_with_retries("drop table Queen;") @pytest.mark.asyncio - async def test_oneshot_query_raises(self, pool: QuerySessionPoolAsync): + async def test_oneshot_query_raises(self, pool: QuerySessionPool): with pytest.raises(ydb.GenericError): await pool.execute_with_retries("Is this the real life? Is this just fantasy?") @pytest.mark.asyncio - async def test_retry_op_uses_created_session(self, pool: QuerySessionPoolAsync): - async def callee(session: QuerySessionAsync): + async def test_retry_op_uses_created_session(self, pool: QuerySessionPool): + async def callee(session: QuerySession): assert session._state._state == QuerySessionStateEnum.CREATED await pool.retry_operation_async(callee) @pytest.mark.asyncio - async def test_retry_op_normal(self, pool: QuerySessionPoolAsync): - async def callee(session: QuerySessionAsync): + async def test_retry_op_normal(self, pool: QuerySessionPool): + async def callee(session: QuerySession): async with session.transaction() as tx: iterator = await tx.execute("select 1;", commit_tx=True) return [result_set async for result_set in iterator] @@ -45,12 +45,79 @@ async def callee(session: QuerySessionAsync): assert len(res) == 1 @pytest.mark.asyncio - async def test_retry_op_raises(self, pool: QuerySessionPoolAsync): + async def test_retry_op_raises(self, pool: QuerySessionPool): class CustomException(Exception): pass - async def callee(session: QuerySessionAsync): + async def callee(session: QuerySession): raise CustomException() with pytest.raises(CustomException): await pool.retry_operation_async(callee) + + @pytest.mark.asyncio + async def test_pool_size_limit_logic(self, pool: QuerySessionPool): + target_size = 5 + pool._size = target_size + ids = set() + + for i in range(1, target_size + 1): + session = await pool.acquire() + assert pool._current_size == i + assert session._state.session_id not in ids + ids.add(session._state.session_id) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(pool.acquire(), timeout=0.1) + + last_id = session._state.session_id + await pool.release(session) + + session = await pool.acquire() + assert session._state.session_id == last_id + assert pool._current_size == target_size + + @pytest.mark.asyncio + async def test_checkout_do_not_increase_size(self, pool: QuerySessionPool): + session_id = None + for _ in range(10): + async with pool.checkout() as session: + if session_id is None: + session_id = session._state.session_id + assert pool._current_size == 1 + assert session_id == session._state.session_id + + @pytest.mark.asyncio + async def test_pool_recreates_bad_sessions(self, pool: QuerySessionPool): + async with pool.checkout() as session: + session_id = session._state.session_id + await session.delete() + + async with pool.checkout() as session: + assert session_id != session._state.session_id + assert pool._current_size == 1 + + @pytest.mark.asyncio + async def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPool): + await pool.stop() + with pytest.raises(RuntimeError): + await pool.acquire() + + @pytest.mark.asyncio + async def test_acquire_with_timeout_from_closed_pool_raises(self, pool: QuerySessionPool): + await pool.stop() + with pytest.raises(RuntimeError): + await asyncio.wait_for(pool.acquire(), timeout=0.1) + + @pytest.mark.asyncio + async def test_no_session_leak(self, driver, docker_project): + pool = ydb.aio.QuerySessionPool(driver, 1) + docker_project.stop() + try: + await asyncio.wait_for(pool.acquire(), timeout=0.1) + except ydb.Error: + pass + assert pool._current_size == 0 + + docker_project.start() + await pool.stop() diff --git a/tests/aio/query/test_query_transaction.py b/tests/aio/query/test_query_transaction.py index e332b086..47222d0b 100644 --- a/tests/aio/query/test_query_transaction.py +++ b/tests/aio/query/test_query_transaction.py @@ -1,73 +1,73 @@ import pytest -from ydb.aio.query.transaction import QueryTxContextAsync +from ydb.aio.query.transaction import QueryTxContext from ydb.query.transaction import QueryTxStateEnum class TestAsyncQueryTransaction: @pytest.mark.asyncio - async def test_tx_begin(self, tx: QueryTxContextAsync): + async def test_tx_begin(self, tx: QueryTxContext): assert tx.tx_id is None await tx.begin() assert tx.tx_id is not None @pytest.mark.asyncio - async def test_tx_allow_double_commit(self, tx: QueryTxContextAsync): + async def test_tx_allow_double_commit(self, tx: QueryTxContext): await tx.begin() await tx.commit() await tx.commit() @pytest.mark.asyncio - async def test_tx_allow_double_rollback(self, tx: QueryTxContextAsync): + async def test_tx_allow_double_rollback(self, tx: QueryTxContext): await tx.begin() await tx.rollback() await tx.rollback() @pytest.mark.asyncio - async def test_tx_commit_before_begin(self, tx: QueryTxContextAsync): + async def test_tx_commit_before_begin(self, tx: QueryTxContext): await tx.commit() assert tx._tx_state._state == QueryTxStateEnum.COMMITTED @pytest.mark.asyncio - async def test_tx_rollback_before_begin(self, tx: QueryTxContextAsync): + async def test_tx_rollback_before_begin(self, tx: QueryTxContext): await tx.rollback() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED @pytest.mark.asyncio - async def test_tx_first_execute_begins_tx(self, tx: QueryTxContextAsync): + async def test_tx_first_execute_begins_tx(self, tx: QueryTxContext): await tx.execute("select 1;") await tx.commit() @pytest.mark.asyncio - async def test_interactive_tx_commit(self, tx: QueryTxContextAsync): + async def test_interactive_tx_commit(self, tx: QueryTxContext): await tx.execute("select 1;", commit_tx=True) with pytest.raises(RuntimeError): await tx.execute("select 1;") @pytest.mark.asyncio - async def test_tx_execute_raises_after_commit(self, tx: QueryTxContextAsync): + async def test_tx_execute_raises_after_commit(self, tx: QueryTxContext): await tx.begin() await tx.commit() with pytest.raises(RuntimeError): await tx.execute("select 1;") @pytest.mark.asyncio - async def test_tx_execute_raises_after_rollback(self, tx: QueryTxContextAsync): + async def test_tx_execute_raises_after_rollback(self, tx: QueryTxContext): await tx.begin() await tx.rollback() with pytest.raises(RuntimeError): await tx.execute("select 1;") @pytest.mark.asyncio - async def test_context_manager_rollbacks_tx(self, tx: QueryTxContextAsync): + async def test_context_manager_rollbacks_tx(self, tx: QueryTxContext): async with tx: await tx.begin() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED @pytest.mark.asyncio - async def test_context_manager_normal_flow(self, tx: QueryTxContextAsync): + async def test_context_manager_normal_flow(self, tx: QueryTxContext): async with tx: await tx.begin() await tx.execute("select 1;") @@ -76,7 +76,7 @@ async def test_context_manager_normal_flow(self, tx: QueryTxContextAsync): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED @pytest.mark.asyncio - async def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContextAsync): + async def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContext): class CustomException(Exception): pass @@ -85,7 +85,7 @@ class CustomException(Exception): raise CustomException() @pytest.mark.asyncio - async def test_execute_as_context_manager(self, tx: QueryTxContextAsync): + async def test_execute_as_context_manager(self, tx: QueryTxContext): await tx.begin() async with await tx.execute("select 1;") as results: diff --git a/tests/query/conftest.py b/tests/query/conftest.py index 277aaeba..fa37b82e 100644 --- a/tests/query/conftest.py +++ b/tests/query/conftest.py @@ -1,11 +1,11 @@ import pytest -from ydb.query.session import QuerySessionSync +from ydb.query.session import QuerySession from ydb.query.pool import QuerySessionPool @pytest.fixture def session(driver_sync): - session = QuerySessionSync(driver_sync) + session = QuerySession(driver_sync) yield session diff --git a/tests/query/test_query_session.py b/tests/query/test_query_session.py index 89b899bd..a3f49cc4 100644 --- a/tests/query/test_query_session.py +++ b/tests/query/test_query_session.py @@ -1,22 +1,22 @@ import pytest -from ydb.query.session import QuerySessionSync +from ydb.query.session import QuerySession -def _check_session_state_empty(session: QuerySessionSync): +def _check_session_state_empty(session: QuerySession): assert session._state.session_id is None assert session._state.node_id is None assert not session._state.attached -def _check_session_state_full(session: QuerySessionSync): +def _check_session_state_full(session: QuerySession): assert session._state.session_id is not None assert session._state.node_id is not None assert session._state.attached class TestQuerySession: - def test_session_normal_lifecycle(self, session: QuerySessionSync): + def test_session_normal_lifecycle(self, session: QuerySession): _check_session_state_empty(session) session.create() @@ -25,7 +25,7 @@ def test_session_normal_lifecycle(self, session: QuerySessionSync): session.delete() _check_session_state_empty(session) - def test_second_create_do_nothing(self, session: QuerySessionSync): + def test_second_create_do_nothing(self, session: QuerySession): session.create() _check_session_state_full(session) @@ -38,27 +38,27 @@ def test_second_create_do_nothing(self, session: QuerySessionSync): assert session._state.session_id == session_id_before assert session._state.node_id == node_id_before - def test_second_delete_do_nothing(self, session: QuerySessionSync): + def test_second_delete_do_nothing(self, session: QuerySession): session.create() session.delete() session.delete() - def test_delete_before_create_not_possible(self, session: QuerySessionSync): + def test_delete_before_create_not_possible(self, session: QuerySession): with pytest.raises(RuntimeError): session.delete() - def test_create_after_delete_not_possible(self, session: QuerySessionSync): + def test_create_after_delete_not_possible(self, session: QuerySession): session.create() session.delete() with pytest.raises(RuntimeError): session.create() - def test_transaction_before_create_raises(self, session: QuerySessionSync): + def test_transaction_before_create_raises(self, session: QuerySession): with pytest.raises(RuntimeError): session.transaction() - def test_transaction_after_delete_raises(self, session: QuerySessionSync): + def test_transaction_after_delete_raises(self, session: QuerySession): session.create() session.delete() @@ -66,21 +66,21 @@ def test_transaction_after_delete_raises(self, session: QuerySessionSync): with pytest.raises(RuntimeError): session.transaction() - def test_transaction_after_create_not_raises(self, session: QuerySessionSync): + def test_transaction_after_create_not_raises(self, session: QuerySession): session.create() session.transaction() - def test_execute_before_create_raises(self, session: QuerySessionSync): + def test_execute_before_create_raises(self, session: QuerySession): with pytest.raises(RuntimeError): session.execute("select 1;") - def test_execute_after_delete_raises(self, session: QuerySessionSync): + def test_execute_after_delete_raises(self, session: QuerySession): session.create() session.delete() with pytest.raises(RuntimeError): session.execute("select 1;") - def test_basic_execute(self, session: QuerySessionSync): + def test_basic_execute(self, session: QuerySession): session.create() it = session.execute("select 1;") result_sets = [result_set for result_set in it] @@ -90,7 +90,7 @@ def test_basic_execute(self, session: QuerySessionSync): assert len(result_sets[0].columns) == 1 assert list(result_sets[0].rows[0].values()) == [1] - def test_two_results(self, session: QuerySessionSync): + def test_two_results(self, session: QuerySession): session.create() res = [] diff --git a/tests/query/test_query_session_pool.py b/tests/query/test_query_session_pool.py index 3c66c613..cb476fa8 100644 --- a/tests/query/test_query_session_pool.py +++ b/tests/query/test_query_session_pool.py @@ -1,7 +1,7 @@ import pytest import ydb from ydb.query.pool import QuerySessionPool -from ydb.query.session import QuerySessionSync, QuerySessionStateEnum +from ydb.query.session import QuerySession, QuerySessionStateEnum class TestQuerySessionPool: @@ -9,8 +9,6 @@ def test_checkout_provides_created_session(self, pool: QuerySessionPool): with pool.checkout() as session: assert session._state._state == QuerySessionStateEnum.CREATED - assert session._state._state == QuerySessionStateEnum.CLOSED - def test_oneshot_query_normal(self, pool: QuerySessionPool): res = pool.execute_with_retries("select 1;") assert len(res) == 1 @@ -24,13 +22,13 @@ def test_oneshot_query_raises(self, pool: QuerySessionPool): pool.execute_with_retries("Is this the real life? Is this just fantasy?") def test_retry_op_uses_created_session(self, pool: QuerySessionPool): - def callee(session: QuerySessionSync): + def callee(session: QuerySession): assert session._state._state == QuerySessionStateEnum.CREATED pool.retry_operation_sync(callee) def test_retry_op_normal(self, pool: QuerySessionPool): - def callee(session: QuerySessionSync): + def callee(session: QuerySession): with session.transaction() as tx: iterator = tx.execute("select 1;", commit_tx=True) return [result_set for result_set in iterator] @@ -42,8 +40,64 @@ def test_retry_op_raises(self, pool: QuerySessionPool): class CustomException(Exception): pass - def callee(session: QuerySessionSync): + def callee(session: QuerySession): raise CustomException() with pytest.raises(CustomException): pool.retry_operation_sync(callee) + + def test_pool_size_limit_logic(self, pool: QuerySessionPool): + target_size = 5 + pool._size = target_size + ids = set() + + for i in range(1, target_size + 1): + session = pool.acquire(timeout=0.1) + assert pool._current_size == i + assert session._state.session_id not in ids + ids.add(session._state.session_id) + + with pytest.raises(ydb.SessionPoolEmpty): + pool.acquire(timeout=0.1) + + last_id = session._state.session_id + pool.release(session) + + session = pool.acquire(timeout=0.1) + assert session._state.session_id == last_id + assert pool._current_size == target_size + + def test_checkout_do_not_increase_size(self, pool: QuerySessionPool): + session_id = None + for _ in range(10): + with pool.checkout() as session: + if session_id is None: + session_id = session._state.session_id + assert pool._current_size == 1 + assert session_id == session._state.session_id + + def test_pool_recreates_bad_sessions(self, pool: QuerySessionPool): + with pool.checkout() as session: + session_id = session._state.session_id + session.delete() + + with pool.checkout() as session: + assert session_id != session._state.session_id + assert pool._current_size == 1 + + def test_acquire_from_closed_pool_raises(self, pool: QuerySessionPool): + pool.stop() + with pytest.raises(RuntimeError): + pool.acquire(1) + + def test_no_session_leak(self, driver_sync, docker_project): + pool = ydb.QuerySessionPool(driver_sync, 1) + docker_project.stop() + try: + pool.acquire(timeout=0.1) + except ydb.Error: + pass + assert pool._current_size == 0 + + docker_project.start() + pool.stop() diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index 07a43fa6..9e78988a 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -1,62 +1,62 @@ import pytest -from ydb.query.transaction import QueryTxContextSync +from ydb.query.transaction import QueryTxContext from ydb.query.transaction import QueryTxStateEnum class TestQueryTransaction: - def test_tx_begin(self, tx: QueryTxContextSync): + def test_tx_begin(self, tx: QueryTxContext): assert tx.tx_id is None tx.begin() assert tx.tx_id is not None - def test_tx_allow_double_commit(self, tx: QueryTxContextSync): + def test_tx_allow_double_commit(self, tx: QueryTxContext): tx.begin() tx.commit() tx.commit() - def test_tx_allow_double_rollback(self, tx: QueryTxContextSync): + def test_tx_allow_double_rollback(self, tx: QueryTxContext): tx.begin() tx.rollback() tx.rollback() - def test_tx_commit_before_begin(self, tx: QueryTxContextSync): + def test_tx_commit_before_begin(self, tx: QueryTxContext): tx.commit() assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_tx_rollback_before_begin(self, tx: QueryTxContextSync): + def test_tx_rollback_before_begin(self, tx: QueryTxContext): tx.rollback() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_tx_first_execute_begins_tx(self, tx: QueryTxContextSync): + def test_tx_first_execute_begins_tx(self, tx: QueryTxContext): tx.execute("select 1;") tx.commit() - def test_interactive_tx_commit(self, tx: QueryTxContextSync): + def test_interactive_tx_commit(self, tx: QueryTxContext): tx.execute("select 1;", commit_tx=True) with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_commit(self, tx: QueryTxContextSync): + def test_tx_execute_raises_after_commit(self, tx: QueryTxContext): tx.begin() tx.commit() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_tx_execute_raises_after_rollback(self, tx: QueryTxContextSync): + def test_tx_execute_raises_after_rollback(self, tx: QueryTxContext): tx.begin() tx.rollback() with pytest.raises(RuntimeError): tx.execute("select 1;") - def test_context_manager_rollbacks_tx(self, tx: QueryTxContextSync): + def test_context_manager_rollbacks_tx(self, tx: QueryTxContext): with tx: tx.begin() assert tx._tx_state._state == QueryTxStateEnum.ROLLBACKED - def test_context_manager_normal_flow(self, tx: QueryTxContextSync): + def test_context_manager_normal_flow(self, tx: QueryTxContext): with tx: tx.begin() tx.execute("select 1;") @@ -64,7 +64,7 @@ def test_context_manager_normal_flow(self, tx: QueryTxContextSync): assert tx._tx_state._state == QueryTxStateEnum.COMMITTED - def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContextSync): + def test_context_manager_does_not_hide_exceptions(self, tx: QueryTxContext): class CustomException(Exception): pass @@ -72,7 +72,7 @@ class CustomException(Exception): with tx: raise CustomException() - def test_execute_as_context_manager(self, tx: QueryTxContextSync): + def test_execute_as_context_manager(self, tx: QueryTxContext): tx.begin() with tx.execute("select 1;") as results: diff --git a/tests/slo/src/jobs.py b/tests/slo/src/jobs.py index 3fb1833a..4fe0cd37 100644 --- a/tests/slo/src/jobs.py +++ b/tests/slo/src/jobs.py @@ -155,8 +155,8 @@ def run_reads_query(driver, query, max_id, metrics, limiter, runtime, timeout): with limiter: def check_result(result): - res = next(result) - assert res.rows[0] + with result: + pass params = RequestParams( pool=pool, @@ -182,7 +182,7 @@ def run_read_jobs_query(args, driver, tb_name, max_id, metrics): futures = [] for _ in range(args.read_threads): future = threading.Thread( - name="slo_run_read", + name="slo_run_read_query", target=run_reads_query, args=(driver, read_q, max_id, metrics, read_limiter, args.time, args.read_timeout / 1000), ) @@ -306,7 +306,7 @@ def run_write_jobs_query(args, driver, tb_name, max_id, metrics): futures = [] for _ in range(args.write_threads): future = threading.Thread( - name="slo_run_write", + name="slo_run_write_query", target=run_writes_query, args=(driver, write_q, row_generator, metrics, write_limiter, args.time, args.write_timeout / 1000), ) diff --git a/ydb/aio/__init__.py b/ydb/aio/__init__.py index 0e7d4e74..a755713d 100644 --- a/ydb/aio/__init__.py +++ b/ydb/aio/__init__.py @@ -1,3 +1,3 @@ from .driver import Driver # noqa from .table import SessionPool, retry_operation # noqa -from .query import QuerySessionPoolAsync, QuerySessionAsync # noqa +from .query import QuerySessionPool, QuerySession # noqa diff --git a/ydb/aio/query/__init__.py b/ydb/aio/query/__init__.py index 829d7b54..ea5273d7 100644 --- a/ydb/aio/query/__init__.py +++ b/ydb/aio/query/__init__.py @@ -1,7 +1,7 @@ __all__ = [ - "QuerySessionPoolAsync", - "QuerySessionAsync", + "QuerySessionPool", + "QuerySession", ] -from .pool import QuerySessionPoolAsync -from .session import QuerySessionAsync +from .pool import QuerySessionPool +from .session import QuerySession diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py index f91f7465..f0b962c3 100644 --- a/ydb/aio/query/pool.py +++ b/ydb/aio/query/pool.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import ( Callable, @@ -6,7 +7,7 @@ ) from .session import ( - QuerySessionAsync, + QuerySession, ) from ...retries import ( RetrySettings, @@ -18,16 +19,68 @@ logger = logging.getLogger(__name__) -class QuerySessionPoolAsync: - """QuerySessionPoolAsync is an object to simplify operations with sessions of Query Service.""" +class QuerySessionPool: + """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" - def __init__(self, driver: common_utils.SupportedDriverType): + def __init__(self, driver: common_utils.SupportedDriverType, size: int = 100): """ :param driver: A driver instance + :param size: Size of session pool """ - logger.warning("QuerySessionPoolAsync is an experimental API, which could be changed.") + logger.warning("QuerySessionPool is an experimental API, which could be changed.") self._driver = driver + self._size = size + self._should_stop = asyncio.Event() + self._queue = asyncio.Queue() + self._current_size = 0 + self._waiters = 0 + self._loop = asyncio.get_running_loop() + + async def _create_new_session(self): + session = QuerySession(self._driver) + await session.create() + logger.debug(f"New session was created for pool. Session id: {session._state.session_id}") + return session + + async def acquire(self) -> QuerySession: + if self._should_stop.is_set(): + logger.error("An attempt to take session from closed session pool.") + raise RuntimeError("An attempt to take session from closed session pool.") + + session = None + try: + session = self._queue.get_nowait() + except asyncio.QueueEmpty: + pass + + if session is None and self._current_size == self._size: + queue_get = asyncio.ensure_future(self._queue.get()) + task_stop = asyncio.ensure_future(asyncio.ensure_future(self._should_stop.wait())) + done, _ = await asyncio.wait((queue_get, task_stop), return_when=asyncio.FIRST_COMPLETED) + if task_stop in done: + queue_get.cancel() + raise RuntimeError("An attempt to take session from closed session pool.") + + task_stop.cancel() + session = queue_get.result() + + if session is not None: + if session._state.attached: + logger.debug(f"Acquired active session from queue: {session._state.session_id}") + return session + else: + self._current_size -= 1 + logger.debug(f"Acquired dead session from queue: {session._state.session_id}") + + logger.debug(f"Session pool is not large enough: {self._current_size} < {self._size}, will create new one.") + session = await self._create_new_session() + self._current_size += 1 + return session + + async def release(self, session: QuerySession) -> None: + self._queue.put_nowait(session) + logger.debug("Session returned to queue: %s", session._state.session_id) def checkout(self) -> "SimpleQuerySessionCheckoutAsync": """WARNING: This API is experimental and could be changed. @@ -85,8 +138,20 @@ async def wrapped_callee(): return await retry_operation_async(wrapped_callee, retry_settings) - async def stop(self, timeout=None): - pass # TODO: implement + async def stop(self): + self._should_stop.set() + + tasks = [] + while True: + try: + session = self._queue.get_nowait() + tasks.append(session.delete()) + except asyncio.QueueEmpty: + break + + await asyncio.gather(*tasks) + + logger.debug("All session were deleted.") async def __aenter__(self): return self @@ -94,15 +159,21 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop() + def __del__(self): + if self._should_stop.is_set() or self._loop.is_closed(): + return + + self._loop.call_soon(self.stop) + class SimpleQuerySessionCheckoutAsync: - def __init__(self, pool: QuerySessionPoolAsync): + def __init__(self, pool: QuerySessionPool): self._pool = pool - self._session = QuerySessionAsync(pool._driver) + self._session = None - async def __aenter__(self) -> QuerySessionAsync: - await self._session.create() + async def __aenter__(self) -> QuerySession: + self._session = await self._pool.acquire() return self._session async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._session.delete() + await self._pool.release(self._session) diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py index 627a41d8..5f51b671 100644 --- a/ydb/aio/query/session.py +++ b/ydb/aio/query/session.py @@ -5,9 +5,10 @@ ) from .base import AsyncResponseContextIterator -from .transaction import QueryTxContextAsync +from .transaction import QueryTxContext from .. import _utilities from ... import issues +from ...settings import BaseRequestSettings from ..._grpc.grpcwrapper import common_utils from ..._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public @@ -18,7 +19,7 @@ ) -class QuerySessionAsync(BaseQuerySession): +class QuerySession(BaseQuerySession): """Session object for Query Service. It is not recommended to control session's lifecycle manually - use a QuerySessionPool is always a better choise. """ @@ -32,7 +33,7 @@ def __init__( settings: Optional[base.QueryClientSettings] = None, loop: asyncio.AbstractEventLoop = None, ): - super(QuerySessionAsync, self).__init__(driver, settings) + super(QuerySession, self).__init__(driver, settings) self._loop = loop if loop is not None else asyncio.get_running_loop() async def _attach(self) -> None: @@ -62,7 +63,7 @@ async def _check_session_status_loop(self) -> None: self._state.reset() self._state._change_state(QuerySessionStateEnum.CLOSED) - async def delete(self) -> None: + async def delete(self, settings: Optional[BaseRequestSettings] = None) -> None: """WARNING: This API is experimental and could be changed. Deletes a Session of Query Service on server side and releases resources. @@ -73,30 +74,30 @@ async def delete(self) -> None: return self._state._check_invalid_transition(QuerySessionStateEnum.CLOSED) - await self._delete_call() + await self._delete_call(settings=settings) self._stream.cancel() - async def create(self) -> "QuerySessionAsync": + async def create(self, settings: Optional[BaseRequestSettings] = None) -> "QuerySession": """WARNING: This API is experimental and could be changed. Creates a Session of Query Service on server side and attaches it. - :return: QuerySessionSync object. + :return: QuerySession object. """ if self._state._already_in(QuerySessionStateEnum.CREATED): return self._state._check_invalid_transition(QuerySessionStateEnum.CREATED) - await self._create_call() + await self._create_call(settings=settings) await self._attach() return self - def transaction(self, tx_mode=None) -> QueryTxContextAsync: + def transaction(self, tx_mode=None) -> QueryTxContext: self._state._check_session_ready_to_use() tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() - return QueryTxContextAsync( + return QueryTxContext( self._driver, self._state, self, @@ -110,6 +111,7 @@ async def execute( syntax: base.QuerySyntax = None, exec_mode: base.QueryExecMode = None, concurrent_result_sets: bool = False, + settings: Optional[BaseRequestSettings] = None, ) -> AsyncResponseContextIterator: """WARNING: This API is experimental and could be changed. @@ -132,6 +134,7 @@ async def execute( exec_mode=exec_mode, parameters=parameters, concurrent_result_sets=concurrent_result_sets, + settings=settings, ) return AsyncResponseContextIterator( @@ -139,6 +142,7 @@ async def execute( lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, + session_state=self._state, settings=self._settings, ), ) diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index 429ba125..0e3ab602 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -15,8 +15,8 @@ logger = logging.getLogger(__name__) -class QueryTxContextAsync(BaseQueryTxContext): - async def __aenter__(self) -> "QueryTxContextAsync": +class QueryTxContext(BaseQueryTxContext): + async def __aenter__(self) -> "QueryTxContext": """ Enters a context manager and returns a transaction @@ -47,7 +47,7 @@ async def _ensure_prev_stream_finished(self) -> None: pass self._prev_stream = None - async def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxContextAsync": + async def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxContext": """WARNING: This API is experimental and could be changed. Explicitly begins a transaction @@ -146,6 +146,7 @@ async def execute( lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, + session_state=self._session_state, tx=self, commit_tx=commit_tx, settings=self.session._settings, diff --git a/ydb/query/__init__.py b/ydb/query/__init__.py index 40e512cd..1e950bb7 100644 --- a/ydb/query/__init__.py +++ b/ydb/query/__init__.py @@ -5,7 +5,7 @@ "QueryStaleReadOnly", "QuerySessionPool", "QueryClientSync", - "QuerySessionSync", + "QuerySession", ] import logging @@ -14,7 +14,7 @@ QueryClientSettings, ) -from .session import QuerySessionSync +from .session import QuerySession from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper.ydb_query_public_types import ( @@ -35,5 +35,5 @@ def __init__(self, driver: common_utils.SupportedDriverType, query_client_settin self._driver = driver self._settings = query_client_settings - def session(self) -> QuerySessionSync: - return QuerySessionSync(self._driver, self._settings) + def session(self) -> QuerySession: + return QuerySession(self._driver, self._settings) diff --git a/ydb/query/base.py b/ydb/query/base.py index 55087d0c..9372cbcf 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -165,28 +165,31 @@ def create_execute_query_request( ) +def bad_session_handler(func): + @functools.wraps(func) + def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs): + try: + return func(rpc_state, response_pb, session_state, *args, **kwargs) + except issues.BadSession: + session_state.reset() + raise + + return decorator + + +@bad_session_handler def wrap_execute_query_response( rpc_state: RpcState, response_pb: _apis.ydb_query.ExecuteQueryResponsePart, + session_state: IQuerySessionState, tx: Optional["BaseQueryTxContext"] = None, commit_tx: Optional[bool] = False, settings: Optional[QueryClientSettings] = None, ) -> convert.ResultSet: issues._process_response(response_pb) - if tx and response_pb.tx_meta and not tx.tx_id: - tx._move_to_beginned(response_pb.tx_meta.id) if tx and commit_tx: tx._move_to_commited() - return convert.ResultSet.from_message(response_pb.result_set, settings) - - -def bad_session_handler(func): - @functools.wraps(func) - def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs): - try: - return func(rpc_state, response_pb, session_state, *args, **kwargs) - except issues.BadSession: - session_state.reset() - raise + elif tx and response_pb.tx_meta and not tx.tx_id: + tx._move_to_beginned(response_pb.tx_meta.id) - return decorator + return convert.ResultSet.from_message(response_pb.result_set, settings) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index afe39f06..1ee9ea83 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -4,15 +4,20 @@ Optional, List, ) +import time +import threading +import queue from .session import ( - QuerySessionSync, + QuerySession, ) from ..retries import ( RetrySettings, retry_operation_sync, ) +from .. import issues from .. import convert +from ..settings import BaseRequestSettings from .._grpc.grpcwrapper import common_utils @@ -22,20 +27,80 @@ class QuerySessionPool: """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" - def __init__(self, driver: common_utils.SupportedDriverType): + def __init__(self, driver: common_utils.SupportedDriverType, size: int = 100): """ :param driver: A driver instance """ logger.warning("QuerySessionPool is an experimental API, which could be changed.") self._driver = driver - - def checkout(self) -> "SimpleQuerySessionCheckout": + self._queue = queue.Queue() + self._current_size = 0 + self._size = size + self._should_stop = threading.Event() + self._lock = threading.RLock() + + def _create_new_session(self, timeout: Optional[float]): + session = QuerySession(self._driver) + session.create(settings=BaseRequestSettings().with_timeout(timeout)) + logger.debug(f"New session was created for pool. Session id: {session._state.session_id}") + return session + + def acquire(self, timeout: Optional[float] = None) -> QuerySession: + start = time.monotonic() + + lock_acquire_timeout = timeout if timeout is not None else -1 + acquired = self._lock.acquire(timeout=lock_acquire_timeout) + try: + if self._should_stop.is_set(): + logger.error("An attempt to take session from closed session pool.") + raise RuntimeError("An attempt to take session from closed session pool.") + + session = None + try: + session = self._queue.get_nowait() + except queue.Empty: + pass + + finish = time.monotonic() + timeout = timeout - (finish - start) if timeout is not None else None + + start = time.monotonic() + if session is None and self._current_size == self._size: + try: + session = self._queue.get(block=True, timeout=timeout) + except queue.Empty: + raise issues.SessionPoolEmpty("Timeout on acquire session") + + if session is not None: + if session._state.attached: + logger.debug(f"Acquired active session from queue: {session._state.session_id}") + return session + else: + self._current_size -= 1 + logger.debug(f"Acquired dead session from queue: {session._state.session_id}") + + logger.debug(f"Session pool is not large enough: {self._current_size} < {self._size}, will create new one.") + finish = time.monotonic() + time_left = timeout - (finish - start) if timeout is not None else None + session = self._create_new_session(time_left) + + self._current_size += 1 + return session + finally: + if acquired: + self._lock.release() + + def release(self, session: QuerySession) -> None: + self._queue.put_nowait(session) + logger.debug("Session returned to queue: %s", session._state.session_id) + + def checkout(self, timeout: Optional[float] = None) -> "SimpleQuerySessionCheckout": """WARNING: This API is experimental and could be changed. Return a Session context manager, that opens session on enter and closes session on exit. """ - return SimpleQuerySessionCheckout(self) + return SimpleQuerySessionCheckout(self, timeout) def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetrySettings] = None, *args, **kwargs): """WARNING: This API is experimental and could be changed. @@ -50,7 +115,7 @@ def retry_operation_sync(self, callee: Callable, retry_settings: Optional[RetryS retry_settings = RetrySettings() if retry_settings is None else retry_settings def wrapped_callee(): - with self.checkout() as session: + with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session: return callee(session, *args, **kwargs) return retry_operation_sync(wrapped_callee, retry_settings) @@ -78,14 +143,28 @@ def execute_with_retries( retry_settings = RetrySettings() if retry_settings is None else retry_settings def wrapped_callee(): - with self.checkout() as session: + with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session: it = session.execute(query, parameters, *args, **kwargs) return [result_set for result_set in it] return retry_operation_sync(wrapped_callee, retry_settings) def stop(self, timeout=None): - pass # TODO: implement + acquire_timeout = timeout if timeout is not None else -1 + acquired = self._lock.acquire(timeout=acquire_timeout) + try: + self._should_stop.set() + while True: + try: + session = self._queue.get_nowait() + session.delete() + except queue.Empty: + break + + logger.debug("All session were deleted.") + finally: + if acquired: + self._lock.release() def __enter__(self): return self @@ -93,15 +172,19 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() + def __del__(self): + self.stop() + class SimpleQuerySessionCheckout: - def __init__(self, pool: QuerySessionPool): + def __init__(self, pool: QuerySessionPool, timeout: Optional[float]): self._pool = pool - self._session = QuerySessionSync(pool._driver) + self._timeout = timeout + self._session = None - def __enter__(self) -> QuerySessionSync: - self._session.create() + def __enter__(self) -> QuerySession: + self._session = self._pool.acquire(self._timeout) return self._session def __exit__(self, exc_type, exc_val, exc_tb): - self._session.delete() + self._pool.release(self._session) diff --git a/ydb/query/session.py b/ydb/query/session.py index 4b051dc1..66e86501 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -10,12 +10,13 @@ from . import base from .. import _apis, issues, _utilities +from ..settings import BaseRequestSettings from ..connection import _RpcState as RpcState from .._grpc.grpcwrapper import common_utils from .._grpc.grpcwrapper import ydb_query as _ydb_query from .._grpc.grpcwrapper import ydb_query_public_types as _ydb_query_public -from .transaction import QueryTxContextSync +from .transaction import QueryTxContext logger = logging.getLogger(__name__) @@ -136,29 +137,32 @@ def __init__(self, driver: common_utils.SupportedDriverType, settings: Optional[ self._settings = settings if settings is not None else base.QueryClientSettings() self._state = QuerySessionState(settings) - def _create_call(self) -> "BaseQuerySession": + def _create_call(self, settings: Optional[BaseRequestSettings] = None) -> "BaseQuerySession": return self._driver( _apis.ydb_query.CreateSessionRequest(), _apis.QueryService.Stub, _apis.QueryService.CreateSession, wrap_result=wrapper_create_session, wrap_args=(self._state, self), + settings=settings, ) - def _delete_call(self) -> "BaseQuerySession": + def _delete_call(self, settings: Optional[BaseRequestSettings] = None) -> "BaseQuerySession": return self._driver( _apis.ydb_query.DeleteSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, _apis.QueryService.DeleteSession, wrap_result=wrapper_delete_session, wrap_args=(self._state, self), + settings=settings, ) - def _attach_call(self) -> Iterable[_apis.ydb_query.SessionState]: + def _attach_call(self, settings: Optional[BaseRequestSettings] = None) -> Iterable[_apis.ydb_query.SessionState]: return self._driver( _apis.ydb_query.AttachSessionRequest(session_id=self._state.session_id), _apis.QueryService.Stub, _apis.QueryService.AttachSession, + settings=settings, ) def _execute_call( @@ -169,6 +173,7 @@ def _execute_call( exec_mode: base.QueryExecMode = None, parameters: dict = None, concurrent_result_sets: bool = False, + settings: Optional[BaseRequestSettings] = None, ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: request = base.create_execute_query_request( query=query, @@ -186,18 +191,19 @@ def _execute_call( request.to_proto(), _apis.QueryService.Stub, _apis.QueryService.ExecuteQuery, + settings=settings, ) -class QuerySessionSync(BaseQuerySession): +class QuerySession(BaseQuerySession): """Session object for Query Service. It is not recommended to control session's lifecycle manually - use a QuerySessionPool is always a better choise. """ _stream = None - def _attach(self) -> None: - self._stream = self._attach_call() + def _attach(self, settings: Optional[BaseRequestSettings] = None) -> None: + self._stream = self._attach_call(settings=settings) status_stream = _utilities.SyncResponseIterator( self._stream, lambda response: common_utils.ServerStatus.from_proto(response), @@ -228,7 +234,7 @@ def _check_session_status_loop(self, status_stream: _utilities.SyncResponseItera self._state.reset() self._state._change_state(QuerySessionStateEnum.CLOSED) - def delete(self) -> None: + def delete(self, settings: Optional[BaseRequestSettings] = None) -> None: """WARNING: This API is experimental and could be changed. Deletes a Session of Query Service on server side and releases resources. @@ -239,26 +245,27 @@ def delete(self) -> None: return self._state._check_invalid_transition(QuerySessionStateEnum.CLOSED) - self._delete_call() + self._delete_call(settings=settings) self._stream.cancel() - def create(self) -> "QuerySessionSync": + def create(self, settings: Optional[BaseRequestSettings] = None) -> "QuerySession": """WARNING: This API is experimental and could be changed. Creates a Session of Query Service on server side and attaches it. - :return: QuerySessionSync object. + :return: QuerySession object. """ if self._state._already_in(QuerySessionStateEnum.CREATED): return self._state._check_invalid_transition(QuerySessionStateEnum.CREATED) - self._create_call() + + self._create_call(settings=settings) self._attach() return self - def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> QueryTxContextSync: + def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> QueryTxContext: """WARNING: This API is experimental and could be changed. Creates a transaction context manager with specified transaction mode. @@ -275,7 +282,7 @@ def transaction(self, tx_mode: Optional[base.BaseQueryTxMode] = None) -> QueryTx tx_mode = tx_mode if tx_mode else _ydb_query_public.QuerySerializableReadWrite() - return QueryTxContextSync( + return QueryTxContext( self._driver, self._state, self, @@ -289,6 +296,7 @@ def execute( syntax: base.QuerySyntax = None, exec_mode: base.QueryExecMode = None, concurrent_result_sets: bool = False, + settings: Optional[BaseRequestSettings] = None, ) -> base.SyncResponseContextIterator: """WARNING: This API is experimental and could be changed. @@ -311,6 +319,7 @@ def execute( exec_mode=exec_mode, parameters=parameters, concurrent_result_sets=concurrent_result_sets, + settings=settings, ) return base.SyncResponseContextIterator( @@ -318,6 +327,7 @@ def execute( lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, + session_state=self._state, settings=self._settings, ), ) diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index be7396b1..9ad3552f 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -294,7 +294,7 @@ def _move_to_commited(self) -> None: self._tx_state._change_state(QueryTxStateEnum.COMMITTED) -class QueryTxContextSync(BaseQueryTxContext): +class QueryTxContext(BaseQueryTxContext): def __enter__(self) -> "BaseQueryTxContext": """ Enters a context manager and returns a transaction @@ -326,7 +326,7 @@ def _ensure_prev_stream_finished(self) -> None: pass self._prev_stream = None - def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxContextSync": + def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryTxContext": """WARNING: This API is experimental and could be changed. Explicitly begins a transaction @@ -427,6 +427,7 @@ def execute( lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, + session_state=self._session_state, tx=self, commit_tx=commit_tx, settings=self.session._settings,