diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index 7ad81e3ceb3..7daa455ec4c 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -7,7 +7,6 @@ import json import logging import os -import re import sys import time import uuid @@ -491,20 +490,11 @@ def run_query( ] = None, # this argument is currently only used by AsyncJob **kwargs, ) -> Union[Dict[str, Any], AsyncJob]: - use_ddl_pattern = r"^\s*use\s+(warehouse|database|schema|role)\s+(.+)\s*$" - if match := re.match(use_ddl_pattern, query): - # if the query is "use xxx", then the object name is already verified by the upper stream - # we do not validate here - object_type = match.group(1) - object_name = match.group(2) - setattr(self, f"_active_{object_type}", object_name) - return {"data": [("Statement executed successfully.",)], "sfqid": None} - else: - self.log_not_supported_error( - external_feature_name="Running SQL queries", - internal_feature_name="MockServerConnection.run_query", - raise_error=NotImplementedError, - ) + self.log_not_supported_error( + external_feature_name="Running SQL queries", + internal_feature_name="MockServerConnection.run_query", + raise_error=NotImplementedError, + ) def _to_data_or_iter( self, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 9098b33cca7..98a50caa7ce 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -2761,7 +2761,22 @@ def use_secondary_roles(self, roles: Optional[Literal["all", "none"]]) -> None: def _use_object(self, object_name: str, object_type: str) -> None: if object_name: validate_object_name(object_name) - self._run_query(f"use {object_type} {object_name}") + query = f"use {object_type} {object_name}" + if isinstance(self._conn, MockServerConnection): + use_ddl_pattern = ( + r"^\s*use\s+(warehouse|database|schema|role)\s+(.+)\s*$" + ) + + if match := re.match(use_ddl_pattern, query): + # if the query is "use xxx", then the object name is already verified by the upper stream + # we do not validate here + object_type = match.group(1) + object_name = match.group(2) + setattr(self, f"_active_{object_type}", object_name) + else: + self._run_query(query) + else: + self._run_query(query) else: raise ValueError(f"'{object_type}' must not be empty or None.")