Skip to content

Commit

Permalink
V2.0 Definition for snowpark (#1402)
Browse files Browse the repository at this point in the history
* Support project definition V2 in streamlit deploy command (#1369)

* Init

* Solution outline

* Fixing paths

* build fix

* build fix

* Update src/snowflake/cli/api/project/schemas/entities/snowpark_entity.py

Co-authored-by: Tomasz Urbaszek <[email protected]>

* Update src/snowflake/cli/api/project/schemas/entities/snowpark_entity.py

Co-authored-by: Tomasz Urbaszek <[email protected]>

* query problem

* Fix for zipper

* Test fix

* Test fix

* Update src/snowflake/cli/plugins/snowpark/commands.py

Co-authored-by: Tomasz Urbaszek <[email protected]>

* Update src/snowflake/cli/plugins/snowpark/commands.py

Co-authored-by: Patryk Czajka <[email protected]>

* Fixes

* Test fix

* Merge fixes

* Test fix

* Reformat

* Merge cleanup

* Fixup

* Changed to artifacts

* Fixup

---------

Co-authored-by: Tomasz Urbaszek <[email protected]>
Co-authored-by: Patryk Czajka <[email protected]>
  • Loading branch information
3 people authored Aug 10, 2024
1 parent 4ef4173 commit bb4ee2b
Show file tree
Hide file tree
Showing 14 changed files with 414 additions and 168 deletions.
143 changes: 101 additions & 42 deletions src/snowflake/cli/_plugins/snowpark/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from snowflake.cli._plugins.object.manager import ObjectManager
from snowflake.cli._plugins.snowpark import package_utils
from snowflake.cli._plugins.snowpark.common import (
FunctionOrProcedure,
UdfSprocIdentifier,
check_if_replace_is_required,
)
Expand All @@ -54,7 +53,10 @@
)
from snowflake.cli._plugins.snowpark.zipper import zip_dir
from snowflake.cli._plugins.stage.manager import StageManager
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.cli_global_context import (
_CliGlobalContextAccess,
get_cli_context,
)
from snowflake.cli.api.commands.decorators import (
with_project_definition,
)
Expand All @@ -70,15 +72,22 @@
DEPLOYMENT_STAGE,
ObjectType,
)
from snowflake.cli.api.exceptions import SecretsWithoutExternalAccessIntegrationError
from snowflake.cli.api.entities.snowpark_entity import SnowparkEntity
from snowflake.cli.api.exceptions import (
NoProjectDefinitionError,
SecretsWithoutExternalAccessIntegrationError,
)
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.output.types import (
CollectionResult,
CommandResult,
MessageResult,
SingleQueryResult,
)
from snowflake.cli.api.project.project_verification import assert_project_type
from snowflake.cli.api.project.schemas.project_definition import (
ProjectDefinition,
ProjectDefinitionV2,
)
from snowflake.cli.api.project.schemas.snowpark.callable import (
FunctionSchema,
ProcedureSchema,
Expand Down Expand Up @@ -121,18 +130,15 @@ def deploy(
By default, if any of the objects exist already the commands will fail unless `--replace` flag is provided.
All deployed objects use the same artifact which is deployed only once.
"""
cli_context, pd = _get_v2_context_and_project_definition()

assert_project_type("snowpark")

cli_context = get_cli_context()
snowpark = cli_context.project_definition.snowpark
paths = SnowparkPackagePaths.for_snowpark_project(
project_root=SecurePath(cli_context.project_root),
snowpark_project_definition=snowpark,
project_definition=pd,
)

procedures = snowpark.procedures
functions = snowpark.functions
procedures = pd.get_entities_by_type("procedure")
functions = pd.get_entities_by_type("function")

if not procedures and not functions:
raise ClickException(
Expand Down Expand Up @@ -164,29 +170,33 @@ def deploy(
raise ClickException(msg)

# Create stage
stage_name = snowpark.stage_name
stage_manager = StageManager()
stage_name = FQN.from_string(stage_name).using_context()
stage_manager.create(fqn=stage_name, comment="deployments managed by Snowflake CLI")

snowflake_dependencies = _read_snowflake_requrements_file(
paths.snowflake_requirements_file
)
stage_names = {
entity.stage for entity in [*functions.values(), *procedures.values()]
}
stage_manager = StageManager()

artifact_stage_directory = get_app_stage_path(stage_name, snowpark.project_name)
artifact_stage_target = (
f"{artifact_stage_directory}/{paths.artifact_file.path.name}"
)
# TODO: Raise error if stage name is not provided

stage_manager.put(
local_path=paths.artifact_file.path,
stage_path=artifact_stage_directory,
overwrite=True,
)
for stage in stage_names:
stage = FQN.from_string(stage).using_context()
stage_manager.create(fqn=stage, comment="deployments managed by Snowflake CLI")
artifact_stage_directory = get_app_stage_path(stage, pd.defaults.project_name)
artifact_stage_target = (
f"{artifact_stage_directory}/{paths.artifact_file.path.name}"
)

stage_manager.put(
local_path=paths.artifact_file.path,
stage_path=artifact_stage_directory,
overwrite=True,
)

deploy_status = []
# Procedures
for procedure in procedures:
for procedure in procedures.values():
operation_result = _deploy_single_object(
manager=pm,
object_type=ObjectType.PROCEDURE,
Expand All @@ -198,7 +208,7 @@ def deploy(
deploy_status.append(operation_result)

# Functions
for function in functions:
for function in functions.values():
operation_result = _deploy_single_object(
manager=fm,
object_type=ObjectType.FUNCTION,
Expand All @@ -213,9 +223,9 @@ def deploy(


def _assert_object_definitions_are_correct(
object_type, object_definitions: List[FunctionOrProcedure]
object_type, object_definitions: Dict[str, SnowparkEntity]
):
for definition in object_definitions:
for name, definition in object_definitions.items():
database = definition.database
schema = definition.schema_name
name = definition.name
Expand All @@ -232,11 +242,11 @@ def _assert_object_definitions_are_correct(

def _find_existing_objects(
object_type: ObjectType,
objects: List[FunctionOrProcedure],
objects: Dict[str, SnowparkEntity],
om: ObjectManager,
):
existing_objects = {}
for object_definition in objects:
for object_name, object_definition in objects.items():
identifier = UdfSprocIdentifier.from_definition(
object_definition
).identifier_with_arg_types
Expand All @@ -253,16 +263,16 @@ def _find_existing_objects(

def _check_if_all_defined_integrations_exists(
om: ObjectManager,
functions: List[FunctionSchema],
procedures: List[ProcedureSchema],
functions: Dict[str, FunctionSchema],
procedures: Dict[str, ProcedureSchema],
):
existing_integrations = {
i["name"].lower()
for i in om.show(object_type="integration", cursor_class=DictCursor, like=None)
if i["type"] == "EXTERNAL_ACCESS"
}
declared_integration: Set[str] = set()
for object_definition in [*functions, *procedures]:
for object_definition in [*functions.values(), *procedures.values()]:
external_access_integrations = {
s.lower() for s in object_definition.external_access_integrations
}
Expand All @@ -280,15 +290,15 @@ def _check_if_all_defined_integrations_exists(
)


def get_app_stage_path(stage_name: Optional[str], project_name: str) -> str:
def get_app_stage_path(stage_name: Optional[str | FQN], project_name: str) -> str:
artifact_stage_directory = f"@{(stage_name or DEPLOYMENT_STAGE)}/{project_name}"
return artifact_stage_directory


def _deploy_single_object(
manager: FunctionManager | ProcedureManager,
object_type: ObjectType,
object_definition: FunctionOrProcedure,
object_definition: SnowparkEntity,
existing_objects: Dict[str, Dict],
snowflake_dependencies: List[str],
stage_artifact_path: str,
Expand Down Expand Up @@ -374,16 +384,16 @@ def build(
) -> CommandResult:
"""
Builds the Snowpark project as a `.zip` archive that can be used by `deploy` command.
The archive is built using only the `src` directory specified in the project file.
The archive is built using only the `artifacts` directory specified in the project file.
"""
cli_context, pd = _get_v2_context_and_project_definition()

assert_project_type("snowpark")
cli_context = get_cli_context()
snowpark_paths = SnowparkPackagePaths.for_snowpark_project(
project_root=SecurePath(cli_context.project_root),
snowpark_project_definition=cli_context.project_definition.snowpark,
project_definition=pd,
)
log.info("Building package using sources from: %s", snowpark_paths.source.path)
log.info("Building package using sources from:")
log.info(",".join(str(s) for s in snowpark_paths.sources))

anaconda_packages_manager = AnacondaPackagesManager()

Expand Down Expand Up @@ -424,7 +434,7 @@ def build(
)

zip_dir(
source=snowpark_paths.source.path,
source=snowpark_paths.sources_paths,
dest_zip=snowpark_paths.artifact_file.path,
)
if any(packages_dir.iterdir()):
Expand Down Expand Up @@ -510,3 +520,52 @@ def describe(
):
"""Provides description of a procedure or function."""
object_describe(object_type=object_type.value, object_name=identifier, **options)


def _migrate_v1_snowpark_to_v2(pd: ProjectDefinition):
if not pd.snowpark:
raise NoProjectDefinitionError(
project_type="snowpark", project_file=get_cli_context().project_root
)

data: dict = {
"definition_version": "2",
"defaults": {
"stage": pd.snowpark.stage_name,
"project_name": pd.snowpark.project_name,
},
"entities": {},
}

for entity in [*pd.snowpark.procedures, *pd.snowpark.functions]:
v2_entity = {
"type": "function" if isinstance(entity, FunctionSchema) else "procedure",
"stage": pd.snowpark.stage_name,
"artifacts": pd.snowpark.src,
"handler": entity.handler,
"returns": entity.returns,
"signature": entity.signature,
"runtime": entity.runtime,
"external_access_integrations": entity.external_access_integrations,
"secrets": entity.secrets,
"imports": entity.imports,
"name": entity.name,
"database": entity.database,
"schema": entity.schema_name,
}
if isinstance(entity, ProcedureSchema):
v2_entity["execute_as_caller"] = entity.execute_as_caller

data["entities"][entity.name] = v2_entity

return ProjectDefinitionV2(**data)


def _get_v2_context_and_project_definition() -> Tuple[
_CliGlobalContextAccess, ProjectDefinitionV2
]:
cli_context = get_cli_context()
pd = cli_context.project_definition
if not pd.meets_version_requirement("2"):
pd = _migrate_v1_snowpark_to_v2(pd)
return cli_context, pd
10 changes: 3 additions & 7 deletions src/snowflake/cli/_plugins/snowpark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@
from __future__ import annotations

import re
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Set

from snowflake.cli._plugins.snowpark.models import Requirement
from snowflake.cli._plugins.snowpark.package_utils import (
generate_deploy_stage_name,
)
from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.entities.snowpark_entity import SnowparkEntity
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.project.schemas.snowpark.callable import (
FunctionSchema,
ProcedureSchema,
)
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.connector.cursor import SnowflakeCursor

DEFAULT_RUNTIME = "3.10"
FunctionOrProcedure = Union[FunctionSchema, ProcedureSchema]


def check_if_replace_is_required(
Expand Down Expand Up @@ -271,7 +267,7 @@ def identifier_for_sql(self):
return self._identifier_from_signature(self._full_signature(), for_sql=True)

@classmethod
def from_definition(cls, udf_sproc: FunctionOrProcedure):
def from_definition(cls, udf_sproc: SnowparkEntity):
names = []
types = []
defaults = []
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/cli/_plugins/snowpark/package_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parse_requirements(
).splitlines():
line = re.sub(r"\s*#.*", "", line).strip()
if line:
reqs.append(Requirement.parse(line))
reqs.append(Requirement.parse_line(line))
return reqs


Expand Down
39 changes: 25 additions & 14 deletions src/snowflake/cli/_plugins/snowpark/snowpark_package_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path
from typing import List

from snowflake.cli.api.project.schemas.snowpark.snowpark import Snowpark
from snowflake.cli.api.project.schemas.project_definition import DefinitionV20
from snowflake.cli.api.secure_path import SecurePath

_DEFINED_REQUIREMENTS = "requirements.txt"
Expand All @@ -23,24 +25,31 @@

@dataclass
class SnowparkPackagePaths:
source: SecurePath
sources: List[SecurePath]
artifact_file: SecurePath
defined_requirements_file: SecurePath = SecurePath(_DEFINED_REQUIREMENTS)
snowflake_requirements_file: SecurePath = SecurePath(_REQUIREMENTS_SNOWFLAKE)

@classmethod
def for_snowpark_project(
cls, project_root: SecurePath, snowpark_project_definition: Snowpark
cls, project_root: SecurePath, project_definition: DefinitionV20
) -> "SnowparkPackagePaths":
defined_source_path = SecurePath(snowpark_project_definition.src)
sources = set()
entities = project_definition.get_entities_by_type(
"function"
) | project_definition.get_entities_by_type("procedure")
for name, entity in entities.items():
sources.add(entity.artifacts)

return cls(
source=cls._get_snowpark_project_source_absolute_path(
project_root=project_root,
defined_source_path=defined_source_path,
),
sources=[
cls._get_snowpark_project_source_absolute_path(
project_root, SecurePath(source)
)
for source in sources
],
artifact_file=cls._get_snowpark_project_artifact_absolute_path(
project_root=project_root,
defined_source_path=defined_source_path,
),
defined_requirements_file=project_root / _DEFINED_REQUIREMENTS,
snowflake_requirements_file=project_root / _REQUIREMENTS_SNOWFLAKE,
Expand All @@ -56,10 +65,12 @@ def _get_snowpark_project_source_absolute_path(

@classmethod
def _get_snowpark_project_artifact_absolute_path(
cls, project_root: SecurePath, defined_source_path: SecurePath
cls, project_root: SecurePath
) -> SecurePath:
source_path = cls._get_snowpark_project_source_absolute_path(
project_root=project_root, defined_source_path=defined_source_path
)
artifact_file = project_root / (source_path.path.name + ".zip")

artifact_file = project_root / "app.zip"
return artifact_file

@property
def sources_paths(self) -> List[Path]:
return [source.path for source in self.sources]
Loading

0 comments on commit bb4ee2b

Please sign in to comment.