-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #553 from roboflow/custom-metadata-workflow-block
Custom Metadata block for Model Monitoring users
- Loading branch information
Showing
8 changed files
with
548 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "0.15.1" | ||
__version__ = "0.15.2" | ||
|
||
|
||
if __name__ == "__main__": | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,3 +32,4 @@ | |
CLASS_NAME_KEY = "class" | ||
POLYGON_KEY = "points" | ||
TRACKER_ID_KEY = "tracker_id" | ||
INFERENCE_ID_KEY = "inference_id" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
213 changes: 213 additions & 0 deletions
213
inference/core/workflows/core_steps/sinks/roboflow/roboflow_custom_metadata.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import hashlib | ||
from functools import partial | ||
from typing import List, Literal, Optional, Type, Union | ||
|
||
import numpy as np | ||
import supervision as sv | ||
from fastapi import BackgroundTasks | ||
from pydantic import ConfigDict, Field | ||
|
||
from inference.core.cache.base import BaseCache | ||
from inference.core.roboflow_api import add_custom_metadata, get_roboflow_workspace | ||
from inference.core.workflows.constants import INFERENCE_ID_KEY | ||
from inference.core.workflows.entities.base import Batch, OutputDefinition | ||
from inference.core.workflows.entities.types import ( | ||
BATCH_OF_CLASSIFICATION_PREDICTION_KIND, | ||
BATCH_OF_INSTANCE_SEGMENTATION_PREDICTION_KIND, | ||
BATCH_OF_KEYPOINT_DETECTION_PREDICTION_KIND, | ||
BATCH_OF_OBJECT_DETECTION_PREDICTION_KIND, | ||
BOOLEAN_KIND, | ||
STRING_KIND, | ||
StepOutputSelector, | ||
WorkflowParameterSelector, | ||
) | ||
from inference.core.workflows.prototypes.block import ( | ||
BlockResult, | ||
WorkflowBlock, | ||
WorkflowBlockManifest, | ||
) | ||
|
||
SHORT_DESCRIPTION = "Add custom metadata to Roboflow Model Monitoring dashboard" | ||
|
||
LONG_DESCRIPTION = """ | ||
Block allows users to add custom metadata to each inference result in Roboflow Model Monitoring dashboard. | ||
This is useful for adding information specific to your use case. For example, if you want to be able to | ||
filter inferences by a specific label such as location, you can attach those labels to each inference with this block. | ||
For more information on Model Monitoring at Roboflow, see https://docs.roboflow.com/deploy/model-monitoring. | ||
""" | ||
|
||
WORKSPACE_NAME_CACHE_EXPIRE = 900 # 15 min | ||
|
||
|
||
class BlockManifest(WorkflowBlockManifest): | ||
model_config = ConfigDict( | ||
json_schema_extra={ | ||
"short_description": SHORT_DESCRIPTION, | ||
"long_description": LONG_DESCRIPTION, | ||
"license": "Apache-2.0", | ||
"block_type": "sink", | ||
} | ||
) | ||
type: Literal["RoboflowCustomMetadata"] | ||
predictions: StepOutputSelector( | ||
kind=[ | ||
BATCH_OF_OBJECT_DETECTION_PREDICTION_KIND, | ||
BATCH_OF_INSTANCE_SEGMENTATION_PREDICTION_KIND, | ||
BATCH_OF_KEYPOINT_DETECTION_PREDICTION_KIND, | ||
BATCH_OF_CLASSIFICATION_PREDICTION_KIND, | ||
] | ||
) = Field( | ||
description="Reference data to extract property from", | ||
examples=["$steps.my_step.predictions"], | ||
) | ||
field_value: StepOutputSelector(kind=[STRING_KIND]) = Field( | ||
description="This is the name of the metadata field you are creating", | ||
examples=["toronto", "pass", "fail"], | ||
) | ||
field_name: str = Field( | ||
description="Name of the field to be updated in Roboflow Customer Metadata", | ||
examples=["The name of the value of the field"], | ||
) | ||
fire_and_forget: Union[bool, WorkflowParameterSelector(kind=[BOOLEAN_KIND])] = ( | ||
Field( | ||
default=True, | ||
description="Boolean flag dictating if sink is supposed to be executed in the background, " | ||
"not waiting on status of registration before end of workflow run. Use `True` if best-effort " | ||
"registration is needed, use `False` while debugging and if error handling is needed", | ||
) | ||
) | ||
|
||
@classmethod | ||
def accepts_batch_input(cls) -> bool: | ||
return True | ||
|
||
@classmethod | ||
def describe_outputs(cls) -> List[OutputDefinition]: | ||
return [ | ||
OutputDefinition(name="error_status", kind=[BOOLEAN_KIND]), | ||
OutputDefinition(name="message", kind=[STRING_KIND]), | ||
OutputDefinition( | ||
name="predictions", kind=[BATCH_OF_OBJECT_DETECTION_PREDICTION_KIND] | ||
), | ||
] | ||
|
||
|
||
class RoboflowCustomMetadataBlock(WorkflowBlock): | ||
|
||
def __init__( | ||
self, | ||
cache: BaseCache, | ||
api_key: Optional[str], | ||
background_tasks: Optional[BackgroundTasks], | ||
): | ||
self._api_key = api_key | ||
self._cache = cache | ||
self._background_tasks = background_tasks | ||
|
||
@classmethod | ||
def get_init_parameters(cls) -> List[str]: | ||
return ["api_key", "cache", "background_tasks"] | ||
|
||
@classmethod | ||
def get_manifest(cls) -> Type[WorkflowBlockManifest]: | ||
return BlockManifest | ||
|
||
async def run( | ||
self, | ||
fire_and_forget: bool, | ||
field_name: str, | ||
field_value: Batch[str], | ||
predictions: Batch[sv.Detections], | ||
) -> BlockResult: | ||
if self._api_key is None: | ||
raise ValueError( | ||
"RoboflowCustomMetadata block cannot run without Roboflow API key. " | ||
"If you do not know how to get API key - visit " | ||
"https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key to learn how to " | ||
"retrieve one." | ||
) | ||
inference_ids: List[np.ndarray] = [p[INFERENCE_ID_KEY] for p in predictions] | ||
if len(inference_ids) == 0: | ||
return [ | ||
{ | ||
"error_status": True, | ||
"predictions": predictions, | ||
"message": "Custom metadata upload failed because no inference_ids were received", | ||
} | ||
] | ||
inference_ids: List[str] = list(set(np.concatenate(inference_ids).tolist())) | ||
if field_name is None: | ||
return [ | ||
{ | ||
"error_status": True, | ||
"predictions": predictions, | ||
"message": "Custom metadata upload failed because no field_name was inputted", | ||
} | ||
] | ||
if field_value is None or len(field_value) == 0: | ||
return [ | ||
{ | ||
"error_status": True, | ||
"predictions": predictions, | ||
"message": "Custom metadata upload failed because no field_value was received", | ||
} | ||
] | ||
registration_task = partial( | ||
add_custom_metadata_request, | ||
cache=self._cache, | ||
api_key=self._api_key, | ||
inference_ids=inference_ids, | ||
field_name=field_name, | ||
field_value=field_value[0], | ||
) | ||
if fire_and_forget and self._background_tasks: | ||
self._background_tasks.add_task(registration_task) | ||
else: | ||
registration_task() | ||
return [ | ||
{ | ||
"error_status": False, | ||
"predictions": predictions, | ||
"message": "Custom metadata upload was successful", | ||
} | ||
] | ||
|
||
|
||
def get_workspace_name( | ||
api_key: str, | ||
cache: BaseCache, | ||
) -> str: | ||
api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest() | ||
cache_key = f"workflows:api_key_to_workspace:{api_key_hash}" | ||
cached_workspace_name = cache.get(cache_key) | ||
if cached_workspace_name: | ||
return cached_workspace_name | ||
workspace_name_from_api = get_roboflow_workspace(api_key=api_key) | ||
cache.set( | ||
key=cache_key, value=workspace_name_from_api, expire=WORKSPACE_NAME_CACHE_EXPIRE | ||
) | ||
return workspace_name_from_api | ||
|
||
|
||
def add_custom_metadata_request( | ||
cache: BaseCache, | ||
api_key: str, | ||
inference_ids: List[str], | ||
field_name: str, | ||
field_value: str, | ||
): | ||
workspace_id = get_workspace_name(api_key=api_key, cache=cache) | ||
was_added = False | ||
try: | ||
was_added = add_custom_metadata( | ||
api_key=api_key, | ||
workspace_id=workspace_id, | ||
inference_ids=inference_ids, | ||
field_name=field_name, | ||
field_value=field_value, | ||
) | ||
except Exception as e: | ||
pass | ||
return was_added |
Oops, something went wrong.