Skip to content

Commit

Permalink
SNOW-1011771: Adding tests for OverrideableOption
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-davwang committed Mar 4, 2024
1 parent d9cdef8 commit e8f7f98
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 110 deletions.
79 changes: 39 additions & 40 deletions src/snowflake/cli/api/commands/flags.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from inspect import signature
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple

import click
import typer
from click import ClickException
from snowflake.cli.api.cli_global_context import cli_context_manager
from snowflake.cli.api.output.formats import OutputFormat

Expand All @@ -30,7 +31,7 @@ def __init__(
self,
default: Any,
*param_decls: str,
mutually_exclusive: Optional[List[str]] = None,
mutually_exclusive: Optional[List[str] | Tuple[str]] = None,
**kwargs,
):
self.default = default
Expand Down Expand Up @@ -59,14 +60,32 @@ def __call__(self, **kwargs) -> typer.models.OptionInfo:
passed_kwargs.pop(non_kwarg, None)
return typer.Option(default, *param_decls, **passed_kwargs)

def _callback_factory(self, callback, mutually_exclusive: List[str]):
mutually_exclusive_names = (
tuple(mutually_exclusive) if mutually_exclusive else None
)
class InvalidCallbackSignature(ClickException):
def __init__(self, callback):
super().__init__(
f"Signature {signature(callback)} is not valid for an OverrideableOption callback function. Must have at most one parameter with each of the following types: (typer.Context, typer.CallbackParam, Any Other Type)"
)

def _callback_factory(
self, callback, mutually_exclusive: Optional[List[str] | Tuple[str]]
):
callback = callback if callback else lambda x: x

# inspect existing_callback to make sure signature is valid
existing_params = signature(callback).parameters
# at most one parameter with each type in [typer.Context, typer.CallbackParam, any other type]
limits = [
lambda x: x == typer.Context,
lambda x: x == typer.CallbackParam,
lambda x: x != typer.Context and x != typer.CallbackParam,
]
for limit in limits:
if len([v for v in existing_params.values() if limit(v.annotation)]) > 1:
raise self.InvalidCallbackSignature(callback)

def generated_callback(ctx: typer.Context, param: typer.CallbackParam, value):
if mutually_exclusive_names:
for name in mutually_exclusive_names:
if mutually_exclusive:
for name in mutually_exclusive:
if value and ctx.params.get(
name, False
): # if the current parameter is set to True and a previous parameter is also Truthy
Expand All @@ -77,38 +96,18 @@ def generated_callback(ctx: typer.Context, param: typer.CallbackParam, value):
raise click.ClickException(
f"Options '{curr_opt}' and '{other_opt}' are incompatible."
)
if callback:
# inspect existing_callback to make sure signature is valid
existing_params = signature(callback).parameters
# at most one parameter with each type in [typer.Context, typer.CallbackParam, any other type]
limits = [
lambda x: x == typer.Context,
lambda x: x == typer.CallbackParam,
lambda x: x != typer.Context and x != typer.CallbackParam,
]
for limit in limits:
if (
len(
[v for v in existing_params.values() if limit(v.annotation)]
)
> 1
):
raise click.ClickException(
f"Signature {signature(callback)} is not valid for an OverrideableOption callback function. Must have at most one parameter with each of the following types: (typer.Context, typer.CallbackParam, Any Other Type)"
)
# pass args to existing_callback based on its signature (this is how Typer infers callback args)
passed_params = {}
for existing_param in existing_params:
annotation = existing_params[existing_param].annotation
if annotation == typer.Context:
passed_params[existing_param] = ctx
elif annotation == typer.CallbackParam:
passed_params[existing_param] = param
else:
passed_params[existing_param] = value
return callback(**passed_params)
else:
return value

# pass args to existing callback based on its signature (this is how Typer infers callback args)
passed_params = {}
for existing_param in existing_params:
annotation = existing_params[existing_param].annotation
if annotation == typer.Context:
passed_params[existing_param] = ctx
elif annotation == typer.CallbackParam:
passed_params[existing_param] = param
else:
passed_params[existing_param] = value
return callback(**passed_params)

