Skip to content

Commit

Permalink
Fix git execute with upper letter in directory (#1494)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas authored Sep 4, 2024
1 parent bb0cafe commit a3fc58f
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 25 deletions.
3 changes: 1 addition & 2 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

## Fixes and improvements
* Duplicated keys in `snowflake.yml` are now detected and reported.

* Fixed git execute not working with upper case in directory name.

# v3.0.0
## Backward incompatibility
Expand Down Expand Up @@ -60,7 +60,6 @@
* Improved error message for incompatible parameters.
* Fixed SQL error when running `snow app version create` and `snow app version drop` with a version name that isn't a valid Snowflake unquoted identifier


# v2.8.0
## Backward incompatibility

Expand Down
20 changes: 14 additions & 6 deletions src/snowflake/cli/_plugins/git/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,25 @@ def __init__(self, stage_path: str):

@property
def path(self) -> str:
return (
f"{self.stage_name}{self.directory}"
if self.stage_name.endswith("/")
else f"{self.stage_name}/{self.directory}"
)
return f"{self.stage_name.rstrip('/')}/{self.directory}"

def add_stage_prefix(self, file_path: str) -> str:
@classmethod
def get_directory(cls, stage_path: str) -> str:
return "/".join(Path(stage_path).parts[3:])

@property
def full_path(self) -> str:
return f"{self.stage.rstrip('/')}/{self.directory}"

def replace_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
file_path_without_prefix = Path(file_path).parts[1:]
return f"{stage}/{'/'.join(file_path_without_prefix)}"

def add_stage_prefix(self, file_path: str) -> str:
stage = self.stage.rstrip("/")
return f"{stage}/{file_path.lstrip('/')}"

def get_directory_from_file_path(self, file_path: str) -> List[str]:
stage_path_length = len(Path(self.directory).parts)
return list(Path(file_path).parts[3 + stage_path_length : -1])
Expand Down
61 changes: 46 additions & 15 deletions src/snowflake/cli/_plugins/stage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,21 @@ class StagePathParts:
stage_name: str
is_directory: bool

@staticmethod
def get_directory(stage_path: str) -> str:
@classmethod
def get_directory(cls, stage_path: str) -> str:
return "/".join(Path(stage_path).parts[1:])

@property
def path(self) -> str:
raise NotImplementedError

@property
def full_path(self) -> str:
raise NotImplementedError

def replace_stage_prefix(self, file_path: str) -> str:
raise NotImplementedError

def add_stage_prefix(self, file_path: str) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -112,24 +119,27 @@ def __init__(self, stage_path: str):
self.directory = self.get_directory(stage_path)
self.stage = StageManager.get_stage_from_path(stage_path)
stage_name = self.stage.split(".")[-1]
if stage_name.startswith("@"):
stage_name = stage_name[1:]
stage_name = stage_name[1:] if stage_name.startswith("@") else stage_name
self.stage_name = stage_name
self.is_directory = True if stage_path.endswith("/") else False

@property
def path(self) -> str:
return (
f"{self.stage_name}{self.directory}"
if self.stage_name.endswith("/")
else f"{self.stage_name}/{self.directory}"
)
return f"{self.stage_name.rstrip('/')}/{self.directory}"

def add_stage_prefix(self, file_path: str) -> str:
@property
def full_path(self) -> str:
return f"{self.stage.rstrip('/')}/{self.directory}"

def replace_stage_prefix(self, file_path: str) -> str:
stage = Path(self.stage).parts[0]
file_path_without_prefix = Path(file_path).parts[1:]
return f"{stage}/{'/'.join(file_path_without_prefix)}"

def add_stage_prefix(self, file_path: str) -> str:
stage = self.stage.rstrip("/")
return f"{stage}/{file_path.lstrip('/')}"

def get_directory_from_file_path(self, file_path: str) -> List[str]:
stage_path_length = len(Path(self.directory).parts)
return list(Path(file_path).parts[1 + stage_path_length : -1])
Expand All @@ -146,14 +156,29 @@ class UserStagePathParts(StagePathParts):

def __init__(self, stage_path: str):
self.directory = self.get_directory(stage_path)
self.stage = "@~"
self.stage_name = "@~"
self.stage = USER_STAGE_PREFIX
self.stage_name = USER_STAGE_PREFIX
self.is_directory = True if stage_path.endswith("/") else False

@classmethod
def get_directory(cls, stage_path: str) -> str:
if Path(stage_path).parts[0] == USER_STAGE_PREFIX:
return super().get_directory(stage_path)
return stage_path

@property
def path(self) -> str:
return f"{self.directory}"

@property
def full_path(self) -> str:
return f"{self.stage}/{self.directory}"

def replace_stage_prefix(self, file_path: str) -> str:
if Path(file_path).parts[0] == self.stage_name:
return file_path
return f"{self.stage}/{file_path}"

def add_stage_prefix(self, file_path: str) -> str:
return f"{self.stage}/{file_path}"

Expand Down Expand Up @@ -241,7 +266,7 @@ def get_recursive(
self._assure_is_existing_directory(dest_directory)

result = self._execute_query(
f"get {self.quote_stage_name(stage_path_parts.add_stage_prefix(file_path))} {self._to_uri(f'{dest_directory}/')} parallel={parallel}"
f"get {self.quote_stage_name(stage_path_parts.replace_stage_prefix(file_path))} {self._to_uri(f'{dest_directory}/')} parallel={parallel}"
)
results.append(result)

Expand Down Expand Up @@ -321,8 +346,14 @@ def execute(
stage_path_parts = self._stage_path_part_factory(stage_path)
all_files_list = self._get_files_list_from_stage(stage_path_parts)

all_files_with_stage_name_prefix = [
stage_path_parts.get_directory(file) for file in all_files_list
]

# filter files from stage if match stage_path pattern
filtered_file_list = self._filter_files_list(stage_path_parts, all_files_list)
filtered_file_list = self._filter_files_list(
stage_path_parts, all_files_with_stage_name_prefix
)

if not filtered_file_list:
raise ClickException(f"No files matched pattern '{stage_path}'")
Expand Down Expand Up @@ -378,7 +409,7 @@ def _filter_files_list(
if not stage_path_parts.directory:
return self._filter_supported_files(files_on_stage)

stage_path = stage_path_parts.path.lower()
stage_path = stage_path_parts.directory

# Exact file path was provided if stage_path in file list
if stage_path in files_on_stage:
Expand Down
9 changes: 9 additions & 0 deletions tests_integration/__snapshots__/test_git.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@
}),
])
# ---
# name: test_execute_with_name_in_pascal_case
list([
dict({
'Error': None,
'File': '@SNOWCLI_TESTING_REPO/branches/main/tests_integration/test_data/projects/stage_execute/ScriptInPascalCase.sql',
'Status': 'SUCCESS',
}),
])
# ---
1 change: 0 additions & 1 deletion tests_integration/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def test_copy_error(runner, sf_git_repository):
)


@pytest.mark.skip(reason="This will be enabled in following PR")
@pytest.mark.integration
def test_execute_with_name_in_pascal_case(
runner, test_database, sf_git_repository, snapshot
Expand Down
4 changes: 3 additions & 1 deletion tests_integration/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def test_create_error_schema_not_exist(runner, test_database):
@mock.patch.dict(os.environ, os.environ, clear=True)
def test_create_error_undefined_database(runner):
# undefined database
del os.environ["SNOWFLAKE_CONNECTIONS_INTEGRATION_DATABASE"]
database_environment_variable = "SNOWFLAKE_CONNECTIONS_INTEGRATION_DATABASE"
if database_environment_variable in os.environ:
del os.environ[database_environment_variable]

result = runner.invoke_with_connection(
["object", "create", "schema", f"name=test_schema"]
Expand Down

0 comments on commit a3fc58f

Please sign in to comment.