From 098da14e5108f81e0619cc742b505c0382a83384 Mon Sep 17 00:00:00 2001 From: Mark Keller Date: Fri, 5 Jan 2024 19:50:52 +0000 Subject: [PATCH] SNOW-901157 support session recreation when token expires (#1139) * adding new feature * PR feedback * adding changelog --- CHANGELOG.md | 1 + setup.py | 2 +- src/snowflake/snowpark/session.py | 13 +++++++++---- tests/integ/test_session.py | 1 + tests/unit/test_session.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7212193c5e7..2eb1b8558a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - `chunk_size`: The number of bytes to hash per chunk of the uploaded files. - `whole_file_hash`: By default only the first chunk of the uploaded import is hashed to save time. When this is set to True each uploaded file is fully hashed instead. - Added parameters `external_access_integrations` and `secrets` when creating a UDAF from Snowpark Python to allow integration with external access. +- `SessionBuilder.getOrCreate` will now attempt to replace the singleton it returns when token expiration has been detected. ### Bug Fixes diff --git a/setup.py b/setup.py index de75942a426..3d96f7a3006 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ THIS_DIR = os.path.dirname(os.path.realpath(__file__)) SRC_DIR = os.path.join(THIS_DIR, "src") SNOWPARK_SRC_DIR = os.path.join(SRC_DIR, "snowflake", "snowpark") -CONNECTOR_DEPENDENCY_VERSION = ">=3.4.0, <4.0.0" +CONNECTOR_DEPENDENCY_VERSION = ">=3.6.0, <4.0.0" INSTALL_REQ_LIST = [ "setuptools>=40.6.0", "wheel", diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 7267e5f7ebc..94d68cd554b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -182,7 +182,7 @@ WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None -def _get_active_session() -> Optional["Session"]: +def _get_active_session() -> "Session": with _session_management_lock: if len(_active_sessions) == 1: return next(iter(_active_sessions)) @@ -195,6 +195,8 @@ def _get_active_session() -> Optional["Session"]: def _get_active_sessions() -> Set["Session"]: with _session_management_lock: if len(_active_sessions) >= 1: + # TODO: This function is allowing unsafe access to a mutex protected data + # structure, we should ONLY use it in tests return _active_sessions else: raise SnowparkClientExceptionMessages.SERVER_NO_DEFAULT_SESSION() @@ -335,12 +337,15 @@ def create(self) -> "Session": def getOrCreate(self) -> "Session": """Gets the last created session or creates a new one if needed.""" try: - return _get_active_session() + session = _get_active_session() + if session._conn._conn.expired: + _remove_session(session) + return self.create() + return session except SnowparkClientException as ex: if ex.error_code == "1403": # No session, ok lets create one return self.create() - else: # Any other reason... - raise ex + raise def _create_internal( self, diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 4d84c06d199..07a430f9745 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -80,6 +80,7 @@ def test_sql_select_with_params(session): def test_active_session(session): assert session == _get_active_session() + assert not session._conn._conn.expired @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 39ab959d608..0f0e53d96fe 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -389,3 +389,32 @@ def test_connection(): assert session.connection == session._conn._conn assert session.connection == fake_snowflake_connection + + +def test_connection_expiry(): + session = Session( + ServerConnection( + {"": ""}, + mock.Mock( + spec=SnowflakeConnection, + _telemetry=mock.Mock(), + _session_parameters=mock.Mock(), + is_closed=mock.Mock(return_value=False), + expired=False, + ), + ), + ) + with mock.patch( + "snowflake.snowpark.session._active_sessions", + {session}, + ): + assert Session.builder.getOrCreate() is session + session._conn._conn.expired = True + builder = Session.builder + with mock.patch.object( + builder, + "create", + return_value=None, + ) as m: + assert builder.getOrCreate() is None + m.assert_called_once()