diff --git a/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py b/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py index 4f9cffabcd..72e677ca3a 100644 --- a/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py +++ b/src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py @@ -14,6 +14,8 @@ from __future__ import annotations +import copy +import re from typing import Dict, Optional from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext @@ -34,7 +36,7 @@ ) SNOWPARK_PROCESSOR = "snowpark" -NA_SETUP_PROCESSOR = "native-app-setup" +NA_SETUP_PROCESSOR = "native app setup" _REGISTERED_PROCESSORS_BY_NAME = { SNOWPARK_PROCESSOR: SnowparkAnnotationProcessor, @@ -110,7 +112,15 @@ def _try_create_processor( # No registered processor with the specified name return None - current_processor = processor_factory(self._bundle_ctx) + processor_ctx = copy.copy(self._bundle_ctx) + processor_subdirectory = re.sub(r"[^a-zA-Z0-9_$]", "_", processor_name) + processor_ctx.bundle_root = ( + self._bundle_ctx.bundle_root / processor_subdirectory + ) + processor_ctx.generated_root = ( + self._bundle_ctx.generated_root / processor_subdirectory + ) + current_processor = processor_factory(processor_ctx) self.cached_processors[processor_name] = current_processor return current_processor diff --git a/src/snowflake/cli/_plugins/nativeapp/codegen/setup/native_app_setup_processor.py b/src/snowflake/cli/_plugins/nativeapp/codegen/setup/native_app_setup_processor.py index 7d32127595..928cd6eb88 100644 --- a/src/snowflake/cli/_plugins/nativeapp/codegen/setup/native_app_setup_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/codegen/setup/native_app_setup_processor.py @@ -14,10 +14,9 @@ from __future__ import annotations -import filecmp import json +import logging import os.path -import shutil from pathlib import Path from typing import List, Optional @@ -47,6 +46,8 @@ DEFAULT_TIMEOUT = 30 DRIVER_PATH = Path(__file__).parent / "setup_driver.py.source" +log = logging.getLogger(__name__) + def safe_set(d: dict, *keys: str, **kwargs) -> None: curr = d @@ -94,50 +95,56 @@ def process( logs = result.get("logs", []) for msg in logs: - cc.message(f"LOG: {msg}") + log.debug(msg) warnings = result.get("warnings", []) for msg in warnings: cc.warning(msg) - if result.get("schema_version") == "1": - setup_script_mods = [ - mod - for mod in result.get("modifications", []) - if mod.get("target") == "native_app:setup_script" - ] + schema_version = result.get("schema_version") + if schema_version != "1": + raise ClickException( + f"Unsupported schema version returned from snowflake-app-python library: {schema_version}" + ) + + setup_script_mods = [ + mod + for mod in result.get("modifications", []) + if mod.get("target") == "native_app:setup_script" + ] + if setup_script_mods: self._edit_setup_sql(setup_script_mods) - manifest_mods = [ - mod - for mod in result.get("modifications", []) - if mod.get("target") == "native_app:manifest" - ] + manifest_mods = [ + mod + for mod in result.get("modifications", []) + if mod.get("target") == "native_app:manifest" + ] + if manifest_mods: self._edit_manifest(manifest_mods) - else: - self._generate_setup_sql_legacy(result) def _execute_in_sandbox(self, py_files: List[Path]) -> dict: file_count = len(py_files) cc.step(f"Processing {file_count} setup file{'s' if file_count > 1 else ''}") manifest_path = find_manifest_file(deploy_root=self._bundle_ctx.deploy_root) - temp_manifest_path = self.bundle_root / manifest_path.name - shutil.copyfile(manifest_path, temp_manifest_path) + + generated_root = self._bundle_ctx.generated_root + generated_root.mkdir(exist_ok=True, parents=True) env_vars = { "_SNOWFLAKE_CLI_PROJECT_PATH": str(self._bundle_ctx.project_root), "_SNOWFLAKE_CLI_SETUP_FILES": os.pathsep.join(map(str, py_files)), "_SNOWFLAKE_CLI_APP_NAME": str(self._bundle_ctx.package_name), - "_SNOWFLAKE_CLI_SQL_DEST_DIR": str(self.generated_root), - "_SNOWFLAKE_CLI_MANIFEST_PATH": str(temp_manifest_path), + "_SNOWFLAKE_CLI_SQL_DEST_DIR": str(generated_root), + "_SNOWFLAKE_CLI_MANIFEST_PATH": str(manifest_path), } try: result = execute_script_in_sandbox( script_source=DRIVER_PATH.read_text(), env_type=ExecutionEnvironmentType.VENV, - cwd=self.bundle_root, + cwd=self._bundle_ctx.bundle_root, timeout=DEFAULT_TIMEOUT, path=self.sandbox_root, env_vars=env_vars, @@ -148,26 +155,13 @@ def _execute_in_sandbox(self, py_files: List[Path]) -> dict: ) if result.returncode == 0: - processor_result = json.loads(result.stdout) - - if not filecmp.cmp(manifest_path, temp_manifest_path): - # manifest was edited, update the original in the deploy root - with self.edit_file(manifest_path) as f: - f.edited_contents = temp_manifest_path.read_text() - - return processor_result + return json.loads(result.stdout) else: raise ClickException( f"Failed to execute python setup script logic: {result.stderr}" ) def _edit_setup_sql(self, modifications: List[dict]) -> None: - generated_root = self.generated_root - generated_root.mkdir(exist_ok=True, parents=True) - - if not modifications: - return - cc.step("Patching setup script") setup_file_path = find_setup_script_file( deploy_root=self._bundle_ctx.deploy_root @@ -182,17 +176,14 @@ def _edit_setup_sql(self, modifications: List[dict]) -> None: if inst.get("type") == "insert": default_loc = inst.get("default_location") if default_loc == "end": - appended.append(self._setup_mod_inst_to_sql(inst)) + appended.append(self._setup_mod_instruction_to_sql(inst)) elif default_loc == "start": - prepended.append(self._setup_mod_inst_to_sql(inst)) + prepended.append(self._setup_mod_instruction_to_sql(inst)) if prepended or appended: f.edited_contents = "\n".join(prepended + [f.contents] + appended) def _edit_manifest(self, modifications: List[dict]) -> None: - if not modifications: - return - cc.step("Patching manifest") manifest_path = find_manifest_file(deploy_root=self._bundle_ctx.deploy_root) @@ -209,57 +200,23 @@ def _edit_manifest(self, modifications: List[dict]) -> None: safe_set(manifest, *key.split("."), value=value) f.edited_contents = yaml.safe_dump(manifest, sort_keys=False) - def _setup_mod_inst_to_sql(self, mod_inst: dict) -> str: - payload = mod_inst["payload"] - if payload["type"] == "execute immediate": + def _setup_mod_instruction_to_sql(self, mod_inst: dict) -> str: + payload = mod_inst.get("payload") + if not payload: + raise ClickException("Unsupported instruction received: no payload found") + + payload_type = payload.get("type") + if payload_type == "execute immediate": file_path = payload.get("file_path") if file_path: - sql_file_path = self.generated_root / file_path + sql_file_path = self._bundle_ctx.generated_root / file_path return f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';" - raise ClickException("Invalid instructions received") - - def _generate_setup_sql_legacy(self, result: dict): - generated_root = self.generated_root - generated_root.mkdir(exist_ok=True, parents=True) - - cc.step("Patching setup script") - setup_file_path = find_setup_script_file( - deploy_root=self._bundle_ctx.deploy_root - ) - - with self.edit_file(setup_file_path) as f: - new_contents = [f.contents] - - if result["prepend"]: - for sql_file in result["prepend"]: - sql_file_path = generated_root / sql_file - new_contents.insert( - 0, - f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';", - ) - - if result["append"]: - for sql_file in result["append"]: - sql_file_path = generated_root / sql_file - new_contents.append( - f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';", - ) - - if len(new_contents) > 1: - f.edited_contents = "\n".join(new_contents) + raise ClickException(f"Unsupported instruction type received: {payload_type}") @property def sandbox_root(self): - return self.bundle_root / "venv" - - @property - def generated_root(self): - return self._bundle_ctx.generated_root / "setup_py" - - @property - def bundle_root(self): - return self._bundle_ctx.bundle_root / "setup_py" + return self._bundle_ctx.bundle_root / "venv" def _create_or_update_sandbox(self): sandbox_root = self.sandbox_root diff --git a/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py b/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py index 7eab90dd4e..e26e1b1c26 100644 --- a/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py +++ b/src/snowflake/cli/_plugins/nativeapp/codegen/snowpark/python_processor.py @@ -226,13 +226,9 @@ def process( edit_setup_script_with_exec_imm_sql( collected_sql_files=collected_sql_files, deploy_root=bundle_map.deploy_root(), - generated_root=self._generated_root, + generated_root=self._bundle_ctx.generated_root, ) - @property - def _generated_root(self): - return self._bundle_ctx.generated_root / "snowpark" - def _normalize_imports( self, extension_fn: NativeAppExtensionFunction, @@ -366,7 +362,9 @@ def generate_new_sql_file_name(self, py_file: Path) -> Path: Generates a SQL filename for the generated root from the python file, and creates its parent directories. """ relative_py_file = py_file.relative_to(self._bundle_ctx.deploy_root) - sql_file = Path(self._generated_root, relative_py_file.with_suffix(".sql")) + sql_file = Path( + self._bundle_ctx.generated_root, relative_py_file.with_suffix(".sql") + ) if sql_file.exists(): cc.warning( f"""\ diff --git a/tests/nativeapp/codegen/snowpark/test_python_processor.py b/tests/nativeapp/codegen/snowpark/test_python_processor.py index bb704bcaed..303842665b 100644 --- a/tests/nativeapp/codegen/snowpark/test_python_processor.py +++ b/tests/nativeapp/codegen/snowpark/test_python_processor.py @@ -405,7 +405,13 @@ def test_process_with_collected_functions( project_definition=native_app_project_instance.native_app, project_root=local_path, ) - processor = SnowparkAnnotationProcessor(project.get_bundle_context()) + project_context = project.get_bundle_context() + processor_context = copy.copy(project_context) + processor_context.generated_root = ( + project_context.generated_root / "snowpark" + ) + processor_context.bundle_root = project_context.bundle_root / "snowpark" + processor = SnowparkAnnotationProcessor(processor_context) processor.process( artifact_to_process=native_app_project_instance.native_app.artifacts[0], processor_mapping=processor_mapping, @@ -465,7 +471,13 @@ def test_package_normalization( project_definition=native_app_project_instance.native_app, project_root=local_path, ) - processor = SnowparkAnnotationProcessor(project.get_bundle_context()) + project_context = project.get_bundle_context() + processor_context = copy.copy(project_context) + processor_context.generated_root = ( + project_context.generated_root / "snowpark" + ) + processor_context.bundle_root = project_context.bundle_root / "snowpark" + processor = SnowparkAnnotationProcessor(processor_context) processor.process( artifact_to_process=native_app_project_instance.native_app.artifacts[0], processor_mapping=processor_mapping,