Skip to content

Commit

Permalink
SNOW-901157 support session recreation when token expires (#1139)
Browse files Browse the repository at this point in the history
* adding new feature

* PR feedback

* adding changelog
  • Loading branch information
sfc-gh-mkeller authored Jan 5, 2024
1 parent 4149ac7 commit 098da14
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- `chunk_size`: The number of bytes to hash per chunk of the uploaded files.
- `whole_file_hash`: By default only the first chunk of the uploaded import is hashed to save time. When this is set to True each uploaded file is fully hashed instead.
- Added parameters `external_access_integrations` and `secrets` when creating a UDAF from Snowpark Python to allow integration with external access.
- `SessionBuilder.getOrCreate` will now attempt to replace the singleton it returns when token expiration has been detected.

### Bug Fixes

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
THIS_DIR = os.path.dirname(os.path.realpath(__file__))
SRC_DIR = os.path.join(THIS_DIR, "src")
SNOWPARK_SRC_DIR = os.path.join(SRC_DIR, "snowflake", "snowpark")
CONNECTOR_DEPENDENCY_VERSION = ">=3.4.0, <4.0.0"
CONNECTOR_DEPENDENCY_VERSION = ">=3.6.0, <4.0.0"
INSTALL_REQ_LIST = [
"setuptools>=40.6.0",
"wheel",
Expand Down
13 changes: 9 additions & 4 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None


def _get_active_session() -> Optional["Session"]:
def _get_active_session() -> "Session":
with _session_management_lock:
if len(_active_sessions) == 1:
return next(iter(_active_sessions))
Expand All @@ -195,6 +195,8 @@ def _get_active_session() -> Optional["Session"]:
def _get_active_sessions() -> Set["Session"]:
with _session_management_lock:
if len(_active_sessions) >= 1:
# TODO: This function is allowing unsafe access to a mutex protected data
# structure, we should ONLY use it in tests
return _active_sessions
else:
raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION()
Expand Down Expand Up @@ -335,12 +337,15 @@ def create(self) -> "Session":
def getOrCreate(self) -> "Session":
"""Gets the last created session or creates a new one if needed."""
try:
return _get_active_session()
session = _get_active_session()
if session._conn._conn.expired:
_remove_session(session)
return self.create()
return session
except SnowparkClientException as ex:
if ex.error_code == "1403": # No session, ok lets create one
return self.create()
else: # Any other reason...
raise ex
raise

def _create_internal(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_sql_select_with_params(session):

def test_active_session(session):
assert session == _get_active_session()
assert not session._conn._conn.expired


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,32 @@ def test_connection():

assert session.connection == session._conn._conn
assert session.connection == fake_snowflake_connection


def test_connection_expiry():
session = Session(
ServerConnection(
{"": ""},
mock.Mock(
spec=SnowflakeConnection,
_telemetry=mock.Mock(),
_session_parameters=mock.Mock(),
is_closed=mock.Mock(return_value=False),
expired=False,
),
),
)
with mock.patch(
"snowflake.snowpark.session._active_sessions",
{session},
):
assert Session.builder.getOrCreate() is session
session._conn._conn.expired = True
builder = Session.builder
with mock.patch.object(
builder,
"create",
return_value=None,
) as m:
assert builder.getOrCreate() is None
m.assert_called_once()

0 comments on commit 098da14

Please sign in to comment.