Skip to content

Commit

Permalink
Enabling a User-defined Config object as input (#20)
Browse files Browse the repository at this point in the history
* api calls now can be configured with strings
* updating whylabs_client
* test validate wih group columns
* enforcing group columns should only contain one element
* putting guardrail for group_column typo
* updating the schema.json file
  • Loading branch information
murilommen authored Mar 30, 2023
1 parent 2177314 commit 599d983
Show file tree
Hide file tree
Showing 21 changed files with 561 additions and 195 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.4
current_version = 0.0.5
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
serialize =
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ jobs:
WHYLABS_API_KEY : ${{ secrets.WHYLABS_API_KEY }}
MONITOR_ID : ${{ secrets.MONITOR_ID }}
ANALYZER_ID : ${{ secrets.ANALYZER_ID }}

DEV_WHYLABS_API_KEY: ${{secrets.DEV_WHYLABS_API_KEY}}
DEV_ORG_ID: ${{secrets.DEV_ORG_ID}}
DEV_DATASET_ID: ${{secrets.DEV_DATASET_ID}}
150 changes: 78 additions & 72 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "whylabs-toolkit"
version = "0.0.4"
version = "0.0.5"
description = "Whylabs CLI and Helpers package."
authors = ["Anthony Naddeo <[email protected]>", "Murilo Mendonca <[email protected]>"]
license = "Apache-2.0 license"
Expand All @@ -10,7 +10,7 @@ include = ["whylabs_toolkit/monitor/schema/schema.json"]

[tool.poetry.dependencies]
python = "^3.8"
whylabs-client = "^0.4.2"
whylabs-client = "^0.4.4"
types-pytz = "^2022.7.1.0"
pydantic = "^1.10.4"
whylogs = "^1.1.26"
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from whylabs_toolkit.monitor.manager import MonitorSetup
from whylabs_toolkit.monitor.models import *
from whylabs_toolkit.helpers.config import UserConfig


@pytest.fixture
Expand All @@ -23,3 +24,13 @@ def existing_monitor_setup() -> MonitorSetup:
monitor_id=os.environ["MONITOR_ID"]
)
return monitor_setup

@pytest.fixture
def user_config() -> UserConfig:
config = UserConfig(
api_key=os.environ["DEV_WHYLABS_API_KEY"],
org_id=os.environ["DEV_ORG_ID"],
dataset_id=os.environ["DEV_DATASET_ID"],
whylabs_host="https://songbird.development.whylabsdev.com/"
)
return config
4 changes: 2 additions & 2 deletions tests/helpers/test_entity_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_change_columns_schema():

update_data_types.update()

assert update_data_types.current_entity_schema["columns"]["temperature"]["dataType"] == "bool"
assert update_data_types.current_entity_schema["columns"]["temperature"]["data_type"] == "bool"

columns_schema = {"temperature": ColumnDataType.fractional}

Expand All @@ -130,7 +130,7 @@ def test_change_columns_schema():

update_data_types.update()

assert update_data_types.current_entity_schema["columns"]["temperature"]["dataType"] == "fractional"
assert update_data_types.current_entity_schema["columns"]["temperature"]["data_type"] == "fractional"


def test_wrong_configuration_on_data_types():
Expand Down
8 changes: 4 additions & 4 deletions tests/helpers/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ def test_update_model_time_period(models_api: ModelsApi) -> None:
update_model_metadata(dataset_id=DATASET_ID, org_id=ORG_ID, time_period="P1D")
model_meta = models_api.get_model(model_id=DATASET_ID, org_id=ORG_ID)

assert model_meta["time_period"].value == "P1D"
assert model_meta["time_period"] == "P1D"

update_model_metadata(dataset_id=DATASET_ID, org_id=ORG_ID, time_period="P1M")
model_meta = models_api.get_model(model_id=DATASET_ID, org_id=ORG_ID)

assert model_meta["time_period"].value == "P1M"
assert model_meta["time_period"] == "P1M"


def test_update_model_type(models_api: ModelsApi) -> None:
update_model_metadata(dataset_id=DATASET_ID, org_id=ORG_ID, model_type="REGRESSION")
model_meta = models_api.get_model(model_id=DATASET_ID, org_id=ORG_ID)

assert model_meta["model_type"].value == "REGRESSION"
assert model_meta["model_type"] == "REGRESSION"

