Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCT-7895 Add app_name method to SessionBuilder to register in query_tag. #1203

Merged
merged 8 commits into from
Feb 7, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features
- Added support for an optional `date_part` argument in function `last_day`
- `SessionBuilder.app_name` will set the query_tag after the session is created.

### Bug Fixes

Expand Down
6 changes: 6 additions & 0 deletions docs/source/_templates/autosummary/accessor_method.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{{ fullname }}
{{ underline }}

.. currentmodule:: {{ module + "." + objname.split(".")[0] }}

.. automethod:: {{ ".".join(objname.split(".")[1:]) }}
13 changes: 13 additions & 0 deletions docs/source/session.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ Snowpark Session

Session


.. rubric:: SessionBuilder

.. autosummary::
:toctree: api/
:template: autosummary/accessor_method.rst

Session.SessionBuilder.app_name
Session.SessionBuilder.config
Session.SessionBuilder.configs
Session.SessionBuilder.create
Session.SessionBuilder.getOrCreate

.. rubric:: Methods

..
Expand Down
15 changes: 14 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,20 @@ class SessionBuilder:

def __init__(self) -> None:
self._options = {}
self._app_name = None

def _remove_config(self, key: str) -> "Session.SessionBuilder":
"""Only used in test."""
self._options.pop(key, None)
return self

def app_name(self, app_name: str) -> "Session.SessionBuilder":
"""
Adds the app name to the :class:`SessionBuilder` to set in the query_tag after session creation
"""
self._app_name = app_name
return self

def config(self, key: str, value: Union[int, str]) -> "Session.SessionBuilder":
"""
Adds the specified connection parameter to the :class:`SessionBuilder` configuration.
Expand Down Expand Up @@ -334,6 +342,11 @@ def create(self) -> "Session":
_add_session(session)
else:
session = self._create_internal(self._options.get("connection"))

if self._app_name:
app_name_tag = f'APPNAME={self._app_name}'
session.append_query_tag(app_name_tag)

return session

def getOrCreate(self) -> "Session":
Expand Down Expand Up @@ -1627,7 +1640,7 @@ def _get_remote_query_tag(self) -> None:

def append_query_tag(self, tag: str, separator: str = ",") -> None:
"""
Appends a tag to the current query tag. The input tag is appended to the current sessions query tag with the given sperator.
Appends a tag to the current query tag. The input tag is appended to the current sessions query tag with the given separator.

Args:
tag: The tag to append to the current query tag.
Expand Down
17 changes: 16 additions & 1 deletion tests/integ/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_get_or_create(session):

@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_get_or_create_no_previous(db_parameters, session):
# Test getOrCreate error. In this case we wan to make sure that
# Test getOrCreate error. In this case we want to make sure that
# if there was not a session the session gets created
sessions_backup = list(_active_sessions)
_active_sessions.clear()
Expand Down Expand Up @@ -331,6 +331,21 @@ def test_create_session_from_connection_with_noise_parameters(
new_session.close()


@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
def test_session_builder_app_name(session, db_parameters):
builder = session.builder
app_name = 'my_app'
expected_query_tag = f'APPNAME={app_name}'
same_session = builder.app_name(app_name).getOrCreate()
new_session = builder.app_name(app_name).configs(db_parameters).create()
sfc-gh-kjimenezmorales marked this conversation as resolved.
Show resolved Hide resolved
try:
assert session == same_session
assert same_session.query_tag is None
assert new_session.query_tag == expected_query_tag
finally:
new_session.close()


@pytest.mark.skipif(
IS_IN_STORED_PROC,
reason="The test creates temporary tables of which the names do not follow the rules of temp object on purposes.",
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,63 @@ def test_connection_expiry():
) as m:
assert builder.getOrCreate() is None
m.assert_called_once()


def test_session_builder_app_name_no_existing_query_tag():
mocked_session = Session(
ServerConnection(
{"": ""},
mock.Mock(
spec=SnowflakeConnection,
_telemetry=mock.Mock(),
_session_parameters=mock.Mock(),
is_closed=mock.Mock(return_value=False),
expired=False,
),
),
)

mocked_session._get_remote_query_tag = MagicMock(return_value=None)

builder = Session.builder

with mock.patch.object(
builder,
"_create_internal",
return_value=mocked_session) as m:
app_name = 'my_app_name'
assert builder.app_name(app_name) is builder
created_session = builder.getOrCreate()
m.assert_called_once()
assert created_session.query_tag == f'APPNAME={app_name}'


def test_session_builder_app_name_existing_query_tag():
mocked_session = Session(
ServerConnection(
{"": ""},
mock.Mock(
spec=SnowflakeConnection,
_telemetry=mock.Mock(),
_session_parameters=mock.Mock(),
is_closed=mock.Mock(return_value=False),
expired=False,
),
),
)

existing_query_tag = 'tag'

mocked_session._get_remote_query_tag = MagicMock(return_value=existing_query_tag)

builder = Session.builder

with mock.patch.object(
builder,
"_create_internal",
return_value=mocked_session) as m:
app_name = 'my_app_name'
assert builder.app_name(app_name) is builder
created_session = builder.getOrCreate()
m.assert_called_once()
assert created_session.query_tag == f'tag,APPNAME={app_name}'
Loading