Skip to content

Commit

Permalink
Command groups
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cgorrie committed Sep 9, 2024
1 parent e25d4c0 commit a243164
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 52 deletions.
162 changes: 121 additions & 41 deletions src/snowflake/cli/_plugins/workspace/entity_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,69 +12,134 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from pathlib import Path
from typing import Callable

from click import ClickException
from snowflake.cli._plugins.workspace.manager import WorkspaceManager
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.commands.decorators import with_project_definition
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
from snowflake.cli.api.commands.snow_typer import (
SnowTyper,
SnowTyperCommandData,
SnowTyperFactory,
)
from snowflake.cli.api.entities.common import EntityActions
from snowflake.cli.api.output.types import CommandResult, MessageResult
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.api.project.schemas.entities.entities import (
Entity,
EntityModel,
v2_entity_model_to_entity_map,
v2_entity_model_types_map,
)

logger = logging.getLogger(__name__)


class EntityCommandGroup(SnowTyperFactory):
help: str # noqa: A003
target_id: str
model_type: EntityModel
entity_type: Entity
_tree_path: list[str]
_command_map: dict[str, SnowTyperCommandData]
_subtree_map: dict[str, "EntityCommandGroup"]

def __init__(self, target_id: str, model_type_str: str):
super().__init__(
name=f"@{target_id}",
help=f"Commands to interact with the {target_id} entity defined in {DefinitionManager.BASE_DEFINITION_FILENAME}.",
)
def __init__(
self,
name: str,
target_id: str,
help_text: str | None = None,
tree_path: list[str] = [],
):
super().__init__(name=name, help=help_text)
self.target_id = target_id
self.model_type_str = model_type_str
self.model_type = v2_entity_model_types_map[model_type_str]
self.entity_type = v2_entity_model_to_entity_map[self.model_type]

@property
def supported_actions(self):
return sorted(
[action for action in EntityActions if self.entity_type.supports(action)]
)
self._tree_path = tree_path
self._command_map = {}
self._subtree_map = {}

def command(self, name: str, *args, **kwargs):
"""Assume the first arg is the command name, unlike superclass"""

def decorator(command):
cmd_data = SnowTyperCommandData(command, args=[name, *args], kwargs=kwargs)
self.commands_to_register.append(cmd_data)
self._command_map[name] = cmd_data
return command

return decorator

def create_instance(self) -> SnowTyper:
"""Provides a default help value generated based on sub-groups and commands."""
if not self.help:
subcommands = sorted(
[
*[f"`{x}`" for x in self._subtree_map.keys()],
*self._command_map.keys(),
]
)
self.help = "-> " + ", ".join(subcommands)

return super().create_instance()

def new_subtree(self, atom: str) -> "EntityCommandGroup":
if atom in self._subtree_map:
logger.error("Duplicate subtree attempted to be created: %s", atom)
else:
subtree = EntityCommandGroup(
atom, target_id=self.target_id, tree_path=[*self._tree_path, atom]
)
self._subtree_map[atom] = subtree
self.add_typer(subtree)

return self._subtree_map[atom]

def _get_subtree(self, group_path: list[str]) -> "EntityCommandGroup":
"""
Gets a group subtree factory for a sub-tree of this command group.
Creates groups on-the-fly (i.e. mkdir -p semantics).
"""
subtree = self
for atom in group_path:
if atom in self._subtree_map:
subtree = self._subtree_map[atom]
else:
subtree = subtree.new_subtree(atom)
return subtree

def register_commands(self):
for action in self.supported_actions:
verb = action.value.split("action_")[1]
action_callable = getattr(self.entity_type, action)
def register_command_leaf(
self, name: str, action: EntityActions, action_callable: Callable
):
"""Registers the provided action at the given name"""

@self.command(name)
@with_project_definition()
def _action_executor(**options) -> CommandResult:
# TODO: get args for action and turn into typer options
# TODO: what message result are we returning? do we throw them away for multi-step actions (i.e deps?)
# TODO: how do we know if a command needs connection?

@self.command(verb)
@with_project_definition()
def _action_executor(**options) -> CommandResult:
cli_context = get_cli_context()
ws = WorkspaceManager(
project_definition=cli_context.project_definition,
project_root=cli_context.project_root,
)
# entity = ws.get_entity(self.target_id)
ws.perform_action(self.target_id, action)
return MessageResult(
f"Successfully performed {verb} on {self.target_id}."
)
cli_context = get_cli_context()
ws = WorkspaceManager(
project_definition=cli_context.project_definition,
project_root=cli_context.project_root,
)
# entity = ws.get_entity(self.target_id)
ws.perform_action(self.target_id, action)
return MessageResult(
f"Successfully performed {action.verb} on {self.target_id}."
)

_action_executor.__doc__ = action_callable.__doc__

