Skip to content

Commit

Permalink
SNOW-1049757 Make sure cleanup active sessions at interpreter shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-azhan committed Feb 14, 2024
1 parent 26f561a commit 1d07608
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
- Added support for an optional `date_part` argument in function `last_day`
- `SessionBuilder.app_name` will set the query_tag after the session is created.

### Improvements

- Make sure cleanup active sessions at interpreter shutdown.

### Bug Fixes

- Fixed a bug in `DataFrame.to_local_iterator` where the iterator could yield wrong results if another query is executed before the iterator finishes due to wrong isolation level. For details, please see #945.
Expand Down
21 changes: 20 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import atexit
import datetime
import decimal
import inspect
Expand Down Expand Up @@ -182,6 +183,7 @@
"PYTHON_SNOWPARK_USE_LOGICAL_TYPE_FOR_CREATE_DATAFRAME"
)
WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None
_register_atexit: bool = False


def _get_active_session() -> "Session":
Expand All @@ -207,6 +209,23 @@ def _get_active_sessions() -> Set["Session"]:
def _add_session(session: "Session") -> None:
with _session_management_lock:
_active_sessions.add(session)
global _register_atexit
if not _register_atexit:
atexit.register(_close_session_atexit)
_register_atexit = True


def _close_session_atexit():
"""
This is the helper function to close all active sessions at interpreter shutdown. For example, when a jupyter
notebook is shutting down, this will also close all active sessions and make sure send all telemetry to the server.
"""
with _session_management_lock:
for session in _active_sessions.copy():
try:
session.close()
except Exception:
pass


def _remove_session(session: "Session") -> None:
Expand Down Expand Up @@ -344,7 +363,7 @@ def create(self) -> "Session":
session = self._create_internal(self._options.get("connection"))

if self._app_name:
app_name_tag = f'APPNAME={self._app_name}'
app_name_tag = f"APPNAME={self._app_name}"
session.append_query_tag(app_name_tag)

return session
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ def test_create_session_from_connection_with_noise_parameters(
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_session_builder_app_name(session, db_parameters):
builder = session.builder
app_name = 'my_app'
expected_query_tag = f'APPNAME={app_name}'
app_name = "my_app"
expected_query_tag = f"APPNAME={app_name}"
same_session = builder.app_name(app_name).getOrCreate()
new_session = builder.app_name(app_name).configs(db_parameters).create()
try:
Expand Down
49 changes: 37 additions & 12 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
SnowparkInvalidObjectNameException,
SnowparkSessionException,
)
from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
_close_session_atexit,
)
from snowflake.snowpark.types import StructField, StructType


Expand Down Expand Up @@ -378,6 +381,30 @@ def test_session_id():
assert session.session_id == 123456


def test_session_close_atexit():
mocked_session = Session(
ServerConnection(
{"": ""},
mock.Mock(
spec=SnowflakeConnection,
_telemetry=mock.Mock(),
_session_parameters=mock.Mock(),
is_closed=mock.Mock(return_value=False),
expired=False,
),
),
)

with mock.patch(
"snowflake.snowpark.session._active_sessions",
{mocked_session},
):
with mock.patch.object(snowflake.snowpark.session.Session, "close") as m:
# _close_session_atexit will be called when the interpreter is shutting down
_close_session_atexit()
m.assert_called_once()


def test_connection():
fake_snowflake_connection = mock.create_autospec(SnowflakeConnection)
fake_snowflake_connection._telemetry = mock.Mock()
Expand Down Expand Up @@ -439,14 +466,13 @@ def test_session_builder_app_name_no_existing_query_tag():
builder = Session.builder

with mock.patch.object(
builder,
"_create_internal",
return_value=mocked_session) as m:
app_name = 'my_app_name'
builder, "_create_internal", return_value=mocked_session
) as m:
app_name = "my_app_name"
assert builder.app_name(app_name) is builder
created_session = builder.getOrCreate()
m.assert_called_once()
assert created_session.query_tag == f'APPNAME={app_name}'
assert created_session.query_tag == f"APPNAME={app_name}"


def test_session_builder_app_name_existing_query_tag():
Expand All @@ -463,18 +489,17 @@ def test_session_builder_app_name_existing_query_tag():
),
)

existing_query_tag = 'tag'
existing_query_tag = "tag"

mocked_session._get_remote_query_tag = MagicMock(return_value=existing_query_tag)

builder = Session.builder

with mock.patch.object(
builder,
"_create_internal",
return_value=mocked_session) as m:
app_name = 'my_app_name'
builder, "_create_internal", return_value=mocked_session
) as m:
app_name = "my_app_name"
assert builder.app_name(app_name) is builder
created_session = builder.getOrCreate()
m.assert_called_once()
assert created_session.query_tag == f'tag,APPNAME={app_name}'
assert created_session.query_tag == f"tag,APPNAME={app_name}"

0 comments on commit 1d07608

Please sign in to comment.