Skip to content

Commit

Permalink
Refactored subdir logic, cleaned up old logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bdufour committed Aug 26, 2024
1 parent 9668b2d commit 9414704
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 94 deletions.
14 changes: 12 additions & 2 deletions src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import copy
import re
from typing import Dict, Optional

from snowflake.cli._plugins.nativeapp.bundle_context import BundleContext
Expand All @@ -34,7 +36,7 @@
)

SNOWPARK_PROCESSOR = "snowpark"
NA_SETUP_PROCESSOR = "native-app-setup"
NA_SETUP_PROCESSOR = "native app setup"

_REGISTERED_PROCESSORS_BY_NAME = {
SNOWPARK_PROCESSOR: SnowparkAnnotationProcessor,
Expand Down Expand Up @@ -110,7 +112,15 @@ def _try_create_processor(
# No registered processor with the specified name
return None

current_processor = processor_factory(self._bundle_ctx)
processor_ctx = copy.copy(self._bundle_ctx)
processor_subdirectory = re.sub(r"[^a-zA-Z0-9_$]", "_", processor_name)
processor_ctx.bundle_root = (
self._bundle_ctx.bundle_root / processor_subdirectory
)
processor_ctx.generated_root = (
self._bundle_ctx.generated_root / processor_subdirectory
)
current_processor = processor_factory(processor_ctx)
self.cached_processors[processor_name] = current_processor

return current_processor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

from __future__ import annotations

import filecmp
import json
import logging
import os.path
import shutil
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -47,6 +46,8 @@
DEFAULT_TIMEOUT = 30
DRIVER_PATH = Path(__file__).parent / "setup_driver.py.source"

log = logging.getLogger(__name__)


def safe_set(d: dict, *keys: str, **kwargs) -> None:
curr = d
Expand Down Expand Up @@ -94,50 +95,56 @@ def process(

logs = result.get("logs", [])
for msg in logs:
cc.message(f"LOG: {msg}")
log.debug(msg)

warnings = result.get("warnings", [])
for msg in warnings:
cc.warning(msg)

if result.get("schema_version") == "1":
setup_script_mods = [
mod
for mod in result.get("modifications", [])
if mod.get("target") == "native_app:setup_script"
]
schema_version = result.get("schema_version")
if schema_version != "1":
raise ClickException(
f"Unsupported schema version returned from snowflake-app-python library: {schema_version}"
)

setup_script_mods = [
mod
for mod in result.get("modifications", [])
if mod.get("target") == "native_app:setup_script"
]
if setup_script_mods:
self._edit_setup_sql(setup_script_mods)

manifest_mods = [
mod
for mod in result.get("modifications", [])
if mod.get("target") == "native_app:manifest"
]
manifest_mods = [
mod
for mod in result.get("modifications", [])
if mod.get("target") == "native_app:manifest"
]
if manifest_mods:
self._edit_manifest(manifest_mods)
else:
self._generate_setup_sql_legacy(result)

def _execute_in_sandbox(self, py_files: List[Path]) -> dict:
file_count = len(py_files)
cc.step(f"Processing {file_count} setup file{'s' if file_count > 1 else ''}")

manifest_path = find_manifest_file(deploy_root=self._bundle_ctx.deploy_root)
temp_manifest_path = self.bundle_root / manifest_path.name
shutil.copyfile(manifest_path, temp_manifest_path)

generated_root = self._bundle_ctx.generated_root
generated_root.mkdir(exist_ok=True, parents=True)

env_vars = {
"_SNOWFLAKE_CLI_PROJECT_PATH": str(self._bundle_ctx.project_root),
"_SNOWFLAKE_CLI_SETUP_FILES": os.pathsep.join(map(str, py_files)),
"_SNOWFLAKE_CLI_APP_NAME": str(self._bundle_ctx.package_name),
"_SNOWFLAKE_CLI_SQL_DEST_DIR": str(self.generated_root),
"_SNOWFLAKE_CLI_MANIFEST_PATH": str(temp_manifest_path),
"_SNOWFLAKE_CLI_SQL_DEST_DIR": str(generated_root),
"_SNOWFLAKE_CLI_MANIFEST_PATH": str(manifest_path),
}

try:
result = execute_script_in_sandbox(
script_source=DRIVER_PATH.read_text(),
env_type=ExecutionEnvironmentType.VENV,
cwd=self.bundle_root,
cwd=self._bundle_ctx.bundle_root,
timeout=DEFAULT_TIMEOUT,
path=self.sandbox_root,
env_vars=env_vars,
Expand All @@ -148,26 +155,13 @@ def _execute_in_sandbox(self, py_files: List[Path]) -> dict:
)

if result.returncode == 0:
processor_result = json.loads(result.stdout)

if not filecmp.cmp(manifest_path, temp_manifest_path):
# manifest was edited, update the original in the deploy root
with self.edit_file(manifest_path) as f:
f.edited_contents = temp_manifest_path.read_text()

return processor_result
return json.loads(result.stdout)
else:
raise ClickException(
f"Failed to execute python setup script logic: {result.stderr}"
)

def _edit_setup_sql(self, modifications: List[dict]) -> None:
generated_root = self.generated_root
generated_root.mkdir(exist_ok=True, parents=True)

if not modifications:
return

cc.step("Patching setup script")
setup_file_path = find_setup_script_file(
deploy_root=self._bundle_ctx.deploy_root
Expand All @@ -182,17 +176,14 @@ def _edit_setup_sql(self, modifications: List[dict]) -> None:
if inst.get("type") == "insert":
default_loc = inst.get("default_location")
if default_loc == "end":
appended.append(self._setup_mod_inst_to_sql(inst))
appended.append(self._setup_mod_instruction_to_sql(inst))
elif default_loc == "start":
prepended.append(self._setup_mod_inst_to_sql(inst))
prepended.append(self._setup_mod_instruction_to_sql(inst))

if prepended or appended:
f.edited_contents = "\n".join(prepended + [f.contents] + appended)

def _edit_manifest(self, modifications: List[dict]) -> None:
if not modifications:
return

cc.step("Patching manifest")
manifest_path = find_manifest_file(deploy_root=self._bundle_ctx.deploy_root)

Expand All @@ -209,57 +200,23 @@ def _edit_manifest(self, modifications: List[dict]) -> None:
safe_set(manifest, *key.split("."), value=value)
f.edited_contents = yaml.safe_dump(manifest, sort_keys=False)

def _setup_mod_inst_to_sql(self, mod_inst: dict) -> str:
payload = mod_inst["payload"]
if payload["type"] == "execute immediate":
def _setup_mod_instruction_to_sql(self, mod_inst: dict) -> str:
payload = mod_inst.get("payload")
if not payload:
raise ClickException("Unsupported instruction received: no payload found")

payload_type = payload.get("type")
if payload_type == "execute immediate":
file_path = payload.get("file_path")
if file_path:
sql_file_path = self.generated_root / file_path
sql_file_path = self._bundle_ctx.generated_root / file_path
return f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';"

raise ClickException("Invalid instructions received")

def _generate_setup_sql_legacy(self, result: dict):
generated_root = self.generated_root
generated_root.mkdir(exist_ok=True, parents=True)

cc.step("Patching setup script")
setup_file_path = find_setup_script_file(
deploy_root=self._bundle_ctx.deploy_root
)

with self.edit_file(setup_file_path) as f:
new_contents = [f.contents]

if result["prepend"]:
for sql_file in result["prepend"]:
sql_file_path = generated_root / sql_file
new_contents.insert(
0,
f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';",
)

if result["append"]:
for sql_file in result["append"]:
sql_file_path = generated_root / sql_file
new_contents.append(
f"EXECUTE IMMEDIATE FROM '/{to_stage_path(sql_file_path.relative_to(self._bundle_ctx.deploy_root))}';",
)

if len(new_contents) > 1:
f.edited_contents = "\n".join(new_contents)
raise ClickException(f"Unsupported instruction type received: {payload_type}")

@property
def sandbox_root(self):
return self.bundle_root / "venv"

@property
def generated_root(self):
return self._bundle_ctx.generated_root / "setup_py"

@property
def bundle_root(self):
return self._bundle_ctx.bundle_root / "setup_py"
return self._bundle_ctx.bundle_root / "venv"

def _create_or_update_sandbox(self):
sandbox_root = self.sandbox_root
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,9 @@ def process(
edit_setup_script_with_exec_imm_sql(
collected_sql_files=collected_sql_files,
deploy_root=bundle_map.deploy_root(),
generated_root=self._generated_root,
generated_root=self._bundle_ctx.generated_root,
)

@property
def _generated_root(self):
return self._bundle_ctx.generated_root / "snowpark"

def _normalize_imports(
self,
extension_fn: NativeAppExtensionFunction,
Expand Down Expand Up @@ -366,7 +362,9 @@ def generate_new_sql_file_name(self, py_file: Path) -> Path:
Generates a SQL filename for the generated root from the python file, and creates its parent directories.
"""
relative_py_file = py_file.relative_to(self._bundle_ctx.deploy_root)
sql_file = Path(self._generated_root, relative_py_file.with_suffix(".sql"))
sql_file = Path(
self._bundle_ctx.generated_root, relative_py_file.with_suffix(".sql")
)
if sql_file.exists():
cc.warning(
f"""\
Expand Down
16 changes: 14 additions & 2 deletions tests/nativeapp/codegen/snowpark/test_python_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,13 @@ def test_process_with_collected_functions(
project_definition=native_app_project_instance.native_app,
project_root=local_path,
)
processor = SnowparkAnnotationProcessor(project.get_bundle_context())
project_context = project.get_bundle_context()
processor_context = copy.copy(project_context)
processor_context.generated_root = (
project_context.generated_root / "snowpark"
)
processor_context.bundle_root = project_context.bundle_root / "snowpark"
processor = SnowparkAnnotationProcessor(processor_context)
processor.process(
artifact_to_process=native_app_project_instance.native_app.artifacts[0],
processor_mapping=processor_mapping,
Expand Down Expand Up @@ -465,7 +471,13 @@ def test_package_normalization(
project_definition=native_app_project_instance.native_app,
project_root=local_path,
)
processor = SnowparkAnnotationProcessor(project.get_bundle_context())
project_context = project.get_bundle_context()
processor_context = copy.copy(project_context)
processor_context.generated_root = (
project_context.generated_root / "snowpark"
)
processor_context.bundle_root = project_context.bundle_root / "snowpark"
processor = SnowparkAnnotationProcessor(processor_context)
processor.process(
artifact_to_process=native_app_project_instance.native_app.artifacts[0],
processor_mapping=processor_mapping,
Expand Down

0 comments on commit 9414704

Please sign in to comment.