Skip to content

Commit

Permalink
Use a scope contextvar to to record metrics during a prediction
Browse files Browse the repository at this point in the history
We use metrics internally to transport metadata about how the model ran
out of the model. It's private because it's entirely invisible within
our system and we make no claims of stability for how it works.
  • Loading branch information
erbridge committed Nov 7, 2024
1 parent 4de7f61 commit 6cb53a0
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 10 deletions.
12 changes: 11 additions & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from pydantic import BaseModel

from .base_predictor import BasePredictor
from .types import ConcatenateIterator, File, Input, Path, Secret
from .server.scope import current_scope
from .types import (
ConcatenateIterator,
ExperimentalFeatureWarning,
File,
Input,
Path,
Secret,
)

try:
from ._version import __version__
Expand All @@ -14,6 +22,8 @@
"BaseModel",
"BasePredictor",
"ConcatenateIterator",
"current_scope",
"ExperimentalFeatureWarning",
"File",
"Input",
"Path",
Expand Down
8 changes: 7 additions & 1 deletion python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Union

from attrs import define, field, validators

Expand Down Expand Up @@ -29,6 +29,12 @@ class Log:
source: str = field(validator=validators.in_(["stdout", "stderr"]))


@define
class PredictionMetric:
name: str
value: Union[float, int]


@define
class PredictionOutput:
payload: Any
Expand Down
24 changes: 19 additions & 5 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import Future
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, TypeVar, Union

import requests
import structlog
Expand All @@ -17,7 +17,13 @@
from ..json import upload_files
from ..types import PYDANTIC_V2
from .errors import FileUploadError, RunnerBusyError, UnknownPredictionError
from .eventtypes import Done, Log, PredictionOutput, PredictionOutputType
from .eventtypes import (
Done,
Log,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
)

if PYDANTIC_V2:
from .helpers import unwrap_pydantic_serialization_iterators
Expand Down Expand Up @@ -348,6 +354,11 @@ def append_logs(self, logs: str) -> None:
self._p.logs += logs
self._send_webhook(schema.WebhookEvent.LOGS)

def set_metric(self, key: str, value: Union[float, int]) -> None:
if self._p.metrics is None:
self._p.metrics = {}
self._p.metrics[key] = value

def succeeded(self) -> None:
self._log.info("prediction succeeded")
self._p.status = schema.Status.SUCCEEDED
Expand All @@ -356,9 +367,10 @@ def succeeded(self) -> None:
# that...
assert self._p.completed_at is not None
assert self._p.started_at is not None
self._p.metrics = {
"predict_time": (self._p.completed_at - self._p.started_at).total_seconds()
}
self.set_metric(
"predict_time",
(self._p.completed_at - self._p.started_at).total_seconds(),
)
self._send_webhook(schema.WebhookEvent.COMPLETED)

def failed(self, error: str) -> None:
Expand All @@ -378,6 +390,8 @@ def handle_event(self, event: _PublicEventType) -> None:
try:
if isinstance(event, Log):
self.append_logs(event.message)
elif isinstance(event, PredictionMetric):
self.set_metric(event.name, event.value)
elif isinstance(event, PredictionOutputType):
self.set_output_type(multi=event.multi)
elif isinstance(event, PredictionOutput):
Expand Down
39 changes: 39 additions & 0 deletions python/cog/server/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import warnings
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Callable, Generator, Optional, Union

from ..types import ExperimentalFeatureWarning


class Scope:
def __init__(
self,
*,
record_metric: Callable[[str, Union[float, int]], None],
) -> None:
self.record_metric = record_metric


_current_scope: ContextVar[Optional[Scope]] = ContextVar("scope", default=None)


def current_scope() -> Scope:
warnings.warn(
"current_scope is an experimental internal function. It may change or be removed without warning.",
category=ExperimentalFeatureWarning,
stacklevel=1,
)
s = _current_scope.get()
if s is None:
raise RuntimeError("No scope available")
return s


@contextmanager
def scope(sc: Scope) -> Generator[None, None, None]:
s = _current_scope.set(sc)
try:
yield
finally:
_current_scope.reset(s)
12 changes: 10 additions & 2 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Done,
Log,
PredictionInput,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
Shutdown,
Expand All @@ -36,6 +37,7 @@
InvalidStateException,
)
from .helpers import AsyncStreamRedirector, StreamRedirector
from .scope import Scope, scope

