Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow UDF/Sproc Decorators to execute in a Local Sandbox #1341

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def __init__(
self,
options: Dict[str, Union[int, str]],
conn: Optional[SnowflakeConnection] = None,
_is_in_sandbox: bool = False,
) -> None:
self._is_in_sandbox = _is_in_sandbox
self._lower_case_parameters = {k.lower(): v for k, v in options.items()}
self._add_application_parameters()
self._conn = conn if conn else connect(**self._lower_case_parameters)
Expand Down Expand Up @@ -244,6 +246,9 @@ def upload_file(
overwrite: bool = False,
skip_upload_on_content_match: bool = False,
) -> Optional[Dict[str, Any]]:
if self._is_in_sandbox:
return {}

if is_in_stored_procedure(): # pragma: no cover
file_name = os.path.basename(path)
target_path = _build_target_path(stage_location, dest_prefix)
Expand Down Expand Up @@ -292,6 +297,9 @@ def upload_stream(
skip_upload_on_content_match: bool = False,
statement_params: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, Any]]:
if self._is_in_sandbox:
return {}

uri = normalize_local_file(f"/tmp/placeholder/{dest_filename}")
try:
if is_in_stored_procedure(): # pragma: no cover
Expand Down Expand Up @@ -351,6 +359,7 @@ def notify_query_listeners(self, query_record: QueryRecord) -> None:
def execute_and_notify_query_listener(
self, query: str, **kwargs: Any
) -> SnowflakeCursor:
# TODO: ? for all _cursor.execute calls
results_cursor = self._cursor.execute(query, **kwargs)
self.notify_query_listeners(
QueryRecord(results_cursor.sfqid, results_cursor.query)
Expand Down Expand Up @@ -390,8 +399,12 @@ def run_query(
case_sensitive: bool = True,
params: Optional[Sequence[Any]] = None,
num_statements: Optional[int] = None,
is_in_sandbox: bool = False,
**kwargs,
) -> Union[Dict[str, Any], AsyncJob]:
if self._is_in_sandbox or is_in_sandbox:
return {"data": ""}

try:
# Set SNOWPARK_SKIP_TXN_COMMIT_IN_DDL to True to avoid DDL commands to commit the open transaction
if is_ddl_on_temp_object:
Expand Down Expand Up @@ -507,6 +520,10 @@ def execute(
raise NotImplementedError(
"Async query is not supported in stored procedure yet"
)

if self._is_in_sandbox:
return []

result_set, result_meta = self.get_result_set(
plan,
to_pandas,
Expand Down Expand Up @@ -555,6 +572,9 @@ def get_result_set(
],
Union[List[ResultMetadata], List["ResultMetadataV2"]],
]:
if self._is_in_sandbox:
return {}, [] # But cannot determine which keys to put into this empty dict

action_id = plan.session._generate_new_action_id()
# potentially optimize the query using CTEs
plan = plan.replace_repeated_subquery_with_cte()
Expand Down Expand Up @@ -649,6 +669,8 @@ def get_result_set(
def get_result_and_metadata(
self, plan: SnowflakePlan, **kwargs
) -> Tuple[List[Row], List[Attribute]]:
if self._is_in_sandbox:
return [], []
result_set, result_meta = self.get_result_set(plan, **kwargs)
result = result_set_to_rows(result_set["data"])
attributes = convert_result_meta_to_attribute(result_meta)
Expand All @@ -657,7 +679,9 @@ def get_result_and_metadata(
def get_result_query_id(self, plan: SnowflakePlan, **kwargs) -> str:
# get the iterator such that the data is not fetched
result_set, _ = self.get_result_set(plan, to_iter=True, **kwargs)
return result_set["sfqid"]
return result_set[
"sfqid"
] # This will throw an error since key does not exist in empty dict

@_Decorator.wrap_exception
def run_batch_insert(self, query: str, rows: List[Row], **kwargs) -> None:
Expand Down
Loading
Loading