Skip to content

Commit

Permalink
[Fix] fix default cuda config (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jun 28, 2024
1 parent 8e8ddf4 commit 890ae43
Show file tree
Hide file tree
Showing 18 changed files with 70 additions and 73 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR)
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
[![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr)
[![Pypi](https://img.shields.io/badge/pypi-v0.3.0-blue.svg)](https://pypi.org/project/OnnxTR/)
[![Pypi](https://img.shields.io/badge/pypi-v0.3.1-blue.svg)](https://pypi.org/project/OnnxTR/)

> :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project.
Expand Down Expand Up @@ -77,8 +77,8 @@ from onnxtr.models import ocr_predictor, EngineConfig
model = ocr_predictor(
det_arch='fast_base', # detection architecture
reco_arch='vitstr_base', # recognition architecture
det_bs=4, # detection batch size
reco_bs=1024, # recognition batch size
det_bs=2, # detection batch size
reco_bs=512, # recognition batch size
assume_straight_pages=True, # set to `False` if the pages are not straight (rotation, perspective, etc.) (default: True)
straighten_pages=False, # set to `True` if the pages should be straightened before final processing (default: False)
# Preprocessing related parameters
Expand Down Expand Up @@ -151,7 +151,7 @@ general_options.enable_cpu_mem_arena = False
# NOTE: The following would force to run only on the GPU if no GPU is available it will raise an error
# List of strings e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"] or a list of tuples with the provider and its options e.g.
# [("CUDAExecutionProvider", {"device_id": 0}), ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
providers = [("CUDAExecutionProvider", {"device_id": 0})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/
providers = [("CUDAExecutionProvider", {"device_id": 0, "cudnn_conv_algo_search": "DEFAULT"})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/

engine_config = EngineConfig(
session_options=general_options,
Expand Down Expand Up @@ -183,7 +183,7 @@ model = ocr_predictor(det_arch=det_model, reco_arch=reco_model)

## Models architectures

Credits where it's due: this repository is implementing, among others, architectures from published research papers.
Credits where it's due: this repository provides ONNX models for the following architectures, converted from the docTR models:

### Text Detection

Expand Down
5 changes: 1 addition & 4 deletions onnxtr/contrib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Any, List, Optional

import numpy as np
import onnxruntime as ort

from onnxtr.file_utils import requires_package
from onnxtr.utils.data import download_from_url


Expand Down Expand Up @@ -44,9 +44,6 @@ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = Non
-------
Any: the ONNX loaded model
"""
requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.")
import onnxruntime as ort

if not url and not model_path:
raise ValueError("You must provide either a url or a model_path")
onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type]
Expand Down
8 changes: 4 additions & 4 deletions onnxtr/models/classification/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MobileNetV3(Engine):
def __init__(
self,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
cfg: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
Expand All @@ -69,7 +69,7 @@ def _mobilenet_v3(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> MobileNetV3:
# Patch the url
Expand All @@ -81,7 +81,7 @@ def _mobilenet_v3(
def mobilenet_v3_small_crop_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
Expand Down Expand Up @@ -111,7 +111,7 @@ def mobilenet_v3_small_crop_orientation(
def mobilenet_v3_small_page_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
Expand Down
10 changes: 5 additions & 5 deletions onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List
from typing import Any, List, Optional

from onnxtr.models.engine import EngineConfig

Expand All @@ -17,7 +17,7 @@


def _orientation_predictor(
arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any
arch: str, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any
) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
Expand All @@ -26,7 +26,7 @@ def _orientation_predictor(
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
kwargs["batch_size"] = kwargs.get("batch_size", 512 if "crop" in arch else 2)
input_shape = _model.cfg["input_shape"][1:]
predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs),
Expand All @@ -38,7 +38,7 @@ def _orientation_predictor(
def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation",
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> OrientationPredictor:
"""Crop orientation classification architecture.
Expand Down Expand Up @@ -66,7 +66,7 @@ def crop_orientation_predictor(
def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation",
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> OrientationPredictor:
"""Page orientation classification architecture.
Expand Down
10 changes: 5 additions & 5 deletions onnxtr/models/detection/models/differentiable_binarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DBNet(Engine):
def __init__(
self,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
bin_thresh: float = 0.3,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
Expand Down Expand Up @@ -93,7 +93,7 @@ def _dbnet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> DBNet:
# Patch the url
Expand All @@ -105,7 +105,7 @@ def _dbnet(
def db_resnet34(
model_path: str = default_cfgs["db_resnet34"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
Expand Down Expand Up @@ -134,7 +134,7 @@ def db_resnet34(
def db_resnet50(
model_path: str = default_cfgs["db_resnet50"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
Expand Down Expand Up @@ -163,7 +163,7 @@ def db_resnet50(
def db_mobilenet_v3_large(
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
Expand Down
10 changes: 5 additions & 5 deletions onnxtr/models/detection/models/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class FAST(Engine):
def __init__(
self,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
Expand Down Expand Up @@ -92,7 +92,7 @@ def _fast(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> FAST:
if load_in_8_bit:
Expand All @@ -104,7 +104,7 @@ def _fast(
def fast_tiny(
model_path: str = default_cfgs["fast_tiny"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
Expand Down Expand Up @@ -133,7 +133,7 @@ def fast_tiny(
def fast_small(
model_path: str = default_cfgs["fast_small"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
Expand Down Expand Up @@ -162,7 +162,7 @@ def fast_small(
def fast_base(
model_path: str = default_cfgs["fast_base"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
Expand Down
10 changes: 5 additions & 5 deletions onnxtr/models/detection/models/linknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LinkNet(Engine):
def __init__(
self,
model_path: str,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
bin_thresh: float = 0.1,
box_thresh: float = 0.1,
assume_straight_pages: bool = True,
Expand Down Expand Up @@ -94,7 +94,7 @@ def _linknet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> LinkNet:
# Patch the url
Expand All @@ -106,7 +106,7 @@ def _linknet(
def linknet_resnet18(
model_path: str = default_cfgs["linknet_resnet18"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
Expand Down Expand Up @@ -135,7 +135,7 @@ def linknet_resnet18(
def linknet_resnet34(
model_path: str = default_cfgs["linknet_resnet34"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
Expand Down Expand Up @@ -164,7 +164,7 @@ def linknet_resnet34(
def linknet_resnet50(
model_path: str = default_cfgs["linknet_resnet50"]["url"],
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> LinkNet:
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
Expand Down
8 changes: 4 additions & 4 deletions onnxtr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any
from typing import Any, Optional

from .. import detection
from ..engine import EngineConfig
Expand All @@ -29,7 +29,7 @@ def _predictor(
arch: Any,
assume_straight_pages: bool = True,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> DetectionPredictor:
if isinstance(arch, str):
Expand All @@ -48,7 +48,7 @@ def _predictor(

kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 4)
kwargs["batch_size"] = kwargs.get("batch_size", 2)
predictor = DetectionPredictor(
PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
_model,
Expand All @@ -60,7 +60,7 @@ def detection_predictor(
arch: Any = "fast_base",
assume_straight_pages: bool = True,
load_in_8_bit: bool = False,
engine_cfg: EngineConfig = EngineConfig(),
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> DetectionPredictor:
"""Text detection architecture.
Expand Down
6 changes: 3 additions & 3 deletions onnxtr/models/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
{
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"cudnn_conv_algo_search": "DEFAULT",
"do_copy_in_default_stream": True,
},
),
Expand Down Expand Up @@ -87,8 +87,8 @@ class Engine:
**kwargs: additional arguments to be passed to `download_from_url`
"""

def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None:
engine_cfg = engine_cfg or EngineConfig()
def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None:
engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
self.session_options = engine_cfg.session_options
self.providers = engine_cfg.providers
Expand Down
2 changes: 1 addition & 1 deletion onnxtr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
symmetric_pad: bool = True,
detect_orientation: bool = False,
load_in_8_bit: bool = False,
clf_engine_cfg: EngineConfig = EngineConfig(),
clf_engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> None:
self.assume_straight_pages = assume_straight_pages
Expand Down
4 changes: 2 additions & 2 deletions onnxtr/models/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List
from typing import Any, List, Optional

import numpy as np

Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
symmetric_pad: bool = True,
detect_orientation: bool = False,
detect_language: bool = False,
clf_engine_cfg: EngineConfig = EngineConfig(),
clf_engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> None:
self.det_predictor = det_predictor
Expand Down
Loading

0 comments on commit 890ae43

Please sign in to comment.