return generated_callback

Expand Down
22 changes: 0 additions & 22 deletions tests/api/commands/__snapshots__/test_flags.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,5 @@
│ Options '--option2' and '--option1' are incompatible. │
╰──────────────────────────────────────────────────────────────────────────────╯

'''
# ---
# name: test_overrideable_option_is_overrideable
'''

Usage: - [OPTIONS]

╭─ Options ────────────────────────────────────────────────────────────────────╮
│ --option INTEGER [default: (dynamic)] │
│ --install-completion [bash|zsh|fish|powershe Install completion for │
│ ll|pwsh] the specified shell. │
│ [default: None] │
│ --show-completion [bash|zsh|fish|powershe Show completion for the │
│ ll|pwsh] specified shell, to │
│ copy it or customize │
│ the installation. │
│ [default: None] │
│ --help Show this message and │
│ exit. │
╰──────────────────────────────────────────────────────────────────────────────╯


'''
# ---
112 changes: 97 additions & 15 deletions tests/api/commands/test_flags.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from unittest.mock import Mock, patch

import click.core
import pytest
import typer
from snowflake.cli.api.commands.flags import (
PLAIN_PASSWORD_MSG,
OverrideableOption,
PasswordOption,
)
from typer import Typer
from typer.core import TyperOption
from typer.testing import CliRunner


Expand All @@ -29,18 +34,30 @@ def _(password: str = PasswordOption):
assert PLAIN_PASSWORD_MSG in result.output


def test_overrideable_option_is_overrideable(snapshot):
original = OverrideableOption(1, "--option", help="original help")
app = Typer()
@patch("snowflake.cli.api.commands.flags.typer.Option")
def test_overrideable_option_returns_typer_option(mock_option):
mock_option_info = Mock(spec=typer.models.OptionInfo)
mock_option.return_value = mock_option_info
default = 1
param_decls = ["--option"]
help_message = "help message"

@app.command()
def _(option: int = OverrideableOption(default=2, help="new help")):
return "ok"
option = OverrideableOption(default, *param_decls, help=help_message)()
mock_option.assert_called_once_with(default, *param_decls, help=help_message)
assert option == mock_option_info

runner = CliRunner()
result = runner.invoke(app, ["--help"], catch_exceptions=False)
assert result.exit_code == 0
assert result.output == snapshot

def test_overrideable_option_is_overrideable():
original_param_decls = ("--option",)
original = OverrideableOption(1, *original_param_decls, help="original help")

new_default = 2
new_help = "new help"
modified = original(default=new_default, help=new_help)

assert modified.default == new_default
assert modified.help == new_help
assert modified.param_decls == original_param_decls


@pytest.mark.parametrize("set1, set2", [(False, False), (False, True), (True, False)])
Expand All @@ -55,7 +72,7 @@ def test_mutually_exclusive_options_no_error(set1, set2):

@app.command()
def _(option_1: bool = option1(), option_2: bool = option2()):
return "ok"
pass

command = []
if set1:
Expand All @@ -64,7 +81,6 @@ def _(option_1: bool = option1(), option_2: bool = option2()):
command.append("--option2")
runner = CliRunner()
result = runner.invoke(app, command)
print(result.output)
assert result.exit_code == 0


Expand All @@ -79,7 +95,7 @@ def test_mutually_exclusive_options_error(snapshot):

@app.command()
def _(option_1: bool = option1(), option_2: bool = option2()):
return "ok"
pass

command = ["--option1", "--option2"]
runner = CliRunner()
Expand All @@ -88,5 +104,71 @@ def _(option_1: bool = option1(), option_2: bool = option2()):
assert result.output == snapshot


def test_overrideable_option_invalid_callback_signature():
pass
def test_overrideable_option_callback_passthrough():
def callback(value):
return value + 1

app = Typer()

@app.command()
def _(option: int = OverrideableOption(..., "--option", callback=callback)()):
print(option)

runner = CliRunner()
result = runner.invoke(app, ["--option", "0"])
assert result.exit_code == 0
assert result.output.strip() == "1"


def test_overrideable_option_callback_with_context():
# tests that generated_callback will correctly map ctx and param arguments to the original callback
def callback(value, param: typer.CallbackParam, ctx: typer.Context):
assert isinstance(value, int)
assert isinstance(param, TyperOption)
assert isinstance(ctx, click.core.Context)
return value

app = Typer()

@app.command()
def _(option: int = OverrideableOption(..., "--option", callback=callback)()):
pass

runner = CliRunner()
result = runner.invoke(app, ["--option", "0"])
assert result.exit_code == 0


class _InvalidCallbackSignatureNamespace:
# dummy functions for test_overrideable_option_invalid_callback_signature

# too many parameters
@staticmethod
def callback1(
ctx: typer.Context, param: typer.CallbackParam, value1: int, value2: float
):
pass

# untyped Context and CallbackParam
@staticmethod
def callback2(ctx, param, value):
pass

# multiple untyped values
@staticmethod
def callback3(ctx: typer.Context, value1, value2):
pass


@pytest.mark.parametrize(
"callback",
[
_InvalidCallbackSignatureNamespace.callback1,
_InvalidCallbackSignatureNamespace.callback2,
_InvalidCallbackSignatureNamespace.callback3,
],
)
def test_overrideable_option_invalid_callback_signature(callback):
invalid_callback_option = OverrideableOption(None, "--option", callback=callback)
with pytest.raises(OverrideableOption.InvalidCallbackSignature):
invalid_callback_option()
24 changes: 2 additions & 22 deletions tests/spcs/__snapshots__/test_image_repository.ambr
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# serializer version: 1
# name: test_create_cli[False-False]
# name: test_create_cli
'''
+-----------------------------------------------------------+
| key | value |
Expand All @@ -9,27 +9,7 @@

