Skip to content

Commit

Permalink
tests(restapi): add tests for job metrics & mlflowrun id
Browse files Browse the repository at this point in the history
  • Loading branch information
jtsextonMITRE committed Oct 30, 2024
1 parent 7d67b2f commit 2cbbe72
Show file tree
Hide file tree
Showing 6 changed files with 508 additions and 6 deletions.
69 changes: 69 additions & 0 deletions tests/unit/restapi/lib/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,72 @@ def remove_tag(
f"/{V1_ROOT}/{resource_route}/{resource_id}/tags/{tag_id}",
follow_redirects=True,
)


def post_metrics(
client: FlaskClient, job_id: int, metric_name: str, metric_value: float
) -> TestResponse:
"""Remove tag from the resource with the provided unique ID.
Args:
client: The Flask test client.
job_id: The id of the Job to post metrics to.
metric_name: The name of the metric.
metric_value: The value of the metric.
Returns:
The response from the API.
"""

return client.post(
f"/{V1_ROOT}/{V1_JOBS_ROUTE}/{job_id}/metrics",
json={"name": metric_name, "value": metric_value},
)


def post_mlflowrun(
client: FlaskClient, job_id: int, mlflow_run_id: str
) -> TestResponse:
"""Add an mlflow run id to a job.
Args:
client: The Flask test client.
job_id: The id of the Job.
mlflow_run_id: The id of the mlflow run.
Returns:
The response from the API.
"""
payload = {"mlflowRunId": mlflow_run_id}
response = client.post(
f"/{V1_ROOT}/{V1_JOBS_ROUTE}/{job_id}/mlflowRun",
json=payload,
follow_redirects=True,
)
return response


def post_mlflowruns(
client: FlaskClient, mlflowruns: dict[str, Any], registered_jobs: dict[str, Any]
) -> dict[str, Any]:
"""Add mlflow run ids to multiple jobs.
Args:
client: The Flask test client.
mlflowruns: A dictionary mapping job key to mlflow run id.
registered_jobs: A dictionary of registered jobs.
Returns:
The responses from the API.
"""

responses = {}

for key in mlflowruns.keys():
job_id = registered_jobs[key]["id"]
mlflowrun_response = post_mlflowrun(
client=client, job_id=job_id, mlflow_run_id=mlflowruns[key].hex
).get_json()
responses[key] = mlflowrun_response

