Skip to content

Commit

Permalink
Inject an internal _emit_metric method into predictor classes
Browse files Browse the repository at this point in the history
We use metrics internally to transport metadata about how the model ran
out of the model. It's private because it's entirely invisible within
our system and we make no claims of stability for how it works.
  • Loading branch information
erbridge committed Nov 6, 2024
1 parent 4de7f61 commit ed0e233
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 13 deletions.
10 changes: 9 additions & 1 deletion python/cog/base_predictor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from .types import (
File as CogFile,
Expand All @@ -10,6 +10,14 @@


class BasePredictor(ABC):
def __init__(
self,
*,
_emit_metric: Callable[[str, Union[float, int]], None],
) -> None:
# For internal use only.
self._emit_metric = _emit_metric

def setup(
self,
weights: Optional[Union[CogFile, CogPath, str]] = None, # pylint: disable=unused-argument
Expand Down
27 changes: 23 additions & 4 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,38 @@ def load_slim_predictor_from_file(
return module


def get_predictor(module: types.ModuleType, class_name: str) -> Any:
def get_predictor(
module: types.ModuleType,
class_name: str,
*,
emit_metric: Optional[Callable[[str, Union[float, int]], None]] = None,
) -> Any:
predictor = getattr(module, class_name)
# It could be a class or a function
if inspect.isclass(predictor):
return predictor()
if emit_metric is None:

def noop_emit_metric(name: str, value: Union[float, int]) -> None: # pylint: disable=unused-argument
pass

emit_metric = noop_emit_metric

try:
return predictor(_emit_metric=emit_metric)
except TypeError:
return predictor()
return predictor


def load_predictor_from_ref(ref: str) -> BasePredictor:
def load_predictor_from_ref(
ref: str,
*,
emit_metric: Callable[[str, Union[float, int]], None],
) -> BasePredictor:
module_path, class_name = ref.split(":", 1)
module_name = os.path.basename(module_path).split(".py", 1)[0]
module = load_full_predictor_from_file(module_path, module_name)
predictor = get_predictor(module, class_name)
predictor = get_predictor(module, class_name, emit_metric=emit_metric)
return predictor


Expand Down
8 changes: 7 additions & 1 deletion python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Union

from attrs import define, field, validators

Expand Down Expand Up @@ -29,6 +29,12 @@ class Log:
source: str = field(validator=validators.in_(["stdout", "stderr"]))


@define
class PredictionMetric:
name: str
value: Union[float, int]


@define
class PredictionOutput:
payload: Any
Expand Down
25 changes: 20 additions & 5 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import Future
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar, Union

import requests
import structlog
Expand All @@ -17,10 +17,17 @@
from ..json import upload_files
from ..types import PYDANTIC_V2
from .errors import FileUploadError, RunnerBusyError, UnknownPredictionError
from .eventtypes import Done, Log, PredictionOutput, PredictionOutputType
from .eventtypes import (
Done,
Log,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
)

if PYDANTIC_V2:
from .helpers import unwrap_pydantic_serialization_iterators

from .telemetry import current_trace_context
from .useragent import get_user_agent
from .webhook import SKIP_START_EVENT, webhook_caller_filtered
Expand Down Expand Up @@ -348,6 +355,11 @@ def append_logs(self, logs: str) -> None:
self._p.logs += logs
self._send_webhook(schema.WebhookEvent.LOGS)

def set_metric(self, key: str, value: Union[float, int]) -> None:
if self._p.metrics is None:
self._p.metrics = {}
self._p.metrics[key] = value

def succeeded(self) -> None:
self._log.info("prediction succeeded")
self._p.status = schema.Status.SUCCEEDED
Expand All @@ -356,9 +368,10 @@ def succeeded(self) -> None:
# that...
assert self._p.completed_at is not None
assert self._p.started_at is not None
self._p.metrics = {
"predict_time": (self._p.completed_at - self._p.started_at).total_seconds()
}
self.set_metric(
"predict_time",
(self._p.completed_at - self._p.started_at).total_seconds(),
)
self._send_webhook(schema.WebhookEvent.COMPLETED)

def failed(self, error: str) -> None:
Expand All @@ -378,6 +391,8 @@ def handle_event(self, event: _PublicEventType) -> None:
try:
if isinstance(event, Log):
self.append_logs(event.message)
elif isinstance(event, PredictionMetric):
self.set_metric(event.name, event.value)
elif isinstance(event, PredictionOutputType):
self.set_output_type(multi=event.multi)
elif isinstance(event, PredictionOutput):
Expand Down
9 changes: 8 additions & 1 deletion python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Done,
Log,
PredictionInput,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
Shutdown,
Expand Down Expand Up @@ -327,7 +328,10 @@ def send_cancel(self) -> None:
def _setup(self, redirector: AsyncStreamRedirector) -> None:
done = Done()
try:
self._predictor = load_predictor_from_ref(self._predictor_ref)
self._predictor = load_predictor_from_ref(
self._predictor_ref,
emit_metric=self._emit_metric,
)
# Could be a function or a class
if hasattr(self._predictor, "setup"):
run_setup(self._predictor)
Expand Down Expand Up @@ -511,6 +515,9 @@ def _stream_write_hook(self, stream_name: str, data: str) -> None:
else:
self._events.send(Log(data, source="stderr"))

def _emit_metric(self, name: str, value: Union[float, int]) -> None:
self._events.send(PredictionMetric(name, value))


def make_worker(predictor_ref: str, tee_output: bool = True) -> Worker:
parent_conn, child_conn = _spawn.Pipe()
Expand Down
18 changes: 18 additions & 0 deletions python/tests/server/fixtures/emit_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, Optional, Union

from cog.base_predictor import BasePredictor
from cog.types import File as CogFile
from cog.types import Path as CogPath


class Predictor(BasePredictor):
def setup(self, weights: Optional[Union[CogFile, CogPath, str]] = None):
print("did setup")

def predict(self, *, name: str, **kwargs: Any) -> str: # pylint: disable=arguments-differ
print(f"hello, {name}")

assert self._emit_metric
self._emit_metric("foo", 123)

return f"hello, {name}"
40 changes: 39 additions & 1 deletion python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
rule,
)

from cog.server.eventtypes import Done, Log, PredictionOutput, PredictionOutputType
from cog.server.eventtypes import (
Done,
Log,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
)
from cog.server.exceptions import FatalWorkerException, InvalidStateException
from cog.server.worker import Worker, _PublicEventType

Expand Down Expand Up @@ -57,6 +63,16 @@
"missing_predict",
]

METRICS_FIXTURES = [
(
WorkerConfig("emit_metric"),
{"name": ST_NAMES},
{
"foo": 123,
},
),
]

OUTPUT_FIXTURES = [
(
WorkerConfig("hello_world"),
Expand Down Expand Up @@ -112,6 +128,7 @@ class Result:
stdout_lines: List[str] = field(factory=list)
stderr_lines: List[str] = field(factory=list)
heartbeat_count: int = 0
metrics: Optional[Dict[str, Any]] = None
output_type: Optional[PredictionOutputType] = None
output: Any = None
done: Optional[Done] = None
Expand All @@ -133,6 +150,10 @@ def handle_event(self, event: _PublicEventType):
elif isinstance(event, Done):
assert not self.done
self.done = event
elif isinstance(event, PredictionMetric):
if self.metrics is None:
self.metrics = {}
self.metrics[event.name] = event.value
elif isinstance(event, PredictionOutput):
assert self.output_type, "Should get output type before any output"
if self.output_type.multi:
Expand Down Expand Up @@ -215,6 +236,23 @@ def test_stream_redirector_race_condition(worker):
assert not result.done.error


@pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT)
@pytest.mark.parametrize(
"worker,payloads,expected_metrics", METRICS_FIXTURES, indirect=["worker"]
)
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
@given(data=st.data())
def test_metrics(worker, payloads, expected_metrics, data):
"""
We should get the metrics we expect from predictors that emit metrics.
"""
payload = data.draw(st.fixed_dictionaries(payloads))

result = _process(worker, lambda: worker.predict(payload))

assert result.metrics == expected_metrics


@pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT)
@pytest.mark.parametrize(
"worker,payloads,output_generator", OUTPUT_FIXTURES, indirect=["worker"]
Expand Down

0 comments on commit ed0e233

Please sign in to comment.