Skip to content

Commit

Permalink
feat(back): Add new GroundingDINO model (#6)
Browse files Browse the repository at this point in the history
* feat(back): Rename segment_anything module to github

This will allow adding new models like GroundingDINO and YOLO-World in the same module

* fix(back): Add type hint for Image PixanoType

from_dict method is defined with a general PixanoType return, add Image type hint to access the class methods

* feat(back): Add GroundingDINO

* chore: Add GroundingDINO requirements

Update torch versions as required by GroundingDINO, and turn them to fixed required (~) to prevent issues
Update tensorflow versions to prevent issues with those models as well

* fix(back): Add prompt arg to other models

* fix(back): Fix preannotate for SAM and MobileSAM

Bounding box coordinates and confidence are already Python types and do not need to be converted from torch with Tensor.item()

* fix(ci): Add GroundingDINO to lint action

* feat(docs): Add GroundingDINO to documentation

* fix(back): Remove top-level GitHub imports to prevent errors

* chore: Update CHANGELOG.md

* fix(back): Fix GroundingDINO imports
  • Loading branch information
cpvannier authored Mar 18, 2024
1 parent 4c89472 commit 7e12737
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 31 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,12 @@ jobs:
python-version: "3.10"

# Install PyTorch and TensorFlow CPU versions manually to prevent installing CUDA
# Install SAM and MobileSAM manually as they cannot be included in PyPI
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pylint
python -m pip install torch~=2.2.0 torchaudio~=2.2.0 torchvision~=0.17.0 --index-url https://download.pytorch.org/whl/cpu
python -m pip install tensorflow-cpu~=2.15.0
python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything
python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM
python -m pip install .
- name: Lint backend code with Pylint
Expand Down
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to Pixano will be documented in this file.

## [Unreleased]

### Added

- Add **new GroundingDINO model** for semantic segmentation with text prompts (pixano/pixano-inference#6)

### Fixed

- Remove top-level imports for GitHub models to prevent import errors (pixano/pixano-inference#6)
- Fix preannotation with SAM and MobileSAM (pixano/pixano-inference#6)
- Add type hints for Image PixanoType (pixano/pixano-inference#6)

## [0.3.0] - 2024-02-29

### Added
Expand All @@ -20,7 +30,7 @@ All notable changes to Pixano will be documented in this file.
- **Breaking:** Remove SAM and MobileSAM dependencies to allow publishing to PyPI (pixano/pixano-inference#14)
- **Breaking:** Update to Pixano 0.5.0
- **Breaking:** Update InferenceModel `id` attribute to `model_id` to stop redefining built-in `id`
- **Breaking:** Update submodule names to `pytorch` and `tensorflow`
- **Breaking:** Update submodule names to `pytorch`, `tensorflow`, and `github`
- Update README with a small header description listing main features and more detailed installation instructions
- Generate API reference on documentation website automatically
- Add cross-references to Pixano, TensorFlow, and Hugging Face Transformers in the API reference
Expand Down
3 changes: 2 additions & 1 deletion docs/getting_started/installing_pixano_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ pip install pixano
pip install pixano-inference
```

To use the inference models available through GitHub, install the following additional packages:
To use the inference models available through GitHub, install the following additional packages as needed:

```shell
python -m pip install segment-anything@git+https://github.com/facebookresearch/segment-anything
python -m pip install mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM
python -m pip install groundingdino@git+https://github.com/IDEA-Research/GroundingDINO
```
2 changes: 2 additions & 0 deletions pixano_inference/github/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
#
# http://www.cecill.info

from .groundingdino import GroundingDINO
from .mobile_sam import MobileSAM
from .sam import SAM

__all__ = [
"SAM",
"GroundingDINO",
"MobileSAM",
]
145 changes: 145 additions & 0 deletions pixano_inference/github/groundingdino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# @Copyright: CEA-LIST/DIASI/SIALV/LVA (2023)
# @Author: CEA-LIST/DIASI/SIALV/LVA <[email protected]>
# @License: CECILL-C
#
# This software is a collaborative computer program whose purpose is to
# generate and explore labeled data for computer vision applications.
# This software is governed by the CeCILL-C license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL-C
# license as circulated by CEA, CNRS and INRIA at the following URL
#
# http://www.cecill.info

from pathlib import Path

import pyarrow as pa
import shortuuid
from pixano.core import BBox, Image
from pixano.models import InferenceModel
from torchvision.ops import box_convert

from pixano_inference.utils import attempt_import


class GroundingDINO(InferenceModel):
"""GroundingDINO Model
Attributes:
name (str): Model name
model_id (str): Model ID
device (str): Model GPU or CPU device
description (str): Model description
model (torch.nn.Module): PyTorch model
checkpoint_path (Path): Model checkpoint path
config_path (Path): Model config path
"""

def __init__(
self,
checkpoint_path: Path,
config_path: Path,
model_id: str = "",
device: str = "cuda",
) -> None:
"""Initialize model
Args:
checkpoint_path (Path): Model checkpoint path (download from https://github.com/IDEA-Research/GroundingDINO)
config_path (Path): Model config path (download from https://github.com/IDEA-Research/GroundingDINO)
model_id (str, optional): Previously used ID, generate new ID if "". Defaults to "".
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cuda".
"""

# Import GroundingDINO
gd_inf = attempt_import(
"groundingdino.util.inference",
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO",
)

super().__init__(
name="GroundingDINO",
model_id=model_id,
device=device,
description="Fom GitHub, GroundingDINO model.",
)

# Model
self.model = gd_inf.load_model(
config_path.as_posix(),
checkpoint_path.as_posix(),
)
self.model.to(self.device)

def preannotate(
self,
batch: pa.RecordBatch,
views: list[str],
uri_prefix: str,
threshold: float = 0.0,
prompt: str = "",
) -> list[dict]:
"""Inference pre-annotation for a batch
Args:
batch (pa.RecordBatch): Input batch
views (list[str]): Dataset views
uri_prefix (str): URI prefix for media files
threshold (float, optional): Confidence threshold. Defaults to 0.0.
prompt (str, optional): Annotation text prompt. Defaults to "".
Returns:
list[dict]: Processed rows
"""

rows = []

# Import GroundingDINO
gd_inf = attempt_import(
"groundingdino.util.inference",
"groundingdino@git+https://github.com/IDEA-Research/GroundingDINO",
)

for view in views:
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix

_, image = gd_inf.load_image(im.path.as_posix())

# Inference
bbox_tensor, logit_tensor, category_list = gd_inf.predict(
model=self.model,
image=image,
caption=prompt,
box_threshold=0.35,
text_threshold=0.25,
)

# Convert bounding boxes from cyxcywh to xywh
bbox_tensor = box_convert(
boxes=bbox_tensor, in_fmt="cxcywh", out_fmt="xywh"
)
bbox_list = [[coord.item() for coord in bbox] for bbox in bbox_tensor]

# Process model outputs
rows.extend(
[
{
"id": shortuuid.uuid(),
"item_id": batch["id"][x].as_py(),
"view_id": view,
"bbox": BBox.from_xywh(
bbox_list[i],
confidence=logit_tensor[i].item(),
).to_dict(),
"category": category_list[i],
}
for i in range(len(category_list))
if logit_tensor[i].item() > threshold
]
)

return rows
45 changes: 35 additions & 10 deletions pixano_inference/github/mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import pyarrow as pa
import shortuuid
import torch
from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from mobile_sam.utils.onnx import SamOnnxModel
from onnxruntime.quantization import QuantType
from onnxruntime.quantization.quantize import quantize_dynamic
from pixano.core import BBox, CompressedRLE, Image
from pixano.models import InferenceModel

from pixano_inference.utils import attempt_import


class MobileSAM(InferenceModel):
"""MobileSAM
Expand Down Expand Up @@ -54,6 +54,11 @@ def __init__(
device (str, optional): Model GPU or CPU device (e.g. "cuda", "cpu"). Defaults to "cpu".
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

super().__init__(
name="Mobile_SAM",
model_id=model_id,
Expand All @@ -62,7 +67,7 @@ def __init__(
)

# Model
self.model = sam_model_registry["vit_t"](checkpoint=checkpoint_path)
self.model = mobile_sam.sam_model_registry["vit_t"](checkpoint=checkpoint_path)
self.model.to(device=self.device)

# Model path
Expand All @@ -74,6 +79,7 @@ def preannotate(
views: list[str],
uri_prefix: str,
threshold: float = 0.0,
prompt: str = "",
) -> list[dict]:
"""Inference pre-annotation for a batch
Expand All @@ -82,25 +88,32 @@ def preannotate(
views (list[str]): Dataset views
uri_prefix (str): URI prefix for media files
threshold (float, optional): Confidence threshold. Defaults to 0.0.
prompt (str, optional): Annotation text prompt. Defaults to "".
Returns:
list[dict]: Processed rows
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

rows = []
_ = prompt # This model does not use prompts

for view in views:
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im = Image.from_dict(batch[view][x].as_py())
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix
im = im.as_cv2()
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

# Inference
with torch.no_grad():
generator = SamAutomaticMaskGenerator(self.model)
generator = mobile_sam.SamAutomaticMaskGenerator(self.model)
output = generator.generate(im)

# Process model outputs
Expand All @@ -112,8 +125,8 @@ def preannotate(
"item_id": batch["id"][x].as_py(),
"view_id": view,
"bbox": BBox.from_xywh(
[coord.item() for coord in output[i]["bbox"]],
confidence=output[i]["predicted_iou"].item(),
[int(coord) for coord in output[i]["bbox"]],
confidence=float(output[i]["predicted_iou"]),
)
.normalize(h, w)
.to_dict(),
Expand Down Expand Up @@ -145,6 +158,11 @@ def precompute_embeddings(
pa.RecordBatch: Embedding rows
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

rows = [
{
"id": batch["id"][x].as_py(),
Expand All @@ -156,14 +174,14 @@ def precompute_embeddings(
# Iterate manually
for x in range(batch.num_rows):
# Preprocess image
im = Image.from_dict(batch[view][x].as_py())
im: Image = Image.from_dict(batch[view][x].as_py())
im.uri_prefix = uri_prefix
im = im.as_cv2()
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

# Inference
with torch.no_grad():
predictor = SamPredictor(self.model)
predictor = mobile_sam.SamPredictor(self.model)
predictor.set_image(im)
img_embedding = predictor.get_image_embedding().cpu().numpy()

Expand All @@ -181,6 +199,11 @@ def export_to_onnx(self, library_dir: Path):
library_dir (Path): Dataset library directory
"""

# Import MobileSAM
mobile_sam = attempt_import(
"mobile_sam", "mobile-sam@git+https://github.com/ChaoningZhang/MobileSAM"
)

# Model directory
model_dir = library_dir / "models"
model_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -189,7 +212,9 @@ def export_to_onnx(self, library_dir: Path):
self.model.to("cpu")

# Export settings
onnx_model = SamOnnxModel(self.model, return_single_mask=True)
onnx_model = mobile_sam.utils.onnx.SamOnnxModel(
self.model, return_single_mask=True
)
dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
Expand Down
Loading

0 comments on commit 7e12737

Please sign in to comment.