if PYDANTIC_V2:
from .helpers import unwrap_pydantic_serialization_iterators
Expand Down Expand Up @@ -324,6 +326,9 @@ def send_cancel(self) -> None:
if self.is_alive() and self.pid:
os.kill(self.pid, signal.SIGUSR1)

def record_metric(self, name: str, value: Union[float, int]) -> None:
self._events.send(PredictionMetric(name, value))

def _setup(self, redirector: AsyncStreamRedirector) -> None:
done = Done()
try:
Expand Down Expand Up @@ -360,7 +365,7 @@ def _loop(
predict: Callable[..., Any],
redirector: StreamRedirector,
) -> None:
with redirector:
with scope(self._loop_scope()), redirector:
while True:
ev = self._events.recv()
if isinstance(ev, Cancel):
Expand All @@ -383,7 +388,7 @@ async def _aloop(

task = None

with redirector:
with scope(self._loop_scope()), redirector:
while True:
ev = await self._events.recv()
if isinstance(ev, Cancel) and task and self._cancelable:
Expand Down Expand Up @@ -448,6 +453,9 @@ async def _apredict(
self._events.send(PredictionOutputType(multi=False))
self._events.send(PredictionOutput(payload=make_encodeable(output)))

def _loop_scope(self) -> Scope:
return Scope(record_metric=self.record_metric)

@contextlib.contextmanager
def _handle_predict_error(
self, redirector: Union[AsyncStreamRedirector, StreamRedirector]
Expand Down
4 changes: 4 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
FILENAME_MAX_LENGTH = 200


class ExperimentalFeatureWarning(Warning):
pass


class CogConfig(TypedDict): # pylint: disable=too-many-ancestors
build: "CogBuildConfig"
image: NotRequired[str]
Expand Down
13 changes: 13 additions & 0 deletions python/tests/server/fixtures/record_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from cog import current_scope


class Predictor:
def setup(self):
print("did setup")

def predict(self, *, name: str) -> str:
print(f"hello, {name}")

current_scope().record_metric("foo", 123)

return f"hello, {name}"
40 changes: 39 additions & 1 deletion python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
rule,
)

from cog.server.eventtypes import Done, Log, PredictionOutput, PredictionOutputType
from cog.server.eventtypes import (
Done,
Log,
PredictionMetric,
PredictionOutput,
PredictionOutputType,
)
from cog.server.exceptions import FatalWorkerException, InvalidStateException
from cog.server.worker import Worker, _PublicEventType

Expand Down Expand Up @@ -57,6 +63,16 @@
"missing_predict",
]

METRICS_FIXTURES = [
(
WorkerConfig("record_metric"),
{"name": ST_NAMES},
{
"foo": 123,
},
),
]

OUTPUT_FIXTURES = [
(
WorkerConfig("hello_world"),
Expand Down Expand Up @@ -112,6 +128,7 @@ class Result:
stdout_lines: List[str] = field(factory=list)
stderr_lines: List[str] = field(factory=list)
heartbeat_count: int = 0
metrics: Optional[Dict[str, Any]] = None
output_type: Optional[PredictionOutputType] = None
output: Any = None
done: Optional[Done] = None
Expand All @@ -133,6 +150,10 @@ def handle_event(self, event: _PublicEventType):
elif isinstance(event, Done):
assert not self.done
self.done = event
elif isinstance(event, PredictionMetric):
if self.metrics is None:
self.metrics = {}
self.metrics[event.name] = event.value
elif isinstance(event, PredictionOutput):
assert self.output_type, "Should get output type before any output"
if self.output_type.multi:
Expand Down Expand Up @@ -215,6 +236,23 @@ def test_stream_redirector_race_condition(worker):
assert not result.done.error


@pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT)
@pytest.mark.parametrize(
"worker,payloads,expected_metrics", METRICS_FIXTURES, indirect=["worker"]
)
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
@given(data=st.data())
def test_metrics(worker, payloads, expected_metrics, data):
"""
We should get the metrics we expect from predictors that emit metrics.
"""
payload = data.draw(st.fixed_dictionaries(payloads))

result = _process(worker, lambda: worker.predict(payload))

assert result.metrics == expected_metrics


@pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT)
@pytest.mark.parametrize(
"worker,payloads,output_generator", OUTPUT_FIXTURES, indirect=["worker"]
Expand Down

0 comments on commit 6cb53a0

Please sign in to comment.