diff --git a/CHANGELOG.md b/CHANGELOG.md index e1d0674e0b2..ab3ec74bda9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features - Added support for an optional `date_part` argument in function `last_day` +- `SessionBuilder.app_name` will set the query_tag after the session is created. ### Bug Fixes diff --git a/docs/source/_templates/autosummary/accessor_method.rst b/docs/source/_templates/autosummary/accessor_method.rst new file mode 100644 index 00000000000..96dec5e8a99 --- /dev/null +++ b/docs/source/_templates/autosummary/accessor_method.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module + "." + objname.split(".")[0] }} + +.. automethod:: {{ ".".join(objname.split(".")[1:]) }} \ No newline at end of file diff --git a/docs/source/session.rst b/docs/source/session.rst index 468aa7a4a34..9ee2521dcf7 100644 --- a/docs/source/session.rst +++ b/docs/source/session.rst @@ -11,6 +11,19 @@ Snowpark Session Session + +.. rubric:: SessionBuilder + +.. autosummary:: + :toctree: api/ + :template: autosummary/accessor_method.rst + + Session.SessionBuilder.app_name + Session.SessionBuilder.config + Session.SessionBuilder.configs + Session.SessionBuilder.create + Session.SessionBuilder.getOrCreate + .. rubric:: Methods .. diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index db90f0b0b6a..5d4e9a90051 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -300,12 +300,20 @@ class SessionBuilder: def __init__(self) -> None: self._options = {} + self._app_name = None def _remove_config(self, key: str) -> "Session.SessionBuilder": """Only used in test.""" self._options.pop(key, None) return self + def app_name(self, app_name: str) -> "Session.SessionBuilder": + """ + Adds the app name to the :class:`SessionBuilder` to set in the query_tag after session creation + """ + self._app_name = app_name + return self + def config(self, key: str, value: Union[int, str]) -> "Session.SessionBuilder": """ Adds the specified connection parameter to the :class:`SessionBuilder` configuration. @@ -334,6 +342,11 @@ def create(self) -> "Session": _add_session(session) else: session = self._create_internal(self._options.get("connection")) + + if self._app_name: + app_name_tag = f'APPNAME={self._app_name}' + session.append_query_tag(app_name_tag) + return session def getOrCreate(self) -> "Session": @@ -1627,7 +1640,7 @@ def _get_remote_query_tag(self) -> None: def append_query_tag(self, tag: str, separator: str = ",") -> None: """ - Appends a tag to the current query tag. The input tag is appended to the current sessions query tag with the given sperator. + Appends a tag to the current query tag. The input tag is appended to the current sessions query tag with the given separator. Args: tag: The tag to append to the current query tag. diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 6c50d0a1e5b..858a24508a3 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -112,7 +112,7 @@ def test_get_or_create(session): @pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") def test_get_or_create_no_previous(db_parameters, session): - # Test getOrCreate error. In this case we wan to make sure that + # Test getOrCreate error. In this case we want to make sure that # if there was not a session the session gets created sessions_backup = list(_active_sessions) _active_sessions.clear() @@ -331,6 +331,21 @@ def test_create_session_from_connection_with_noise_parameters( new_session.close() +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP") +def test_session_builder_app_name(session, db_parameters): + builder = session.builder + app_name = 'my_app' + expected_query_tag = f'APPNAME={app_name}' + same_session = builder.app_name(app_name).getOrCreate() + new_session = builder.app_name(app_name).configs(db_parameters).create() + try: + assert session == same_session + assert same_session.query_tag is None + assert new_session.query_tag == expected_query_tag + finally: + new_session.close() + + @pytest.mark.skipif( IS_IN_STORED_PROC, reason="The test creates temporary tables of which the names do not follow the rules of temp object on purposes.", diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0f0e53d96fe..cfe9fe273cd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -418,3 +418,63 @@ def test_connection_expiry(): ) as m: assert builder.getOrCreate() is None m.assert_called_once() + + +def test_session_builder_app_name_no_existing_query_tag(): + mocked_session = Session( + ServerConnection( + {"": ""}, + mock.Mock( + spec=SnowflakeConnection, + _telemetry=mock.Mock(), + _session_parameters=mock.Mock(), + is_closed=mock.Mock(return_value=False), + expired=False, + ), + ), + ) + + mocked_session._get_remote_query_tag = MagicMock(return_value=None) + + builder = Session.builder + + with mock.patch.object( + builder, + "_create_internal", + return_value=mocked_session) as m: + app_name = 'my_app_name' + assert builder.app_name(app_name) is builder + created_session = builder.getOrCreate() + m.assert_called_once() + assert created_session.query_tag == f'APPNAME={app_name}' + + +def test_session_builder_app_name_existing_query_tag(): + mocked_session = Session( + ServerConnection( + {"": ""}, + mock.Mock( + spec=SnowflakeConnection, + _telemetry=mock.Mock(), + _session_parameters=mock.Mock(), + is_closed=mock.Mock(return_value=False), + expired=False, + ), + ), + ) + + existing_query_tag = 'tag' + + mocked_session._get_remote_query_tag = MagicMock(return_value=existing_query_tag) + + builder = Session.builder + + with mock.patch.object( + builder, + "_create_internal", + return_value=mocked_session) as m: + app_name = 'my_app_name' + assert builder.app_name(app_name) is builder + created_session = builder.getOrCreate() + m.assert_called_once() + assert created_session.query_tag == f'tag,APPNAME={app_name}'