Skip to content

Commit

Permalink
resubmitting the PR for later merging
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurKantor committed May 31, 2023
1 parent e909a8d commit 2c2f7b8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 19 deletions.
19 changes: 9 additions & 10 deletions doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

__all__ = ["DetectionPredictor"]

from doctr.utils.gpu import select_gpu_device


class DetectionPredictor(nn.Module):
"""Implements an object able to localize text elements in a document
Expand All @@ -27,29 +29,26 @@ def __init__(
pre_processor: PreProcessor,
model: nn.Module,
) -> None:

super().__init__()
self.model = model.eval()
self.pre_processor = pre_processor
self.postprocessor = self.model.postprocessor
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.environ.get("CUDA_VISIBLE_DEVICES", []) == "":
self.device = torch.device("cpu")
elif len(os.environ.get("CUDA_VISIBLE_DEVICES", [])) > 0:
self.device = torch.device("cuda")
if "onnx" not in str((type(self.model))) and (self.device == torch.device("cuda")):

detected_device, selected_device = select_gpu_device()
if "onnx" in str((type(self.model))):
selected_device = 'cpu'
# self.model = nn.DataParallel(self.model)
# self.model = self.model.half()
self.model = self.model.to(self.device)
self.device = torch.device(selected_device)
self.model = self.model.to(self.device)

@torch.no_grad()
def forward(
self,
pages: List[Union[np.ndarray, torch.Tensor]],
return_model_output = False,
return_model_output=False,
**kwargs: Any,
) -> List[np.ndarray]:

# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
Expand Down
17 changes: 8 additions & 9 deletions doctr/models/recognition/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import nn
import os
from doctr.models.preprocessor import PreProcessor
from doctr.utils.gpu import select_gpu_device

from ._utils import remap_preds, split_crops

Expand All @@ -31,20 +32,19 @@ def __init__(
model: nn.Module,
split_wide_crops: bool = True,
) -> None:

super().__init__()
self.pre_processor = pre_processor
self.model = model.eval()
self.postprocessor = self.model.postprocessor
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.environ.get("CUDA_VISIBLE_DEVICES", []) == "":
self.device = torch.device("cpu")
elif len(os.environ.get("CUDA_VISIBLE_DEVICES", [])) > 0:
self.device = torch.device("cuda")
if "onnx" not in str((type(self.model))) and (self.device == torch.device("cuda")):

detected_device, selected_device = select_gpu_device()
if "onnx" in str((type(self.model))):
selected_device = 'cpu'
# self.model = nn.DataParallel(self.model)
self.model = self.model.to(self.device)
# self.model = self.model.half()
self.device = torch.device(selected_device)
self.model = self.model.to(self.device)

self.split_wide_crops = split_wide_crops
self.critical_ar = 8 # Critical aspect ratio
self.dil_factor = 1.4 # Dilation factor to overlap the crops
Expand All @@ -56,7 +56,6 @@ def forward(
crops: Sequence[Union[np.ndarray, torch.Tensor]],
**kwargs: Any,
) -> List[Tuple[str, float]]:

if len(crops) == 0:
return []
# Dimension check
Expand Down
40 changes: 40 additions & 0 deletions doctr/utils/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import os
from typing import Tuple
import torch

log = logging.getLogger(__name__)


def select_gpu_device() -> Tuple[str, str]:
"""tries to find either cuda or arm mps gpu accelerator and choses the most appropriate one,
honoring the environment variables (CUDA_VISIBLE_DEVICES), if any have been set.
returns tuple(best_detected_device, selected_device)
best_detected_device reflects capabilities of the system
selected_device is the device that should be used (might be cpu even if best_detected_device is eg cuda)
"""
if torch.cuda.is_available():
detected_gpu_device = 'cuda'
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
detected_gpu_device = 'mps'
else:
detected_gpu_device = 'cpu'

selected_gpu_device = detected_gpu_device
match detected_gpu_device: # various exceptions to the above
case 'cuda':
if os.environ.get("CUDA_VISIBLE_DEVICES") == "":
selected_gpu_device = 'cpu'
case 'mps':
# FIXME detected mps selects cpu here because of the many bugs present in the mps implementation of
# torch'es 1.13 LSTM. As of 5/29/2023, they appear to be actively fixing them. I did try with torch
# 2.0.1 and while the bugs look different it's still broken. Revisit when later versions of torch
# are available.
# pass
selected_gpu_device = 'cpu'
case 'cpu':
pass

log.info(f"{detected_gpu_device=} {selected_gpu_device=}")
return detected_gpu_device, selected_gpu_device

0 comments on commit 2c2f7b8

Please sign in to comment.