Skip to content

Commit

Permalink
Added support for FQN in stage and git execute (#1023)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Tomasz Urbaszek <[email protected]>
  • Loading branch information
sfc-gh-astus and sfc-gh-turbaszek authored May 7, 2024
1 parent bf568ed commit 95d639b
Show file tree
Hide file tree
Showing 12 changed files with 545 additions and 81 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Fixed a bug in `snow app` that caused files to be re-uploaded unnecessarily.
* Optimize snowpark dependency search to lower the size of .zip artifacts and
the number of anaconda dependencies for snowpark projects.
* Added support for fully qualified stage names in stage and git execute commands.

# v2.2.0

Expand Down
24 changes: 22 additions & 2 deletions src/snowflake/cli/plugins/git/manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from textwrap import dedent

from snowflake.cli.plugins.stage.manager import StageManager
from snowflake.cli.plugins.stage.manager import StageManager, StagePathParts
from snowflake.connector.cursor import SnowflakeCursor


Expand Down Expand Up @@ -30,9 +30,29 @@ def create(
return self._execute_query(query)

@staticmethod
def get_stage_name_from_path(path: str):
def get_stage_from_path(path: str):
"""
Returns stage name from potential path on stage. For example
repo/branches/main/foo/bar -> repo/branches/main/
"""
return f"{'/'.join(Path(path).parts[0:3])}/"

def _split_stage_path(self, stage_path: str) -> StagePathParts:
"""
Splits Git repository path `@repo/branch/main/dir`
stage -> @repo/branch/main/
stage_name -> repo/branch/main/
directory -> dir
For Git repository with fully qualified name `@db.schema.repo/branch/main/dir`
stage -> @db.schema.repo/branch/main/
stage_name -> repo/branch/main/
directory -> dir
"""
stage = self.get_stage_from_path(stage_path)
stage_path_parts = Path(stage_path).parts
git_repo_name = stage_path_parts[0].split(".")[-1]
if git_repo_name.startswith("@"):
git_repo_name = git_repo_name[1:]
stage_name = "/".join([git_repo_name, *stage_path_parts[1:3], ""])
directory = "/".join(stage_path_parts[3:])
return StagePathParts(stage, stage_name, directory)
2 changes: 1 addition & 1 deletion src/snowflake/cli/plugins/snowpark/package/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def upload(file: Path, stage: str, overwrite: bool):
temp_app_zip_path = prepare_app_zip(SecurePath(file), temp_dir)
sm = StageManager()

sm.create(sm.get_stage_name_from_path(stage))
sm.create(sm.get_stage_from_path(stage))
put_response = sm.put(
temp_app_zip_path.path, stage, overwrite=overwrite
).fetchone()
Expand Down
95 changes: 74 additions & 21 deletions src/snowflake/cli/plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import re
from contextlib import nullcontext
from dataclasses import dataclass
from os import path
from pathlib import Path
from typing import Dict, List, Optional, Union
Expand All @@ -26,6 +27,25 @@
EXECUTE_SUPPORTED_FILES_FORMATS = {".sql"}


@dataclass
class StagePathParts:
# For path like @db.schema.stage/dir the values will be:
# stage = @db.schema.stage
stage: str
# stage_name = stage/dir
stage_name: str
# directory = dir
directory: str

@property
def path(self) -> str:
return (
f"{self.stage_name}{self.directory}".lower()
if self.stage_name.endswith("/")
else f"{self.stage_name}/{self.directory}".lower()
)


class StageManager(SqlExecutionMixin):
@staticmethod
def get_standard_stage_prefix(name: str) -> str:
Expand All @@ -45,7 +65,7 @@ def get_standard_stage_directory_path(path):
return StageManager.get_standard_stage_prefix(path)

@staticmethod
def get_stage_name_from_path(path: str):
def get_stage_from_path(path: str):
"""
Returns stage name from potential path on stage. For example
db.schema.stage/foo/bar -> db.schema.stage
Expand Down Expand Up @@ -185,10 +205,12 @@ def execute(
on_error: OnErrorType,
variables: Optional[List[str]] = None,
):
stage_path = self.get_standard_stage_prefix(stage_path)
all_files_list = self._get_files_list_from_stage(stage_path)
stage_path_with_prefix = self.get_standard_stage_prefix(stage_path)
stage_path_parts = self._split_stage_path(stage_path_with_prefix)
all_files_list = self._get_files_list_from_stage(stage_path_parts)

# filter files from stage if match stage_path pattern
filtered_file_list = self._filter_files_list(stage_path, all_files_list)
filtered_file_list = self._filter_files_list(stage_path_parts, all_files_list)

if not filtered_file_list:
raise ClickException(f"No files matched pattern '{stage_path}'")
Expand All @@ -203,26 +225,48 @@ def execute(
for file in sorted_file_list:
results.append(
self._call_execute_immediate(
file=file, variables=sql_variables, on_error=on_error
stage_path_parts=stage_path_parts,
file=file,
variables=sql_variables,
on_error=on_error,
)
)

return results

def _get_files_list_from_stage(self, stage_path: str) -> List[str]:
stage_name = self.get_stage_name_from_path(stage_path)
files_list_result = self.list_files(stage_name).fetchall()
def _split_stage_path(self, stage_path: str) -> StagePathParts:
"""
Splits stage path `@stage/dir` to
stage -> @stage
stage_name -> stage
directory -> dir
For stage path with fully qualified name `@db.schema.stage/dir`
stage -> @db.schema.stage
stage_name -> stage
directory -> dir
"""
stage = self.get_stage_from_path(stage_path)
stage_name = stage.split(".")[-1]
if stage_name.startswith("@"):
stage_name = stage_name[1:]
directory = "/".join(Path(stage_path).parts[1:])
return StagePathParts(stage, stage_name, directory)

def _get_files_list_from_stage(self, stage_path_parts: StagePathParts) -> List[str]:
files_list_result = self.list_files(stage_path_parts.stage).fetchall()

if not files_list_result:
raise ClickException(f"No files found on stage '{stage_name}'")
raise ClickException(f"No files found on stage '{stage_path_parts.stage}'")

return [f["name"] for f in files_list_result]

def _filter_files_list(
self, stage_path: str, files_on_stage: List[str]
self, stage_path_parts: StagePathParts, files_on_stage: List[str]
) -> List[str]:
stage_path = self.remove_stage_prefix(stage_path)
stage_path = stage_path.lower()
if not stage_path_parts.directory:
return self._filter_supported_files(files_on_stage)

stage_path = stage_path_parts.path

# Exact file path was provided if stage_path in file list
if stage_path in files_on_stage:
Expand Down Expand Up @@ -256,20 +300,29 @@ def _parse_execute_variables(variables: Optional[List[str]]) -> Optional[str]:
return f" using ({', '.join(query_parameters)})"

def _call_execute_immediate(
self, file: str, variables: Optional[str], on_error: OnErrorType
self,
stage_path_parts: StagePathParts,
file: str,
variables: Optional[str],
on_error: OnErrorType,
) -> Dict:
file_stage_path = self._build_file_stage_path(stage_path_parts, file)
try:
stage_path_prefixed = self.get_standard_stage_prefix(file)
query = (
f"execute immediate from {self.quote_stage_name(stage_path_prefixed)}"
)
query = f"execute immediate from {file_stage_path}"
if variables:
query += variables
self._execute_query(query)
cli_console.step(f"SUCCESS - {file}")
return {"File": file, "Status": "SUCCESS", "Error": None}
cli_console.step(f"SUCCESS - {file_stage_path}")
return {"File": file_stage_path, "Status": "SUCCESS", "Error": None}
except ProgrammingError as e:
cli_console.warning(f"FAILURE - {file}")
cli_console.warning(f"FAILURE - {file_stage_path}")
if on_error == OnErrorType.BREAK:
raise e
return {"File": file, "Status": "FAILURE", "Error": e.msg}
return {"File": file_stage_path, "Status": "FAILURE", "Error": e.msg}

def _build_file_stage_path(
self, stage_path_parts: StagePathParts, file: str
) -> str:
stage = Path(stage_path_parts.stage).parts[0]
file_path = Path(file).parts[1:]
return f"{stage}/{'/'.join(file_path)}"
60 changes: 60 additions & 0 deletions tests/git/__snapshots__/test_git_commands.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# serializer version: 1
# name: test_execute[@db.schema.repo/branches/main/[email protected]/branches/main/-expected_files2]
'''
SUCCESS - @db.schema.repo/branches/main/s1.sql
SUCCESS - @db.schema.repo/branches/main/a/s3.sql
+----------------------------------------------------------+
| File | Status | Error |
|----------------------------------------+---------+-------|
| @db.schema.repo/branches/main/s1.sql | SUCCESS | None |
| @db.schema.repo/branches/main/a/s3.sql | SUCCESS | None |
+----------------------------------------------------------+

'''
# ---
# name: test_execute[@db.schema.repo/branches/main/[email protected]/branches/main/-expected_files0]
'''
SUCCESS - @db.schema.repo/branches/main/s1.sql
+--------------------------------------------------------+
| File | Status | Error |
|--------------------------------------+---------+-------|
| @db.schema.repo/branches/main/s1.sql | SUCCESS | None |
+--------------------------------------------------------+

'''
# ---
# name: test_execute[@db.schema.repo/branches/main/[email protected]/branches/main/-expected_files3]
'''
SUCCESS - @db.schema.repo/branches/main/s1.sql
+--------------------------------------------------------+
| File | Status | Error |
|--------------------------------------+---------+-------|
| @db.schema.repo/branches/main/s1.sql | SUCCESS | None |
+--------------------------------------------------------+

'''
# ---
# name: test_execute[@repo/branches/main/-@repo/branches/main/-expected_files0]
'''
SUCCESS - @repo/branches/main/s1.sql
SUCCESS - @repo/branches/main/a/s3.sql
+------------------------------------------------+
| File | Status | Error |
|------------------------------+---------+-------|
| @repo/branches/main/s1.sql | SUCCESS | None |
| @repo/branches/main/a/s3.sql | SUCCESS | None |
+------------------------------------------------+

'''
# ---
# name: test_execute[@repo/branches/main/a-@repo/branches/main/-expected_files1]
'''
SUCCESS - @repo/branches/main/a/s3.sql
+------------------------------------------------+
| File | Status | Error |
|------------------------------+---------+-------|
| @repo/branches/main/a/s3.sql | SUCCESS | None |
+------------------------------------------------+

'''
# ---
36 changes: 30 additions & 6 deletions tests/git/test_git_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,20 +400,43 @@ def test_setup_create_secret_create_api(


@pytest.mark.parametrize(
"repository_path, expected_files",
"repository_path, expected_stage, expected_files",
[
(
"@repo/branches/main/",
["repo/branches/main/s1.sql", "repo/branches/main/a/s3.sql"],
"@repo/branches/main/",
["@repo/branches/main/s1.sql", "@repo/branches/main/a/s3.sql"],
),
(
"@repo/branches/main/a",
["repo/branches/main/a/s3.sql"],
"@repo/branches/main/",
["@repo/branches/main/a/s3.sql"],
),
(
"@db.schema.repo/branches/main/",
"@db.schema.repo/branches/main/",
[
"@db.schema.repo/branches/main/s1.sql",
"@db.schema.repo/branches/main/a/s3.sql",
],
),
(
"@db.schema.repo/branches/main/s1.sql",
"@db.schema.repo/branches/main/",
["@db.schema.repo/branches/main/s1.sql"],
),
],
)
@mock.patch(f"{STAGE_MANAGER}._execute_query")
def test_execute(mock_execute, mock_cursor, runner, repository_path, expected_files):
def test_execute(
mock_execute,
mock_cursor,
runner,
repository_path,
expected_stage,
expected_files,
snapshot,
):
mock_execute.return_value = mock_cursor(
[
{"name": "repo/branches/main/a/s3.sql"},
Expand All @@ -427,10 +450,11 @@ def test_execute(mock_execute, mock_cursor, runner, repository_path, expected_fi

assert result.exit_code == 0, result.output
ls_call, *execute_calls = mock_execute.mock_calls
assert ls_call == mock.call(f"ls @repo/branches/main/", cursor_class=DictCursor)
assert ls_call == mock.call(f"ls {expected_stage}", cursor_class=DictCursor)
assert execute_calls == [
mock.call(f"execute immediate from @{p}") for p in expected_files
mock.call(f"execute immediate from {p}") for p in expected_files
]
assert result.output == snapshot


@mock.patch(f"{STAGE_MANAGER}._execute_query")
Expand Down
Loading

0 comments on commit 95d639b

Please sign in to comment.