'''
# ---
# name: test_create_cli[False-True]
'''
+-----------------------------------------------------------+
| key | value |
|--------+--------------------------------------------------|
| status | Image Repository TEST_REPO successfully created. |
+-----------------------------------------------------------+

'''
# ---
# name: test_create_cli[True-False]
'''
+-----------------------------------------------------------+
| key | value |
|--------+--------------------------------------------------|
| status | Image Repository TEST_REPO successfully created. |
+-----------------------------------------------------------+

'''
# ---
# name: test_create_cli_replace_and_if_not_exists
# name: test_create_cli_replace_and_if_not_exists_fails
'''
╭─ Error ──────────────────────────────────────────────────────────────────────╮
│ Options '--if-not-exists' and '--replace' are incompatible. │
Expand Down
14 changes: 3 additions & 11 deletions tests/spcs/test_image_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,26 @@ def test_create_replace_and_if_not_exist():
assert "mutually exclusive" in str(e.value)


@pytest.mark.parametrize(
"replace, if_not_exists",
[(False, False), (True, False), (False, True)],
)
@mock.patch(
"snowflake.cli.plugins.spcs.image_repository.manager.ImageRepositoryManager.create"
)
def test_create_cli(mock_create, mock_cursor, runner, replace, if_not_exists, snapshot):
def test_create_cli(mock_create, mock_cursor, runner, snapshot):
repo_name = "test_repo"
cursor = mock_cursor(
rows=[[f"Image Repository {repo_name.upper()} successfully created."]],
columns=["status"],
)
mock_create.return_value = cursor
command = ["spcs", "image-repository", "create", repo_name]
if replace:
command.append("--replace")
if if_not_exists:
command.append("--if-not-exists")
result = runner.invoke(command)
mock_create.assert_called_once_with(
name=repo_name, replace=replace, if_not_exists=if_not_exists
name=repo_name, replace=False, if_not_exists=False
)
assert result.exit_code == 0, result.output
assert result.output == snapshot


def test_create_cli_replace_and_if_not_exists(runner, snapshot):
def test_create_cli_replace_and_if_not_exists_fails(runner, snapshot):
command = [
"spcs",
"image-repository",
Expand Down

0 comments on commit e8f7f98

Please sign in to comment.