Skip to content

Commit

Permalink
Merge branch 'main' into jsikorski/mixins
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jsikorski authored Aug 26, 2024
2 parents e03ba65 + 4b30958 commit 41bc76a
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 341 deletions.
124 changes: 15 additions & 109 deletions src/snowflake/cli/_plugins/snowpark/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import logging
from collections import defaultdict
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, Optional, Set, Tuple

import typer
from click import ClickException, UsageError
Expand All @@ -36,16 +35,18 @@
from snowflake.cli._plugins.object.manager import ObjectManager
from snowflake.cli._plugins.snowpark import package_utils
from snowflake.cli._plugins.snowpark.common import (
check_if_replace_is_required,
EntityToImportPathsMapping,
SnowparkEntities,
SnowparkObject,
SnowparkObjectManager,
StageToArtefactMapping,
)
from snowflake.cli._plugins.snowpark.manager import FunctionManager, ProcedureManager
from snowflake.cli._plugins.snowpark.package.anaconda_packages import (
AnacondaPackages,
AnacondaPackagesManager,
)
from snowflake.cli._plugins.snowpark.package.commands import app as package_app
from snowflake.cli._plugins.snowpark.snowpark_project_paths import (
Artefact,
SnowparkProjectPaths,
)
from snowflake.cli._plugins.snowpark.snowpark_shared import (
Expand All @@ -72,7 +73,6 @@
from snowflake.cli.api.console import cli_console
from snowflake.cli.api.constants import (
DEFAULT_SIZE_LIMIT_MB,
ObjectType,
)
from snowflake.cli.api.exceptions import (
NoProjectDefinitionError,
Expand All @@ -88,7 +88,6 @@
from snowflake.cli.api.project.schemas.entities.snowpark_entity import (
FunctionEntityModel,
ProcedureEntityModel,
SnowparkEntityModel,
)
from snowflake.cli.api.project.schemas.project_definition import (
ProjectDefinition,
Expand Down Expand Up @@ -124,11 +123,6 @@
)


SnowparkEntities = Dict[str, SnowparkEntityModel]
StageToArtefactMapping = Dict[str, set[Artefact]]
EntityToImportPathsMapping = Dict[str, set[str]]


@app.command("deploy", requires_connection=True)
@with_project_definition()
def deploy(
Expand Down Expand Up @@ -179,13 +173,13 @@ def deploy(

# Create snowpark entities
with cli_console.phase("Creating Snowpark entities"):
snowpark_manager = SnowparkObjectManager()
snowflake_dependencies = _read_snowflake_requirements_file(
project_paths.snowflake_requirements
)
deploy_status = []
for entity in snowpark_entities.values():
cli_console.step(f"Creating {entity.type} {entity.fqn}")
operation_result = _deploy_single_object(
operation_result = snowpark_manager.deploy_entity(
entity=entity,
existing_objects=existing_objects,
snowflake_dependencies=snowflake_dependencies,
Expand Down Expand Up @@ -280,7 +274,7 @@ def _find_existing_objects(

def _check_if_all_defined_integrations_exists(
om: ObjectManager,
snowpark_entities: Dict[str, FunctionEntityModel | ProcedureEntityModel],
snowpark_entities: SnowparkEntities,
):
existing_integrations = {
i["name"].lower()
Expand All @@ -306,72 +300,6 @@ def _check_if_all_defined_integrations_exists(
)


def _deploy_single_object(
entity: SnowparkEntityModel,
existing_objects: Dict[str, SnowflakeCursor],
snowflake_dependencies: List[str],
entities_to_artifact_map: EntityToImportPathsMapping,
):
object_type = entity.get_type()
is_procedure = isinstance(entity, ProcedureEntityModel)

handler = entity.handler
returns = entity.returns
imports = entity.imports
external_access_integrations = entity.external_access_integrations
runtime_ver = entity.runtime
execute_as_caller = None
if is_procedure:
execute_as_caller = entity.execute_as_caller
replace_object = False

object_exists = entity.entity_id in existing_objects
if object_exists:
replace_object = check_if_replace_is_required(
object_type=object_type,
current_state=existing_objects[entity.entity_id],
handler=handler,
return_type=returns,
snowflake_dependencies=snowflake_dependencies,
external_access_integrations=external_access_integrations,
imports=imports,
stage_artifact_files=entities_to_artifact_map[entity.entity_id],
runtime_ver=runtime_ver,
execute_as_caller=execute_as_caller,
)

if object_exists and not replace_object:
return {
"object": entity.udf_sproc_identifier.identifier_with_arg_names_types_defaults,
"type": str(object_type),
"status": "packages updated",
}

create_or_replace_kwargs = {
"identifier": entity.udf_sproc_identifier,
"handler": handler,
"return_type": returns,
"artifact_files": entities_to_artifact_map[entity.entity_id],
"packages": snowflake_dependencies,
"runtime": entity.runtime,
"external_access_integrations": entity.external_access_integrations,
"secrets": entity.secrets,
"imports": imports,
}
if is_procedure:
create_or_replace_kwargs["execute_as_caller"] = entity.execute_as_caller

manager = ProcedureManager() if is_procedure else FunctionManager()
manager.create_or_replace(**create_or_replace_kwargs)

status = "created" if not object_exists else "definition updated"
return {
"object": entity.udf_sproc_identifier.identifier_with_arg_names_types_defaults,
"type": str(object_type),
"status": status,
}


def _read_snowflake_requirements_file(file_path: SecurePath):
if not file_path.exists():
return []
Expand Down Expand Up @@ -470,46 +398,24 @@ def get_snowpark_entities(
return snowpark_entities


class _SnowparkObject(Enum):
"""This clas is used only for Snowpark execute where choice is limited."""

PROCEDURE = str(ObjectType.PROCEDURE)
FUNCTION = str(ObjectType.FUNCTION)


def _execute_object_method(
method_name: str,
object_type: _SnowparkObject,
**kwargs,
):
if object_type == _SnowparkObject.PROCEDURE:
manager = ProcedureManager()
elif object_type == _SnowparkObject.FUNCTION:
manager = FunctionManager()
else:
raise ClickException(f"Unknown object type: {object_type}")

return getattr(manager, method_name)(**kwargs)


@app.command("execute", requires_connection=True)
def execute(
object_type: _SnowparkObject = ObjectTypeArgument,
object_type: SnowparkObject = ObjectTypeArgument,
execution_identifier: str = execution_identifier_argument(
"procedure/function", "hello(1, 'world')"
),
**options,
) -> CommandResult:
"""Executes a procedure or function in a specified environment."""
cursor = _execute_object_method(
"execute", object_type=object_type, execution_identifier=execution_identifier
cursor = SnowparkObjectManager().execute(
execution_identifier=execution_identifier, object_type=object_type
)
return SingleQueryResult(cursor)


@app.command("list", requires_connection=True)
def list_(
object_type: _SnowparkObject = ObjectTypeArgument,
object_type: SnowparkObject = ObjectTypeArgument,
like: str = LikeOption,
scope: Tuple[str, str] = scope_option(
help_example="`list function --in database my_db`"
Expand All @@ -522,7 +428,7 @@ def list_(

@app.command("drop", requires_connection=True)
def drop(
object_type: _SnowparkObject = ObjectTypeArgument,
object_type: SnowparkObject = ObjectTypeArgument,
identifier: FQN = IdentifierArgument,
**options,
):
Expand All @@ -532,7 +438,7 @@ def drop(

@app.command("describe", requires_connection=True)
def describe(
object_type: _SnowparkObject = ObjectTypeArgument,
object_type: SnowparkObject = ObjectTypeArgument,
identifier: FQN = IdentifierArgument,
**options,
):
Expand Down
Loading

0 comments on commit 41bc76a

Please sign in to comment.