Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aling committed Apr 29, 2024
1 parent 8d97559 commit 7a5f559
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
20 changes: 5 additions & 15 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import json
import logging
import os
import re
import sys
import time
import uuid
Expand Down Expand Up @@ -491,20 +490,11 @@ def run_query(
] = None, # this argument is currently only used by AsyncJob
**kwargs,
) -> Union[Dict[str, Any], AsyncJob]:
use_ddl_pattern = r"^\s*use\s+(warehouse|database|schema|role)\s+(.+)\s*$"
if match := re.match(use_ddl_pattern, query):
# if the query is "use xxx", then the object name is already verified by the upper stream
# we do not validate here
object_type = match.group(1)
object_name = match.group(2)
setattr(self, f"_active_{object_type}", object_name)
return {"data": [("Statement executed successfully.",)], "sfqid": None}
else:
self.log_not_supported_error(
external_feature_name="Running SQL queries",
internal_feature_name="MockServerConnection.run_query",
raise_error=NotImplementedError,
)
self.log_not_supported_error(
external_feature_name="Running SQL queries",
internal_feature_name="MockServerConnection.run_query",
raise_error=NotImplementedError,
)

def _to_data_or_iter(
self,
Expand Down
17 changes: 16 additions & 1 deletion src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,7 +2761,22 @@ def use_secondary_roles(self, roles: Optional[Literal["all", "none"]]) -> None:
def _use_object(self, object_name: str, object_type: str) -> None:
if object_name:
validate_object_name(object_name)
self._run_query(f"use {object_type} {object_name}")
query = f"use {object_type} {object_name}"
if isinstance(self._conn, MockServerConnection):
use_ddl_pattern = (
r"^\s*use\s+(warehouse|database|schema|role)\s+(.+)\s*$"
)

if match := re.match(use_ddl_pattern, query):
# if the query is "use xxx", then the object name is already verified by the upper stream
# we do not validate here
object_type = match.group(1)
object_name = match.group(2)
setattr(self, f"_active_{object_type}", object_name)
else:
self._run_query(query)
else:
self._run_query(query)
else:
raise ValueError(f"'{object_type}' must not be empty or None.")

Expand Down

0 comments on commit 7a5f559

Please sign in to comment.