_action_executor.__doc__ = action_callable.__doc__
def register_command_in_tree(
self, action: EntityActions, action_callable: Callable
):
"""
Recurses into subtrees created on-demand to register
an action based on its command path and implementation.
"""
[*group_path, verb] = action.command_path
subtree = self._get_subtree(group_path)
subtree.register_command_leaf(verb, action, action_callable)


def generate_entity_commands(
Expand All @@ -83,7 +148,8 @@ def generate_entity_commands(
"""
Introspect the current snowflake.yml file, generating @<id> command groups
for each found entity. Throws a fatal ClickException if templating is used
in the basic definition of entities (i.e. the type discriminator field).
in the basic definition of entities (i.e. the type discriminator field) or
the type is not found in the entity types map.
"""
dm = DefinitionManager(str(project_root) if project_root is not None else None)

Expand All @@ -97,6 +163,20 @@ def generate_entity_commands(
f'Cannot parse {DefinitionManager.BASE_DEFINITION_FILENAME}: entity "{target_id}" has unknown type: {model.type}'
)

subgroup = EntityCommandGroup(target_id, model.type)
subgroup.register_commands()
ws.add_typer(subgroup)
tree_group = EntityCommandGroup(
f"@{target_id}",
target_id=target_id,
help_text=f"Commands to interact with the {target_id} entity defined in {DefinitionManager.BASE_DEFINITION_FILENAME}.",
)

model_type = v2_entity_model_types_map[model.type]
entity_type = v2_entity_model_to_entity_map[model_type]
supported_actions = sorted(
[action for action in EntityActions if entity_type.supports(action)]
)
for action in supported_actions:
tree_group.register_command_in_tree(
action, entity_type.get_action_callable(action)
)

ws.add_typer(tree_group)
17 changes: 17 additions & 0 deletions src/snowflake/cli/api/entities/application_package_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ class ApplicationPackageEntity(EntityBase[ApplicationPackageEntityModel]):
A Native App application package.
"""

def action_version_create(self, ctx: ActionContext):
"""
Adds a new patch to the provided version defined in your application package. If the version does not exist, creates a version with patch 0.
"""
pass

def action_version_drop(self, ctx: ActionContext):
"""
Drops a version defined in your application package. Versions can either be passed in as an argument to the command or read from the `manifest.yml` file.
Dropping patches is not allowed.
"""
pass

def action_version_list(self, ctx: ActionContext):
"""Lists all versions defined in an application package."""
pass

def action_bundle(self, ctx: ActionContext):
"""
Prepares a local folder with configured app artifacts.
Expand Down
46 changes: 35 additions & 11 deletions src/snowflake/cli/api/entities/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
from enum import Enum
from typing import Generic, Type, TypeVar, get_args
from typing import Callable, Generic, Type, TypeVar, get_args

from snowflake.cli._plugins.workspace.action_context import ActionContext
from snowflake.cli.api.sql_execution import SqlExecutor


class EntityActions(str, Enum):
BUNDLE = "action_bundle"
DEPLOY = "action_deploy"
DROP = "action_drop"
VALIDATE = "action_validate"
BUNDLE = "bundle"
DEPLOY = "deploy"
DROP = "drop"
VALIDATE = "validate"
VERSION_CREATE = "version_create"
VERSION_DROP = "version_drop"
VERSION_LIST = "version_list"

@property
def verb(self) -> str:
return self.value.replace("_", " ")

@property
def attr_name(self) -> str:
return f"action_{self.value}"

@property
def command_path(self) -> list[str]:
return self.value.split("_")


T = TypeVar("T")
Expand All @@ -35,17 +50,26 @@ def get_entity_model_type(cls) -> Type[T]:
@classmethod
def supports(cls, action: EntityActions) -> bool:
"""
Checks whether this entity supports the given action. An entity is considered to support an action if it implements a method with the action name.
Checks whether this entity supports the given action.
An entity is considered to support an action if it implements a method with the action name.
"""
return callable(getattr(cls, action, None))
return callable(getattr(cls, action.attr_name, None))

@classmethod
def get_action_callable(cls, action: EntityActions) -> Callable:
"""
Returns a generic action callable that is _not_ bound to a particular entity.
"""
attr = getattr(cls, action.attr_name) # raises KeyError
if not callable(attr):
raise ValueError(f"{action} method exists but is not callable")
return attr

def perform(
self, action: EntityActions, action_ctx: ActionContext, *args, **kwargs
):
"""
Performs the requested action.
"""
return getattr(self, action)(action_ctx, *args, **kwargs)
"""Performs the requested action."""
return getattr(self, action.attr_name)(action_ctx, *args, **kwargs)


def get_sql_executor() -> SqlExecutor:
Expand Down

0 comments on commit a243164

Please sign in to comment.