From 39654bd1e687bab8e23d8ed4be7b9923b0a9e7f0 Mon Sep 17 00:00:00 2001 From: Andong Zhan Date: Thu, 15 Feb 2024 11:43:40 -0800 Subject: [PATCH] SNOW-1049757 Make sure cleanup active sessions at interpreter shutdown (#1250) * SNOW-1049757 Make sure cleanup active sessions at interpreter shutdown * resolve comments --- CHANGELOG.md | 4 +++ src/snowflake/snowpark/session.py | 18 ++++++++++++ tests/integ/test_session.py | 4 +-- tests/unit/test_session.py | 49 +++++++++++++++++++++++-------- 4 files changed, 61 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 925e5435e35..7968f2c797d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ - Added support for an optional `date_part` argument in function `last_day` - `SessionBuilder.app_name` will set the query_tag after the session is created. +### Improvements + +- Added cleanup logic at interpreter shutdown to close all active sessions. + ### Bug Fixes - Fixed a bug in `DataFrame.to_local_iterator` where the iterator could yield wrong results if another query is executed before the iterator finishes due to wrong isolation level. For details, please see #945. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a435d4b888b..79a95dcf12b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3,6 +3,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import atexit import datetime import decimal import inspect @@ -209,6 +210,23 @@ def _add_session(session: "Session") -> None: _active_sessions.add(session) +def _close_session_atexit(): + """ + This is the helper function to close all active sessions at interpreter shutdown. For example, when a jupyter + notebook is shutting down, this will also close all active sessions and make sure send all telemetry to the server. + """ + with _session_management_lock: + for session in _active_sessions.copy(): + try: + session.close() + except Exception: + pass + + +# Register _close_session_atexit so it will be called at interpreter shutdown +atexit.register(_close_session_atexit) + + def _remove_session(session: "Session") -> None: with _session_management_lock: try: diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 858a24508a3..f38e0c688ea 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -334,8 +334,8 @@ def test_create_session_from_connection_with_noise_parameters( @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_session_builder_app_name(session, db_parameters): builder = session.builder - app_name = 'my_app' - expected_query_tag = f'APPNAME={app_name}' + app_name = "my_app" + expected_query_tag = f"APPNAME={app_name}" same_session = builder.app_name(app_name).getOrCreate() new_session = builder.app_name(app_name).configs(db_parameters).create() try: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index cfe9fe273cd..890cd7b2d27 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -26,7 +26,10 @@ SnowparkInvalidObjectNameException, SnowparkSessionException, ) -from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING +from snowflake.snowpark.session import ( + _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, + _close_session_atexit, +) from snowflake.snowpark.types import StructField, StructType @@ -378,6 +381,30 @@ def test_session_id(): assert session.session_id == 123456 +def test_session_close_atexit(): + mocked_session = Session( + ServerConnection( + {"": ""}, + mock.Mock( + spec=SnowflakeConnection, + _telemetry=mock.Mock(), + _session_parameters=mock.Mock(), + is_closed=mock.Mock(return_value=False), + expired=False, + ), + ), + ) + + with mock.patch( + "snowflake.snowpark.session._active_sessions", + {mocked_session}, + ): + with mock.patch.object(snowflake.snowpark.session.Session, "close") as m: + # _close_session_atexit will be called when the interpreter is shutting down + _close_session_atexit() + m.assert_called_once() + + def test_connection(): fake_snowflake_connection = mock.create_autospec(SnowflakeConnection) fake_snowflake_connection._telemetry = mock.Mock() @@ -439,14 +466,13 @@ def test_session_builder_app_name_no_existing_query_tag(): builder = Session.builder with mock.patch.object( - builder, - "_create_internal", - return_value=mocked_session) as m: - app_name = 'my_app_name' + builder, "_create_internal", return_value=mocked_session + ) as m: + app_name = "my_app_name" assert builder.app_name(app_name) is builder created_session = builder.getOrCreate() m.assert_called_once() - assert created_session.query_tag == f'APPNAME={app_name}' + assert created_session.query_tag == f"APPNAME={app_name}" def test_session_builder_app_name_existing_query_tag(): @@ -463,18 +489,17 @@ def test_session_builder_app_name_existing_query_tag(): ), ) - existing_query_tag = 'tag' + existing_query_tag = "tag" mocked_session._get_remote_query_tag = MagicMock(return_value=existing_query_tag) builder = Session.builder with mock.patch.object( - builder, - "_create_internal", - return_value=mocked_session) as m: - app_name = 'my_app_name' + builder, "_create_internal", return_value=mocked_session + ) as m: + app_name = "my_app_name" assert builder.app_name(app_name) is builder created_session = builder.getOrCreate() m.assert_called_once() - assert created_session.query_tag == f'tag,APPNAME={app_name}' + assert created_session.query_tag == f"tag,APPNAME={app_name}"