Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-melnacouzi committed Dec 3, 2024
1 parent 0259ede commit 00ef755
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 314 deletions.
37 changes: 12 additions & 25 deletions src/snowflake/cli/_plugins/connection/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import os
from enum import Enum
from functools import lru_cache
from textwrap import dedent
from typing import Dict, Optional, TypeVar
from typing import Any, Dict, Optional

from click.exceptions import ClickException
from snowflake.connector import SnowflakeConnection
Expand Down Expand Up @@ -60,12 +59,9 @@ class UIParameter(Enum):
NA_FEATURE_RELEASE_CHANNELS = "FEATURE_RELEASE_CHANNELS"


T = TypeVar("T")


def get_ui_parameter(
conn: SnowflakeConnection, parameter: UIParameter, default: T
) -> str | T:
conn: SnowflakeConnection, parameter: UIParameter, default: Any
) -> Any:
"""
Returns the value of a single UI parameter.
If the parameter is not found, the default value is returned.
Expand All @@ -76,26 +72,22 @@ def get_ui_parameter(


@lru_cache()
def get_ui_parameters(conn: SnowflakeConnection) -> Dict[UIParameter, str]:
def get_ui_parameters(conn: SnowflakeConnection) -> Dict[UIParameter, Any]:
"""
Returns the UI parameters from the SYSTEM$BOOTSTRAP_DATA_REQUEST function
"""

parameters_to_fetch = sorted([param.value for param in UIParameter])
parameters_to_fetch = [param.value for param in UIParameter]

query = dedent(
f"""
select value['value']::string as PARAM_VALUE, value['name']::string as PARAM_NAME from table(flatten(
input => parse_json(SYSTEM$BOOTSTRAP_DATA_REQUEST()),
path => 'clientParamsInfo'
)) where value['name'] in ('{"', '".join(parameters_to_fetch)}');
"""
)
query = "call system$bootstrap_data_request('CLIENT_PARAMS_INFO')"
*_, cursor = conn.execute_string(query)

*_, cursor = conn.execute_string(query, cursor_class=DictCursor)
json_map = json.loads(cursor.fetchone()[0])

return {
UIParameter(row["PARAM_NAME"]): row["PARAM_VALUE"] for row in cursor.fetchall()
UIParameter(row["name"]): row["value"]
for row in json_map["clientParamsInfo"]
if row["name"] in parameters_to_fetch
}


Expand All @@ -107,12 +99,7 @@ def is_regionless_redirect(conn: SnowflakeConnection) -> bool:
assume it's regionless, as this is true for most production deployments.
"""
try:
return (
get_ui_parameter(
conn, UIParameter.NA_ENABLE_REGIONLESS_REDIRECT, "true"
).lower()
== "true"
)
return get_ui_parameter(conn, UIParameter.NA_ENABLE_REGIONLESS_REDIRECT, True)
except:
log.warning(
"Cannot determine regionless redirect; assuming True.", exc_info=True
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/cli/_plugins/nativeapp/codegen/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,5 @@ def _should_invoke_processors(self):

def _is_enabled(self, processor: ProcessorMapping) -> bool:
if processor.name.lower() == NA_SETUP_PROCESSOR:
return FeatureFlag.ENABLE_NATIVE_APP_PYTHON_SETUP.get_flag_value() is True
return FeatureFlag.ENABLE_NATIVE_APP_PYTHON_SETUP.is_enabled()
return True
16 changes: 5 additions & 11 deletions src/snowflake/cli/_plugins/nativeapp/entities/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,12 @@ def __init__(
self._is_dev_mode = install_method.is_dev_mode
self._metrics = get_cli_context().metrics
self._console = console
connection = get_sql_executor()._conn # noqa: SLF001
self._event_sharing_enabled = (
get_snowflake_facade()
.get_ui_parameter(UIParameter.NA_EVENT_SHARING_V2, "true")
.lower()
== "true"

self._event_sharing_enabled = get_snowflake_facade().get_ui_parameter(
UIParameter.NA_EVENT_SHARING_V2, True
)
self._event_sharing_enforced = (
get_snowflake_facade()
.get_ui_parameter(UIParameter.NA_ENFORCE_MANDATORY_FILTERS, "true")
.lower()
== "true"
self._event_sharing_enforced = get_snowflake_facade().get_ui_parameter(
UIParameter.NA_ENFORCE_MANDATORY_FILTERS, True
)

self._share_mandatory_events = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,17 +831,17 @@ def _get_enable_release_channels_flag(self) -> Optional[bool]:
It retrieves the value from the configuration file and checks that the feature is enabled in the account.
If return value is None, it means do not explicitly set the flag.
"""
feature_flag_from_config = FeatureFlag.ENABLE_RELEASE_CHANNELS.get_flag_value()
feature_flag_from_config = FeatureFlag.ENABLE_RELEASE_CHANNELS.get_value()
feature_enabled_in_account = (
get_snowflake_facade()
.get_ui_parameter(UIParameter.NA_FEATURE_RELEASE_CHANNELS, "enabled")
.lower()
== "enabled"
get_snowflake_facade().get_ui_parameter(
UIParameter.NA_FEATURE_RELEASE_CHANNELS, "ENABLED"
)
== "ENABLED"
)

if feature_flag_from_config is not None and not feature_enabled_in_account:
self._workspace_ctx.console.warning(
f"Cannot use feature flag {FeatureFlag.ENABLE_RELEASE_CHANNELS.name} because release channels are not enabled in the current account."
f"Ignoring feature flag {FeatureFlag.ENABLE_RELEASE_CHANNELS.name} because release channels are not enabled in the current account."
)
return None

Expand Down
41 changes: 5 additions & 36 deletions src/snowflake/cli/_plugins/nativeapp/feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum, unique
from typing import Any, NamedTuple, Optional
from enum import unique

from snowflake.cli.api.config import FEATURE_FLAGS_SECTION_PATH, get_config_value
from snowflake.cli.api.utils.types import try_cast_to_bool


class OptionalBooleanFlag(NamedTuple):
name: str
default: Optional[bool] = None


@unique
class OptionalFeatureFlagMixin(Enum):
"""
Mixin for feature flags that can be enabled, disabled, or unset.
"""

def get_flag_value(self) -> Optional[bool]:
value = self._get_raw_value()
if value is None:
return self.value.default
return try_cast_to_bool(value)

def _get_raw_value(self) -> Any:
return get_config_value(
*FEATURE_FLAGS_SECTION_PATH,
key=self.value.name.lower(),
default=None,
)
from snowflake.cli.api.feature_flags import BooleanFlag, FeatureFlagMixin


@unique
class FeatureFlag(OptionalFeatureFlagMixin):
"""
Enum for Native Apps feature flags.
"""

ENABLE_NATIVE_APP_PYTHON_SETUP = OptionalBooleanFlag(
class FeatureFlag(FeatureFlagMixin):
ENABLE_NATIVE_APP_PYTHON_SETUP = BooleanFlag(
"ENABLE_NATIVE_APP_PYTHON_SETUP", False
)
ENABLE_RELEASE_CHANNELS = OptionalBooleanFlag("ENABLE_RELEASE_CHANNELS", None)
ENABLE_RELEASE_CHANNELS = BooleanFlag("ENABLE_RELEASE_CHANNELS", None)
6 changes: 2 additions & 4 deletions src/snowflake/cli/_plugins/nativeapp/sf_sql_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from contextlib import contextmanager
from textwrap import dedent
from typing import Any, Dict, List, TypeVar
from typing import Any, Dict, List

from snowflake.cli._plugins.connection.util import UIParameter, get_ui_parameter
from snowflake.cli._plugins.nativeapp.constants import SPECIAL_COMMENT
Expand Down Expand Up @@ -588,9 +588,7 @@ def alter_application_package_properties(
f"Failed to update enable_release_channels for application package {package_name}.",
)

T = TypeVar("T")

def get_ui_parameter(self, parameter: UIParameter, default: T) -> str | T:
def get_ui_parameter(self, parameter: UIParameter, default: Any) -> Any:
"""
Returns the value of a single UI parameter.
If the parameter is not found, the default value is returned.
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/cli/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,12 @@ def get_config_value(*path, key: str, default: Optional[Any] = Empty) -> Any:
raise


def get_config_bool_value(*path, key: str, default: Optional[Any] = Empty) -> bool:
value = get_config_value(*path, key=key, default=default)
def get_config_bool_value(*path, key: str, default: Any = Empty) -> bool | None:
value = get_config_value(*path, key=key, default=None)

if value is None and default is not Empty:
return default

try:
return try_cast_to_bool(value)
except ValueError:
Expand Down
19 changes: 15 additions & 4 deletions src/snowflake/cli/api/feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,31 @@

class BooleanFlag(NamedTuple):
name: str
default: bool = False
default: bool | None = False


@unique
class FeatureFlagMixin(Enum):
def is_enabled(self) -> bool:
def get_value(self) -> bool | None:
return get_config_bool_value(
*FEATURE_FLAGS_SECTION_PATH,
key=self.value.name.lower(),
default=self.value.default,
)

def is_disabled(self):
return not self.is_enabled()
def is_enabled(self) -> bool:
return self.get_value() is True

def is_disabled(self) -> bool:
return self.get_value() is False

def is_set(self) -> bool:
return (
get_config_bool_value(
*FEATURE_FLAGS_SECTION_PATH, key=self.value.name.lower(), default=None
)
is not None
)

def env_variable(self):
return get_env_variable_name(*FEATURE_FLAGS_SECTION_PATH, key=self.value.name)
Expand Down
68 changes: 59 additions & 9 deletions tests/api/test_feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,43 @@
from unittest import mock

import pytest
from click import ClickException
from snowflake.cli.api.feature_flags import BooleanFlag, FeatureFlagMixin


class _TestFlags(FeatureFlagMixin):
# Intentional inconsistency between constant and the enum name to make sure there's no strict relation
ENABLED_BY_DEFAULT = BooleanFlag("ENABLED_DEFAULT", True)
DISABLED_BY_DEFAULT = BooleanFlag("DISABLED_DEFAULT", False)
NON_BOOLEAN = BooleanFlag("NON_BOOLEAN", "xys") # type: ignore
NON_BOOLEAN_DEFAULT = BooleanFlag("NON_BOOLEAN", "xys") # type: ignore
NONE_AS_DEFAULT = BooleanFlag("NON_BOOLEAN", "xys") # type: ignore


def test_flag_value_has_to_be_boolean():
with pytest.raises(ClickException):
_TestFlags.NON_BOOLEAN.is_enabled()
def test_flag_value_default_non_boolean():
_TestFlags.NON_BOOLEAN_DEFAULT.is_enabled() is False
_TestFlags.NON_BOOLEAN_DEFAULT.is_disabled() is False
_TestFlags.NON_BOOLEAN_DEFAULT.get_value() == "xys"
_TestFlags.NON_BOOLEAN_DEFAULT.is_set() is True


def test_flag_value_default_is_none():
_TestFlags.NONE_AS_DEFAULT.is_enabled() is False
_TestFlags.NONE_AS_DEFAULT.is_disabled() is False
_TestFlags.NONE_AS_DEFAULT.get_value() is None
_TestFlags.NONE_AS_DEFAULT.is_set() is False


def test_flag_is_enabled():
assert _TestFlags.ENABLED_BY_DEFAULT.is_enabled() is True
assert _TestFlags.ENABLED_BY_DEFAULT.is_disabled() is False
assert _TestFlags.ENABLED_BY_DEFAULT.get_value() is True
assert _TestFlags.ENABLED_BY_DEFAULT.is_set() is False


def test_flag_is_disabled():
assert _TestFlags.DISABLED_BY_DEFAULT.is_enabled() is False
assert _TestFlags.DISABLED_BY_DEFAULT.is_disabled() is True
assert _TestFlags.DISABLED_BY_DEFAULT.get_value() is False
assert _TestFlags.DISABLED_BY_DEFAULT.is_set() is False


def test_flag_env_variable_value():
Expand All @@ -53,13 +66,50 @@ def test_flag_env_variable_value():


@mock.patch("snowflake.cli.api.config.get_config_value")
@pytest.mark.parametrize("value_from_config", [True, False])
def test_flag_from_config_file(mock_get_config_value, value_from_config):
@pytest.mark.parametrize("value_from_config", [True, False, None])
def test_is_enabled_flag_from_config_file(mock_get_config_value, value_from_config):
mock_get_config_value.return_value = value_from_config

assert _TestFlags.DISABLED_BY_DEFAULT.is_enabled() is (value_from_config or False)
mock_get_config_value.assert_called_once_with(
"cli", "features", key="disabled_default", default=None
)


@mock.patch("snowflake.cli.api.config.get_config_value")
@pytest.mark.parametrize("value_from_config", [True, False, None])
def test_is_disabled_flag_from_config_file(mock_get_config_value, value_from_config):
mock_get_config_value.return_value = value_from_config

assert _TestFlags.DISABLED_BY_DEFAULT.is_disabled() is not (
value_from_config or False
)
mock_get_config_value.assert_called_once_with(
"cli", "features", key="disabled_default", default=None
)


@mock.patch("snowflake.cli.api.config.get_config_value")
@pytest.mark.parametrize("value_from_config", [True, False, None])
def test_is_set_flag_from_config_file(mock_get_config_value, value_from_config):
mock_get_config_value.return_value = value_from_config

assert _TestFlags.DISABLED_BY_DEFAULT.is_enabled() is value_from_config
assert _TestFlags.DISABLED_BY_DEFAULT.is_set() is (value_from_config is not None)

mock_get_config_value.assert_called_once_with(
"cli", "features", key="disabled_default", default=None
)


@mock.patch("snowflake.cli.api.config.get_config_value")
@pytest.mark.parametrize("value_from_config", [True, False, None])
def test_get_value_flag_from_config_file(mock_get_config_value, value_from_config):
mock_get_config_value.return_value = value_from_config

assert _TestFlags.DISABLED_BY_DEFAULT.get_value() == (value_from_config or False)

mock_get_config_value.assert_called_once_with(
"cli", "features", key="disabled_default", default=False
"cli", "features", key="disabled_default", default=None
)


Expand Down
4 changes: 2 additions & 2 deletions tests/nativeapp/test_application_package_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_bundle(project_directory):
@mock.patch(f"{APP_PACKAGE_ENTITY}.execute_post_deploy_hooks")
@mock.patch(f"{APP_PACKAGE_ENTITY}.validate_setup_script")
@mock.patch(f"{APPLICATION_PACKAGE_ENTITY_MODULE}.sync_deploy_root_with_stage")
@mock.patch(SQL_FACADE_GET_UI_PARAMETER, return_value="enabled")
@mock.patch(SQL_FACADE_GET_UI_PARAMETER, return_value="ENABLED")
def test_deploy(
mock_get_parameter,
mock_sync,
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_deploy(
mock_validate.assert_called_once()
mock_execute_post_deploy_hooks.assert_called_once_with()
mock_get_parameter.assert_called_once_with(
UIParameter.NA_FEATURE_RELEASE_CHANNELS, "enabled"
UIParameter.NA_FEATURE_RELEASE_CHANNELS, "ENABLED"
)
assert mock_execute.mock_calls == expected

Expand Down
Loading

0 comments on commit 00ef755

Please sign in to comment.