Skip to content

Commit

Permalink
feat: add 2 new rotation flags in the ocr_predictor (#632)
Browse files Browse the repository at this point in the history
* feat: add 2 new rot flags in ocr_predictor

* fix: style

* feat: add README update

* refcato: requested changes

* fix: kwargs error
  • Loading branch information
charlesmindee authored Nov 19, 2021
1 parent 8b6dac3 commit 74ff9ff
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ doc = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images()
result = model(doc)
```

### Dealing with rotated documents
Should you use docTR on documents that include rotated pages, or pages with multiple box orientations,
you have multiple options to handle it:

- If you only use straight document pages with straight words (horizontal, same reading direction),
consider passing `assume_straight_boxes=True` to the ocr_predictor. It will directly fit straight boxes
on your page and return straight boxes, which makes it the fastest option.

- If you want the predictor to output straight boxes (no matter the orientation of your pages, the final localizations
will be converted to straight boxes), you need to pass `export_as_straight_boxes=True` in the predictor. Otherwise, if `assume_straight_pages=False`, it will return rotated bounding boxes (potentially with an angle of 0°).

If both options are set to False, the predictor will always fit and return rotated boxes.


To interpret your model's predictions, you can visualize them interactively as follows:

```python
Expand Down
5 changes: 5 additions & 0 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class OCRPredictor(nn.Module, _OCRPredictor):
Args:
det_predictor: detection module
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
"""

def __init__(
Expand Down
4 changes: 4 additions & 0 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class OCRPredictor(NestedObject, _OCRPredictor):
Args:
det_predictor: detection module
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
"""
_children_names = ['det_predictor', 'reco_predictor']

Expand Down
29 changes: 25 additions & 4 deletions doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@
__all__ = ["ocr_predictor"]


def _predictor(det_arch: str, reco_arch: str, pretrained: bool, det_bs=2, reco_bs=128) -> OCRPredictor:
def _predictor(
det_arch: str,
reco_arch: str,
pretrained: bool,
det_bs: int = 2,
reco_bs: int = 128,
**kwargs,
) -> OCRPredictor:

# Detection
det_predictor = detection_predictor(det_arch, pretrained=pretrained, batch_size=det_bs)

# Recognition
reco_predictor = recognition_predictor(reco_arch, pretrained=pretrained, batch_size=reco_bs)

return OCRPredictor(det_predictor, reco_predictor)
return OCRPredictor(det_predictor, reco_predictor, **kwargs)


def ocr_predictor(
det_arch: str = 'db_resnet50',
reco_arch: str = 'crnn_vgg16_bn',
pretrained: bool = False,
assume_straight_pages: bool = True,
export_as_straight_boxes: bool = False,
**kwargs: Any
) -> OCRPredictor:
"""End-to-end OCR architecture using one model for localization, and another for text recognition.
Expand All @@ -39,11 +48,23 @@ def ocr_predictor(
>>> out = model([input_page])
Args:
arch: name of the architecture to use ('db_sar_vgg', 'db_sar_resnet', 'db_crnn_vgg', 'db_crnn_resnet')
det_arch: name of the detection architecture to use (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
reco_arch: name of the recognition architecture to use (e.g. 'crnn_vgg16_bn', 'sar_resnet31')
pretrained: If True, returns a model pre-trained on our OCR dataset
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
Returns:
OCR predictor
"""

return _predictor(det_arch, reco_arch, pretrained, **kwargs)
return _predictor(
det_arch,
reco_arch,
pretrained,
assume_straight_pages=assume_straight_pages,
export_as_straight_boxes=export_as_straight_boxes,
**kwargs,
)

0 comments on commit 74ff9ff

Please sign in to comment.