update_model_metadata(dataset_id=DATASET_ID, org_id=ORG_ID, model_type="CLASSIFICATION")
model_meta = models_api.get_model(model_id=DATASET_ID, org_id=ORG_ID)

assert model_meta["model_type"].value == "CLASSIFICATION"
assert model_meta["model_type"] == "CLASSIFICATION"
4 changes: 2 additions & 2 deletions tests/helpers/test_monitor_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
get_model_granularity,
get_monitor_config
)
from whylabs_toolkit.helpers.utils import get_models_api
from whylabs_toolkit.helpers.utils import get_monitor_api
from whylabs_toolkit.monitor.models import Granularity


Expand Down Expand Up @@ -49,7 +49,7 @@
class BaseTestMonitor:
@classmethod
def setup_class(cls) -> None:
api = get_models_api()
api = get_monitor_api()
api.put_monitor(
org_id=ORG_ID,
dataset_id=DATASET_ID,
Expand Down
1 change: 1 addition & 0 deletions tests/helpers/test_smoke.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
def test_import() -> None:
import whylabs_toolkit.helpers.client
from whylabs_toolkit.monitor import MonitorManager, MonitorSetup
13 changes: 13 additions & 0 deletions tests/helpers/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from whylabs_toolkit.helpers.config import UserConfig
from whylabs_toolkit.helpers.utils import get_dataset_profile_api, get_models_api, get_notification_api


def test_get_apis_with_different_config(user_config: UserConfig) -> None:
dataset_api = get_dataset_profile_api(config = user_config)
assert dataset_api.api_client.configuration.api_key["ApiKeyAuth"] == user_config.api_key

models_api = get_models_api(config = user_config)
assert models_api.api_client.configuration.api_key["ApiKeyAuth"] == user_config.api_key

notifications_api = get_notification_api(config = user_config)
assert notifications_api.api_client.configuration.api_key["ApiKeyAuth"] == user_config.api_key
12 changes: 8 additions & 4 deletions tests/monitor/manager/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_dump(self, manager: MonitorManager) -> None:
def test_validate(self, manager: MonitorManager) -> None:
assert manager.validate()

def test_failing_validation(self, monitor_setup) -> None:
def test_failing_validation(self, monitor_setup: MonitorSetup) -> None:
monitor_setup.actions = [EmailRecipient(id="some_long_id", destination="[email protected]")]
monitor_setup.config.mode = "weird_mode"
monitor_setup.config.mode = "weird_mode" # type: ignore
monitor_setup.apply()

manager = MonitorManager(setup=monitor_setup)
Expand Down Expand Up @@ -70,9 +70,13 @@ def setUp(self) -> None:
self.notifications_api = MagicMock()
self.notifications_api.list_notification_actions.return_value = []

self.models_api = MagicMock()
self.monitor_api = MagicMock()

self.monitor_manager = MonitorManager(setup = self.monitor_setup, notifications_api=self.notifications_api, models_api=self.models_api)
self.monitor_manager = MonitorManager(
setup = self.monitor_setup,
notifications_api=self.notifications_api,
monitor_api=self.monitor_api
)


def test_notification_actions_are_updated(self) -> None:
Expand Down
28 changes: 24 additions & 4 deletions tests/monitor/manager/test_monitor_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from whylabs_toolkit.monitor.models import *
from tests.helpers.test_monitor_helpers import BaseTestMonitor
from whylabs_toolkit.monitor.manager.credentials import MonitorCredentials
from whylabs_toolkit.monitor import MonitorSetup
from whylabs_toolkit.helpers.config import UserConfig

def test_set_fixed_dates_baseline(monitor_setup):

def test_set_fixed_dates_baseline(monitor_setup: MonitorSetup) -> None:
monitor_setup.set_fixed_dates_baseline(
start_date=datetime(2023,1,1),
end_date=datetime(2023,1,2)
Expand Down Expand Up @@ -77,10 +80,10 @@ def test_set_and_exclude_columns_keep_state(monitor_setup):


class TestExistingMonitor(BaseTestMonitor):
def test_existing_monitor_monitor_setup_with_id(self, existing_monitor_setup):
def test_existing_monitor_monitor_setup_with_id(self, existing_monitor_setup) -> None:
assert isinstance(existing_monitor_setup.config, StddevConfig)

def test_create_monitor_from_existing_monitor_id(self, existing_monitor_setup):
def test_create_monitor_from_existing_monitor_id(self, existing_monitor_setup) -> None:
assert existing_monitor_setup.monitor.id == os.environ["MONITOR_ID"]

new_credentials = MonitorCredentials(monitor_id="new_monitor_id")
Expand All @@ -91,7 +94,7 @@ def test_create_monitor_from_existing_monitor_id(self, existing_monitor_setup):
assert existing_monitor_setup.monitor.id == "new_monitor_id"
assert existing_monitor_setup.analyzer.id == "new_monitor_id-analyzer"

def test_validate_if_columns_exist_before_setting(existing_monitor_setup):
def test_validate_if_columns_exist_before_setting(existing_monitor_setup: MonitorSetup) -> None:
with pytest.raises(ValueError) as e:
existing_monitor_setup.exclude_target_columns(columns=["test_exclude_column"])
assert e.value == f"test_exclude_column is not present on {existing_monitor_setup.credentials.dataset_id}"
Expand All @@ -101,3 +104,20 @@ def test_validate_if_columns_exist_before_setting(existing_monitor_setup):
assert e.value == f"test_set_column is not present on {existing_monitor_setup.credentials.dataset_id}"


def test_setup_with_passed_in_credentials(user_config: UserConfig) -> None:
monitor_setup = MonitorSetup(
monitor_id="different_id",
config=user_config
)

assert monitor_setup.credentials.org_id == user_config.org_id


def test_setup_with_group_of_columns(monitor_setup) -> None:
monitor_setup.set_target_columns(columns=["group:discrete"])
monitor_setup.exclude_target_columns(columns=["group:output", "other_feature"])
monitor_setup.apply()

def test_setup_with_wrong_group_column_type(monitor_setup) -> None:
with pytest.raises(ValueError):
monitor_setup.set_target_columns(columns=["group:inputs"])
32 changes: 24 additions & 8 deletions whylabs_toolkit/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,39 @@ class ConfigVars(Enum):


class Config:
@staticmethod
def get_whylabs_api_key() -> str:
def get_whylabs_api_key(self) -> str:
return Validations.require(ConfigVars.WHYLABS_API_KEY)

@staticmethod
def get_whylabs_host() -> str:
def get_whylabs_host(self) -> str:
return Validations.get_or_default(ConfigVars.WHYLABS_HOST)

@staticmethod
def get_default_org_id() -> str:
def get_default_org_id(self) -> str:
return Validations.require(ConfigVars.ORG_ID)

@staticmethod
def get_default_dataset_id() -> str:
def get_default_dataset_id(self) -> str:
return Validations.require(ConfigVars.DATASET_ID)


class UserConfig(Config):
def __init__(self, api_key: str, org_id: str, dataset_id: str, whylabs_host: str = ConfigVars.WHYLABS_HOST.value):
self.api_key = api_key
self.whylabs_host = whylabs_host
self.org_id = org_id
self.dataset_id = dataset_id

def get_whylabs_api_key(self) -> str:
return self.api_key

def get_whylabs_host(self) -> str:
return self.whylabs_host

def get_default_org_id(self) -> str:
return self.org_id

def get_default_dataset_id(self) -> str:
return self.dataset_id


class Validations:
@staticmethod
def require(env: ConfigVars) -> str:
Expand Down
9 changes: 7 additions & 2 deletions whylabs_toolkit/helpers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
from whylabs_client.model.time_period import TimePeriod

from whylabs_toolkit.helpers.utils import get_models_api
from whylabs_toolkit.helpers.config import Config

logger = logging.getLogger(__name__)


def update_model_metadata(
dataset_id: str, org_id: Optional[str] = None, time_period: Optional[str] = None, model_type: Optional[str] = None
dataset_id: str,
org_id: Optional[str] = None,
time_period: Optional[str] = None,
model_type: Optional[str] = None,
config: Config = Config(),
) -> None:
"""
Update model attributes like model type and period.
"""
api = get_models_api()
api = get_models_api(config=config)

model_metadata = api.get_model(org_id=org_id, model_id=dataset_id)
logger.debug(f"Updating dataset with current metadata: \n {model_metadata}")
Expand Down
42 changes: 22 additions & 20 deletions whylabs_toolkit/helpers/monitor_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from whylabs_client.exceptions import NotFoundException

from whylabs_toolkit.helpers.config import Config
from whylabs_toolkit.helpers.utils import get_models_api
from whylabs_toolkit.helpers.utils import get_monitor_api, get_models_api
from whylabs_toolkit.monitor.models import Granularity

BASE_ENDPOINT = "https://api.whylabsapp.com"
Expand All @@ -15,37 +15,39 @@
# TODO create deactivate_monitor


def get_monitor_config(org_id: str, dataset_id: str) -> Any:
api = get_models_api()
def get_monitor_config(org_id: str, dataset_id: str, config: Config = Config()) -> Any:
api = get_monitor_api(config=config)
monitor_config = api.get_monitor_config_v3(org_id=org_id, dataset_id=dataset_id)
return monitor_config


def get_monitor(monitor_id: str, org_id: Optional[str], dataset_id: Optional[str]) -> Any:
def get_monitor(monitor_id: str, org_id: Optional[str], dataset_id: Optional[str], config: Config = Config()) -> Any:
if not org_id:
org_id = Config().get_default_org_id()
org_id = config.get_default_org_id()
if not dataset_id:
dataset_id = Config().get_default_dataset_id()
api = get_models_api()
dataset_id = config.get_default_dataset_id()
api = get_monitor_api(config=config)
return api.get_monitor(org_id=org_id, dataset_id=dataset_id, monitor_id=monitor_id)


def get_analyzer_ids(org_id: str, dataset_id: str, monitor_id: str) -> Any:
monitor_config = get_monitor_config(org_id=org_id, dataset_id=dataset_id)
def get_analyzer_ids(org_id: str, dataset_id: str, monitor_id: str, config: Config = Config()) -> Any:
monitor_config = get_monitor_config(org_id=org_id, dataset_id=dataset_id, config=config)
for item in monitor_config["monitors"]:
if item["id"] == monitor_id:
resp = item["analyzerIds"]
return resp


def get_analyzers(monitor_id: str, org_id: Optional[str], dataset_id: Optional[str]) -> List[Any]:
def get_analyzers(
monitor_id: str, org_id: Optional[str], dataset_id: Optional[str], config: Config = Config()
) -> List[Any]:
if not org_id:
org_id = Config().get_default_org_id()
org_id = config.get_default_org_id()
if not dataset_id:
dataset_id = Config().get_default_dataset_id()
api = get_models_api()
dataset_id = config.get_default_dataset_id()
api = get_monitor_api(config=config)
analyzers = []
analyzer_ids = get_analyzer_ids(org_id=org_id, dataset_id=dataset_id, monitor_id=monitor_id)
analyzer_ids = get_analyzer_ids(org_id=org_id, dataset_id=dataset_id, monitor_id=monitor_id, config=config)
if analyzer_ids:
for analyzer in analyzer_ids:
analyzers.append(api.get_analyzer(org_id=org_id, dataset_id=dataset_id, analyzer_id=analyzer))
Expand All @@ -54,8 +56,8 @@ def get_analyzers(monitor_id: str, org_id: Optional[str], dataset_id: Optional[s
raise NotFoundException


def get_model_granularity(org_id: str, dataset_id: str) -> Optional[Granularity]:
api = get_models_api()
def get_model_granularity(org_id: str, dataset_id: str, config: Config = Config()) -> Optional[Granularity]:
api = get_models_api(config=config)
model_meta = api.get_model(org_id=org_id, model_id=dataset_id)

time_period_to_gran = {
Expand All @@ -66,15 +68,15 @@ def get_model_granularity(org_id: str, dataset_id: str) -> Optional[Granularity]
}

for key, value in time_period_to_gran.items():
if key in model_meta["time_period"].value:
if key in model_meta["time_period"]:
return value
return None


def delete_monitor(org_id: str, dataset_id: str, monitor_id: str) -> None:
api = get_models_api()
def delete_monitor(org_id: str, dataset_id: str, monitor_id: str, config: Config = Config()) -> None:
api = get_monitor_api(config=config)
try:
analyzer_ids = get_analyzer_ids(org_id=org_id, dataset_id=dataset_id, monitor_id=monitor_id)
analyzer_ids = get_analyzer_ids(org_id=org_id, dataset_id=dataset_id, monitor_id=monitor_id, config=config)
if analyzer_ids is None:
return
for analyzer_id in analyzer_ids:
Expand Down
Loading

0 comments on commit 599d983

Please sign in to comment.