Skip to content

Commit

Permalink
SNOW-1049757 Make sure cleanup active sessions at interpreter shutdown (
Browse files Browse the repository at this point in the history
#1250)

* SNOW-1049757 Make sure cleanup active sessions at interpreter shutdown

* resolve comments
  • Loading branch information
sfc-gh-azhan authored Feb 15, 2024
1 parent 2516702 commit 39654bd
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 14 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
- Added support for an optional `date_part` argument in function `last_day`
- `SessionBuilder.app_name` will set the query_tag after the session is created.

### Improvements

- Added cleanup logic at interpreter shutdown to close all active sessions.

### Bug Fixes

- Fixed a bug in `DataFrame.to_local_iterator` where the iterator could yield wrong results if another query is executed before the iterator finishes due to wrong isolation level. For details, please see #945.
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import atexit
import datetime
import decimal
import inspect
Expand Down Expand Up @@ -209,6 +210,23 @@ def _add_session(session: "Session") -> None:
_active_sessions.add(session)


def _close_session_atexit():
"""
This is the helper function to close all active sessions at interpreter shutdown. For example, when a jupyter
notebook is shutting down, this will also close all active sessions and make sure send all telemetry to the server.
"""
with _session_management_lock:
for session in _active_sessions.copy():
try:
session.close()
except Exception:
pass


# Register _close_session_atexit so it will be called at interpreter shutdown
atexit.register(_close_session_atexit)


def _remove_session(session: "Session") -> None:
with _session_management_lock:
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ def test_create_session_from_connection_with_noise_parameters(
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_session_builder_app_name(session, db_parameters):
builder = session.builder
app_name = 'my_app'
expected_query_tag = f'APPNAME={app_name}'
app_name = "my_app"
expected_query_tag = f"APPNAME={app_name}"
same_session = builder.app_name(app_name).getOrCreate()
new_session = builder.app_name(app_name).configs(db_parameters).create()
try:
Expand Down
49 changes: 37 additions & 12 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
SnowparkInvalidObjectNameException,
SnowparkSessionException,
)
from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
_close_session_atexit,
)
from snowflake.snowpark.types import StructField, StructType


Expand Down Expand Up @@ -378,6 +381,30 @@ def test_session_id():
assert session.session_id == 123456


def test_session_close_atexit():
mocked_session = Session(
ServerConnection(
{"": ""},
mock.Mock(
spec=SnowflakeConnection,
_telemetry=mock.Mock(),
_session_parameters=mock.Mock(),
is_closed=mock.Mock(return_value=False),
expired=False,
),
),
)

with mock.patch(
"snowflake.snowpark.session._active_sessions",
{mocked_session},
):
with mock.patch.object(snowflake.snowpark.session.Session, "close") as m:
# _close_session_atexit will be called when the interpreter is shutting down
_close_session_atexit()
m.assert_called_once()


def test_connection():
fake_snowflake_connection = mock.create_autospec(SnowflakeConnection)
fake_snowflake_connection._telemetry = mock.Mock()
Expand Down Expand Up @@ -439,14 +466,13 @@ def test_session_builder_app_name_no_existing_query_tag():
builder = Session.builder

with mock.patch.object(
builder,
"_create_internal",
return_value=mocked_session) as m:
app_name = 'my_app_name'
builder, "_create_internal", return_value=mocked_session
) as m:
app_name = "my_app_name"
assert builder.app_name(app_name) is builder
created_session = builder.getOrCreate()
m.assert_called_once()
assert created_session.query_tag == f'APPNAME={app_name}'
assert created_session.query_tag == f"APPNAME={app_name}"


def test_session_builder_app_name_existing_query_tag():
Expand All @@ -463,18 +489,17 @@ def test_session_builder_app_name_existing_query_tag():
),
)

existing_query_tag = 'tag'
existing_query_tag = "tag"

mocked_session._get_remote_query_tag = MagicMock(return_value=existing_query_tag)

builder = Session.builder

with mock.patch.object(
builder,
"_create_internal",
return_value=mocked_session) as m:
app_name = 'my_app_name'
builder, "_create_internal", return_value=mocked_session
) as m:
app_name = "my_app_name"
assert builder.app_name(app_name) is builder
created_session = builder.getOrCreate()
m.assert_called_once()
assert created_session.query_tag == f'tag,APPNAME={app_name}'
assert created_session.query_tag == f"tag,APPNAME={app_name}"

0 comments on commit 39654bd

Please sign in to comment.