return responses
4 changes: 2 additions & 2 deletions tests/unit/restapi/lib/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def assert_group_ref_contents_matches_expectations(
assert group["id"] == expected_group_id


def assert_tag_ref_contents_matches_expectations(tags: dict[str, Any]) -> None:
def assert_tag_ref_contents_matches_expectations(tags: list[dict[str, Any]]) -> None:
for tag in tags:
assert isinstance(tag["id"], int)
assert isinstance(tag["name"], str)
Expand Down Expand Up @@ -325,7 +325,7 @@ def assert_retrieving_snapshots_works(
client: FlaskClient,
resource_route: str,
resource_id: int,
expected: dict[str, Any],
expected: list[dict[str, Any]],
) -> None:
"""Assert that retrieving a queue by id works.
Expand Down
194 changes: 194 additions & 0 deletions tests/unit/restapi/lib/mock_mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
from __future__ import annotations

import time
from typing import Any, Optional

import structlog
from structlog.stdlib import BoundLogger

LOGGER: BoundLogger = structlog.stdlib.get_logger()

active_runs: dict[str, list[MockMlflowMetric]] = {}


class MockMlflowClient(object):
def __init__(self) -> None:
LOGGER.info(
"Mocking mlflow.tracking.MlflowClient instance",
)

def get_run(self, id: str) -> MockMlflowRun:
# Note: In Mlflow, this function would usually throw an MlflowException
# if the run id is not found. For simplicity this is left out in favor of
# assuming the run should exist in this mock instance.

LOGGER.info("Mocking MlflowClient.get_run() function")
if id not in active_runs:
active_runs[id] = []

run = MockMlflowRun(id)
metrics: list[MockMlflowMetric] = active_runs[id]
output_metrics: dict[str, MockMlflowMetric] = {}
for metric in metrics:
# find the latest metric for each metric name
if (
metric.key not in output_metrics
or metric.timestamp > output_metrics[metric.key].timestamp
):
output_metrics[metric.key] = metric

# remove step and timestamp information
for output in output_metrics.keys():
run.data.metrics[output] = output_metrics[output].value
return run

def log_metric(
self, id: str, key: str, value: float, timestamp: Optional[int] = None
):
if id not in active_runs:
active_runs[id] = []
if timestamp is None:
timestamp = time.time_ns() // 1000000
active_runs[id] += [
MockMlflowMetric(
key=key,
value=value,
step=self.get_step_for(id, key),
timestamp=timestamp,
)
]

def get_metric_history(self, run_id: str, key: str):
return [metric for metric in active_runs[run_id] if metric.key == key]

def get_step_for(self, id: str, metric: str):
metric_steps = [run.step for run in active_runs[id] if run.key == metric]
return 0 if metric_steps == [] else max(metric_steps) + 1


class MockMlflowRun(object):
def __init__(
self,
id: str,
) -> None:
LOGGER.info("Mocking mlflow.entities.Run class")
self._id = id
self.data = MockMlflowRunData()

@property
def id(self) -> str:
LOGGER.info("Mocking mlflow.entities.Run.id getter")
return self._id

@id.setter
def id(self, value: str) -> None:
LOGGER.info("Mocking mlflow.entities.Run.id setter", value=value)
self._id = value

@property
def data(self) -> MockMlflowRunData:
LOGGER.info("Mocking mlflow.entities.Run.data getter")
return self._data

@data.setter
def data(self, value: MockMlflowRunData) -> None:
LOGGER.info("Mocking mlflow.entities.Run.data setter", value=value)
self._data = value


class MockMlflowRunData(object):
def __init__(
self,
) -> None:
LOGGER.info("Mocking mlflow.entities.RunData class")
self._metrics: dict[str, Any] = {}

@property
def metrics(self) -> dict[str, Any]:
LOGGER.info("Mocking mlflow.entities.RunData.metrics getter")
return self._metrics

@metrics.setter
def metrics(self, value: dict[str, Any]) -> None:
LOGGER.info("Mocking mlflow.entities.RunData.metrics setter", value=value)
self._metrics = value


class MockMlflowMetric(object):
def __init__(
self,
key: str,
value: float,
step: int,
timestamp: int,
) -> None:
LOGGER.info("Mocking mlflow.entities.Metric class")
self._key = key
self._value = value
self._step = step
self._timestamp = timestamp

@property
def key(self) -> str:
LOGGER.info("Mocking mlflow.entities.Metric.key getter")
return self._key

@key.setter
def key(self, value: str) -> None:
LOGGER.info("Mocking mlflow.entities.Metric.key setter", value=value)
self._key = value

@property
def value(self) -> float:
LOGGER.info("Mocking mlflow.entities.Metric.value getter")
return self._value

@value.setter
def value(self, value: float) -> None:
LOGGER.info("Mocking mlflow.entities.Metric.value setter", value=value)
self._value = value

@property
def step(self) -> int:
LOGGER.info("Mocking mlflow.entities.Metric.step getter")
return self._step

@step.setter
def step(self, value: int) -> None:
LOGGER.info("Mocking mlflow.entities.Metric.step setter", value=value)
self._step = value

@property
def timestamp(self) -> int:
LOGGER.info("Mocking mlflow.entities.Metric.timestamp getter")
return self._timestamp

@timestamp.setter
def timestamp(self, value: int) -> None:
LOGGER.info("Mocking mlflow.entities.Metric.timestamp setter", value=value)
self._timestamp = value


class MockMlflowException(Exception):
def __init__(
self,
text: str,
) -> None:
LOGGER.info("Mocking mlflow.exceptions.MlflowException class")
super().__init__(text)
2 changes: 0 additions & 2 deletions tests/unit/restapi/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# https://creativecommons.org/licenses/by/4.0/legalcode
from __future__ import annotations

import pytest

from dioptra.restapi.utils import find_non_unique


Expand Down
42 changes: 41 additions & 1 deletion tests/unit/restapi/v1/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Fixtures representing resources needed for test suites"""
import textwrap
from collections.abc import Iterator
from pathlib import Path
from typing import Any, cast

import pytest
Expand Down Expand Up @@ -684,3 +683,44 @@ def registered_jobs(
"job2": job2_response,
"job3": job3_response,
}


@pytest.fixture
def registered_mlflowrun(
client: FlaskClient,
db: SQLAlchemy,
auth_account: dict[str, Any],
registered_jobs: dict[str, Any],
) -> dict[str, Any]:
# Inline import necessary to prevent circular import
import uuid

mlflowruns = {"job1": uuid.uuid4(), "job2": uuid.uuid4(), "job3": uuid.uuid4()}

responses = actions.post_mlflowruns(
client=client, mlflowruns=mlflowruns, registered_jobs=registered_jobs
)

return responses


@pytest.fixture
def registered_mlflowrun_incomplete(
client: FlaskClient,
db: SQLAlchemy,
auth_account: dict[str, Any],
registered_jobs: dict[str, Any],
) -> dict[str, Any]:
# Inline import necessary to prevent circular import
import uuid

mlflowruns = {
"job1": uuid.uuid4(),
"job2": uuid.uuid4(),
} # leave job3 out so we can use that in test_mlflowrun()

responses = actions.post_mlflowruns(
client=client, mlflowruns=mlflowruns, registered_jobs=registered_jobs
)

return responses
Loading

0 comments on commit 2cbbe72

Please sign in to comment.