Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(back): Add new GroundingDINO model #6

Merged
merged 12 commits into from
Mar 18, 2024
3 changes: 0 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,12 @@ jobs:
python-version: "3.10"

# Install PyTorch and TensorFlow CPU versions manually to prevent installing CUDA
# Install SAM and MobileSAM manually as they cannot be included in PyPI
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pylint
python -m pip install torch~=2.2.0 torchaudio~=2.2.0 torchvision~=0.17.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install tensorflow-cpu~=2.15.0
python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything
python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM
python -m pip install .

- name: Lint backend code with Pylint
Expand Down
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to Pixano will be documented in this file.

## [Unreleased]

### Added

- Add **new GroundingDINO model** for semantic segmentation with text prompts (pixano/pixano-inference#6)

### Fixed

- Remove top-level imports for GitHub models to prevent import errors (pixano/pixano-inference#6)
- Fix preannotation with SAM and MobileSAM (pixano/pixano-inference#6)
- Add type hints for Image PixanoType (pixano/pixano-inference#6)

## [0.3.0] - 2024-02-29

### Added
Expand All @@ -20,7 +30,7 @@ All notable changes to Pixano will be documented in this file.
- **Breaking:** Remove SAM and MobileSAM dependencies to allow publishing to PyPI (pixano/pixano-inference#14)
- **Breaking:** Update to Pixano 0.5.0
- **Breaking:** Update InferenceModel `id` attribute to `model_id` to stop redefining built-in `id`
- **Breaking:** Update submodule names to `pytorch` and `tensorflow`
- **Breaking:** Update submodule names to `pytorch`, `tensorflow`, and `github`
- Update README with a small header description listing main features and more detailed installation instructions
- Generate API reference on documentation website automatically
- Add cross-references to Pixano, TensorFlow, and Hugging Face Transformers in the API reference
Expand Down
3 changes: 2 additions & 1 deletion docs/getting_started/installing_pixano_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ pip install pixano
pip install pixano-inference
```

To use the inference models available through GitHub, install the following additional packages:
To use the inference models available through GitHub, install the following additional packages as needed:

```shell
python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything
python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM
python -m pip install groundingdino@git+https://github.com/IDEA-Research/GroundingDINO
```
2 changes: 2 additions & 0 deletions pixano_inference/github/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
#
# http://www.cecill.info

from .groundingdino import GroundingDINO
from .mobile_sam import MobileSAM
from .sam import SAM

__all__ = [
"SAM",
"GroundingDINO",
"MobileSAM",
]
145 changes: 145 additions & 0 deletions pixano_inference/github/groundingdino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023)
# @Author: CEA-LIST/DIASI/SIALV/LVA <[email protected]>
# @License: CECILL-C
#
# This software is a collaborative computer program whose purpose is to
# generate and explore labeled data for computer vision applications.
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
#
# http://www.cecill.info

from pathlib import Path

import pyarrow as pa
import shortuuid
from pixano.core import BBox, Image
from pixano.models import InferenceModel
from torchvision.ops import box_convert

from pixano_inference.utils import attempt_import


class GroundingDINO(InferenceModel):
"""GroundingDINO Model

Attributes:
name (str): Model name
model_id (str): Model ID
device (str): Model GPU or CPU device
description (str): Model description
model (torch.nn.Module): PyTorch model
checkpoint_path (Path): Model checkpoint path
config_path (Path): Model config path
"""

def __init__(
self,
checkpoint_path: Path,
config_path: Path,
model_id: str = "",
device: str = "cuda",
) -> None:
"""Initialize model

Args:
checkpoint_path (Path): Model checkpoint path (download from https://github.com/IDEA-Research/GroundingDINO)
config_path (Path): Model config path (download from https://github.com/IDEA-Research/GroundingDINO)
model_id (str, optional): Previously used ID, generate new ID if "". Defaults to "".
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cuda".
"""

# Import GroundingDINO
gd_inf = attempt_import(
"groundingdino.util.inference",
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO",
)

super().__init__(
name="GroundingDINO",
model_id=model_id,
device=device,
description="Fom GitHub, GroundingDINO model.",
)

# Model
self.model = gd_inf.load_model(
config_path.as_posix(),
checkpoint_path.as_posix(),
)
self.model.to(self.device)

def preannotate(
self,
batch: pa.RecordBatch,
views: list[str],
uri_prefix: str,
threshold: float = 0.0,
prompt: str = "",
) -> list[dict]:
"""Inference pre-annotation for a batch

Args:
batch (pa.RecordBatch): Input batch
views (list[str]): Dataset views
uri_prefix (str): URI prefix for media files
threshold (float, optional): Confidence threshold. Defaults to 0.0.
prompt (str, optional): Annotation text prompt. Defaults to "".

Returns:
list[dict]: Processed rows
"""

rows = []

# Import GroundingDINO
gd_inf = attempt_import(
"groundingdino.util.inference",
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO",
)

for view in views:
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix

_, image = gd_inf.load_image(im.path.as_posix())

# Inference
bbox_tensor, logit_tensor, category_list = gd_inf.predict(
model=self.model,
image=image,
caption=prompt,
box_threshold=0.35,
text_threshold=0.25,
)

# Convert bounding boxes from cyxcywh to xywh
bbox_tensor = box_convert(
boxes=bbox_tensor, in_fmt="cxcywh", out_fmt="xywh"
)
bbox_list = [[coord.item() for coord in bbox] for bbox in bbox_tensor]

# Process model outputs
rows.extend(
[
{
"id": shortuuid.uuid(),
"item_id": batch["id"][x].as_py(),
"view_id": view,
"bbox": BBox.from_xywh(
bbox_list[i],
confidence=logit_tensor[i].item(),
).to_dict(),
"category": category_list[i],
}
for i in range(len(category_list))
if logit_tensor[i].item() > threshold
]
)

return rows
45 changes: 35 additions & 10 deletions pixano_inference/github/mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import pyarrow as pa
import shortuuid
import torch
from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from mobile_sam.utils.onnx import SamOnnxModel
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
from pixano.core import BBox, CompressedRLE, Image
from pixano.models import InferenceModel

from pixano_inference.utils import attempt_import


class MobileSAM(InferenceModel):
"""MobileSAM
Expand Down Expand Up @@ -54,6 +54,11 @@ def __init__(
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cpu".
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

super().__init__(
name="Mobile_SAM",
model_id=model_id,
Expand All @@ -62,7 +67,7 @@ def __init__(
)

# Model
self.model = sam_model_registry["vit_t"](checkpoint=checkpoint_path)
self.model = mobile_sam.sam_model_registry["vit_t"](checkpoint=checkpoint_path)
self.model.to(device=self.device)

# Model path
Expand All @@ -74,6 +79,7 @@ def preannotate(
views: list[str],
uri_prefix: str,
threshold: float = 0.0,
prompt: str = "",
) -> list[dict]:
"""Inference pre-annotation for a batch

Expand All @@ -82,25 +88,32 @@ def preannotate(
views (list[str]): Dataset views
uri_prefix (str): URI prefix for media files
threshold (float, optional): Confidence threshold. Defaults to 0.0.
prompt (str, optional): Annotation text prompt. Defaults to "".

Returns:
list[dict]: Processed rows
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

rows = []
_ = prompt # This model does not use prompts

for view in views:
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im = Image.from_dict(batch[view][x].as_py())
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix
im = im.as_cv2()
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

# Inference
with torch.no_grad():
generator = SamAutomaticMaskGenerator(self.model)
generator = mobile_sam.SamAutomaticMaskGenerator(self.model)
output = generator.generate(im)

# Process model outputs
Expand All @@ -112,8 +125,8 @@ def preannotate(
"item_id": batch["id"][x].as_py(),
"view_id": view,
"bbox": BBox.from_xywh(
[coord.item() for coord in output[i]["bbox"]],
confidence=output[i]["predicted_iou"].item(),
[int(coord) for coord in output[i]["bbox"]],
confidence=float(output[i]["predicted_iou"]),
)
.normalize(h, w)
.to_dict(),
Expand Down Expand Up @@ -145,6 +158,11 @@ def precompute_embeddings(
pa.RecordBatch: Embedding rows
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

rows = [
{
"id": batch["id"][x].as_py(),
Expand All @@ -156,14 +174,14 @@ def precompute_embeddings(
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im = Image.from_dict(batch[view][x].as_py())
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix
im = im.as_cv2()
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

# Inference
with torch.no_grad():
predictor = SamPredictor(self.model)
predictor = mobile_sam.SamPredictor(self.model)
predictor.set_image(im)
img_embedding = predictor.get_image_embedding().cpu().numpy()

Expand All @@ -181,6 +199,11 @@ def export_to_onnx(self, library_dir: Path):
library_dir (Path): Dataset library directory
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

# Model directory
model_dir = library_dir / "models"
model_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -189,7 +212,9 @@ def export_to_onnx(self, library_dir: Path):
self.model.to("cpu")

# Export settings
onnx_model = SamOnnxModel(self.model, return_single_mask=True)
onnx_model = mobile_sam.utils.onnx.SamOnnxModel(
self.model, return_single_mask=True
)
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
Expand Down
Loading