Skip to content

Commit

Permalink
[Sync] page orientation integration (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jun 12, 2024
1 parent b5b17bc commit 495151c
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jp

### Putting it together

Let's use the default pretrained model for an example:
Let's use the default `ocr_predictor` model for an example:

```python
from onnxtr.io import DocumentFile
Expand Down
73 changes: 55 additions & 18 deletions onnxtr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import numpy as np
from langdetect import LangDetectException, detect_langs

from onnxtr.utils.geometry import rotate_image

__all__ = ["estimate_orientation", "get_language"]


Expand All @@ -29,42 +31,63 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
return max(w / h, h / w)


def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> int:
def estimate_orientation(
img: np.ndarray,
general_page_orientation: Optional[Tuple[int, float]] = None,
n_ct: int = 70,
ratio_threshold_for_lines: float = 3,
min_confidence: float = 0.2,
lower_area: int = 100,
) -> int:
"""Estimate the angle of the general document orientation based on the
lines of the document and the assumption that they should be horizontal.
Args:
----
img: the img or bitmap to analyze (H, W, C)
general_page_orientation: the general orientation of the page (angle [0, 90, 180, 270 (-90)], confidence)
estimated by a model
n_ct: the number of contours used for the orientation estimation
ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
min_confidence: the minimum confidence to consider the general_page_orientation
lower_area: the minimum area of a contour to be considered
Returns:
-------
the angle of the general document orientation
the estimated angle of the page (clockwise, negative for left side rotation, positive for right side rotation)
"""
assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported"
max_value = np.max(img)
min_value = np.min(img)
if max_value <= 1 and min_value >= 0 or (max_value <= 255 and min_value >= 0 and img.shape[-1] == 1):
thresh = img.astype(np.uint8)
if max_value <= 255 and min_value >= 0 and img.shape[-1] == 3:
thresh = None
# Convert image to grayscale if necessary
if img.shape[-1] == 3:
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = cv2.medianBlur(gray_img, 5)
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

# try to merge words in lines
(h, w) = img.shape[:2]
k_x = max(1, (floor(w / 100)))
k_y = max(1, (floor(h / 100)))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
thresh = cv2.dilate(thresh, kernel, iterations=1)
else:
thresh = img.astype(np.uint8) # type: ignore[assignment]

page_orientation, orientation_confidence = general_page_orientation or (None, 0.0)
if page_orientation and orientation_confidence >= min_confidence:
# We rotate the image to the general orientation which improves the detection
# No expand needed bitmap is already padded
thresh = rotate_image(thresh, -page_orientation) # type: ignore
else: # That's only required if we do not work on the detection models bin map
# try to merge words in lines
(h, w) = img.shape[:2]
k_x = max(1, (floor(w / 100)))
k_y = max(1, (floor(h / 100)))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
thresh = cv2.dilate(thresh, kernel, iterations=1)

# extract contours
contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

# Sort contours
contours = sorted(contours, key=get_max_width_length_ratio, reverse=True)
# Filter & Sort contours
contours = sorted(
[contour for contour in contours if cv2.contourArea(contour) > lower_area],
key=get_max_width_length_ratio,
reverse=True,
)

angles = []
for contour in contours[:n_ct]:
Expand All @@ -75,10 +98,24 @@ def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_li
angles.append(angle - 90)

if len(angles) == 0:
return 0 # in case no angles is found
estimated_angle = 0 # in case no angles is found
else:
median = -median_low(angles)
return round(median) if abs(median) != 0 else 0
estimated_angle = -round(median) if abs(median) != 0 else 0

# combine with the general orientation and the estimated angle
if page_orientation and orientation_confidence >= min_confidence:
# special case where the estimated angle is mostly wrong:
# case 1: - and + swapped
# case 2: estimated angle is completely wrong
# so in this case we prefer the general page orientation
if abs(estimated_angle) == abs(page_orientation):
return page_orientation
estimated_angle = estimated_angle if page_orientation == 0 else page_orientation + estimated_angle
if estimated_angle > 180:
estimated_angle -= 360

return estimated_angle # return the clockwise angle (negative - left side rotation, positive - right side rotation)


def rectify_crops(
Expand Down
62 changes: 59 additions & 3 deletions onnxtr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np

from onnxtr.models.builder import DocumentBuilder
from onnxtr.utils.geometry import extract_crops, extract_rcrops
from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image

from .._utils import rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor, page_orientation_predictor
from ..classification.predictor import OrientationPredictor
from ..detection.zoo import ARCHS as DETECTION_ARCHS
from ..recognition.zoo import ARCHS as RECOGNITION_ARCHS
Expand All @@ -31,18 +31,22 @@ class _OCRPredictor:
accordingly. Doing so will improve performances for documents with page-uniform rotations.
preserve_aspect_ratio: if True, resize preserving the aspect ratio (with padding)
symmetric_pad: if True and preserve_aspect_ratio is True, pas the image symmetrically.
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword args of `DocumentBuilder`
"""

crop_orientation_predictor: Optional[OrientationPredictor]
page_orientation_predictor: Optional[OrientationPredictor]

def __init__(
self,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
load_in_8_bit: bool = False,
**kwargs: Any,
) -> None:
Expand All @@ -51,11 +55,63 @@ def __init__(
self.crop_orientation_predictor = (
None if assume_straight_pages else crop_orientation_predictor(load_in_8_bit=load_in_8_bit)
)
self.page_orientation_predictor = (
page_orientation_predictor(load_in_8_bit=load_in_8_bit)
if detect_orientation or straighten_pages or not assume_straight_pages
else None
)
self.doc_builder = DocumentBuilder(**kwargs)
self.preserve_aspect_ratio = preserve_aspect_ratio
self.symmetric_pad = symmetric_pad
self.hooks: List[Callable] = []

def _general_page_orientations(
self,
pages: List[np.ndarray],
) -> List[Tuple[int, float]]:
_, classes, probs = zip(self.page_orientation_predictor(pages)) # type: ignore[misc]
# Flatten to list of tuples with (value, confidence)
page_orientations = [
(orientation, prob)
for page_classes, page_probs in zip(classes, probs)
for orientation, prob in zip(page_classes, page_probs)
]
return page_orientations

def _get_orientations(
self, pages: List[np.ndarray], seg_maps: List[np.ndarray]
) -> Tuple[List[Tuple[int, float]], List[int]]:
general_pages_orientations = self._general_page_orientations(pages)
origin_page_orientations = [
estimate_orientation(seq_map, general_orientation)
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
]
return general_pages_orientations, origin_page_orientations

def _straighten_pages(
self,
pages: List[np.ndarray],
seg_maps: List[np.ndarray],
general_pages_orientations: Optional[List[Tuple[int, float]]] = None,
origin_pages_orientations: Optional[List[int]] = None,
) -> List[np.ndarray]:
general_pages_orientations = (
general_pages_orientations if general_pages_orientations else self._general_page_orientations(pages)
)
origin_pages_orientations = (
origin_pages_orientations
if origin_pages_orientations
else [
estimate_orientation(seq_map, general_orientation)
for seq_map, general_orientation in zip(seg_maps, general_pages_orientations)
]
)
return [
# We exapnd if the page is wider than tall and the angle is 90 or -90
rotate_image(page, angle, expand=page.shape[1] > page.shape[0] and abs(angle) == 90)
for page, angle in zip(pages, origin_pages_orientations)
]

@staticmethod
def _generate_crops(
pages: List[np.ndarray],
Expand Down
26 changes: 15 additions & 11 deletions onnxtr/models/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np

from onnxtr.io.elements import Document
from onnxtr.models._utils import estimate_orientation, get_language
from onnxtr.models._utils import get_language
from onnxtr.models.detection.predictor import DetectionPredictor
from onnxtr.models.recognition.predictor import RecognitionPredictor
from onnxtr.utils.geometry import detach_scores, rotate_image
from onnxtr.utils.geometry import detach_scores
from onnxtr.utils.repr import NestedObject

from .base import _OCRPredictor
Expand Down Expand Up @@ -55,7 +55,13 @@ def __init__(
self.det_predictor = det_predictor
self.reco_predictor = reco_predictor
_OCRPredictor.__init__(
self, assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, **kwargs
self,
assume_straight_pages,
straighten_pages,
preserve_aspect_ratio,
symmetric_pad,
detect_orientation,
**kwargs,
)
self.detect_orientation = detect_orientation
self.detect_language = detect_language
Expand All @@ -80,19 +86,17 @@ def __call__(
for out_map in out_maps
]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
]
else:
orientations = None
general_pages_orientations = None
origin_pages_orientations = None
if self.straighten_pages:
origin_page_orientations = (
origin_page_orientations
if self.detect_orientation
else [estimate_orientation(seq_map) for seq_map in seg_maps]
)
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)

# forward again to get predictions on straight pages
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ changelog = "https://github.com/felixdittrich92/OnnxTR/releases"
zip-safe = true

[tool.setuptools.packages.find]
exclude = ["tests*", "scripts*"]
exclude = ["docs*", "tests*", "scripts*"]

[tool.setuptools.package-data]
doctr = ["py.typed"]

[tool.mypy]
files = "onnxtr/"
Expand Down
21 changes: 14 additions & 7 deletions tests/common/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,35 @@ def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip):

# test binarized image
angle = estimate_orientation(mock_bitmap)
assert abs(angle - 30.0) < 1.0
assert abs(angle) - 30 < 1.0

angle = estimate_orientation(mock_bitmap * 255)
assert abs(angle - 30.0) < 1.0
assert abs(angle) - 30.0 < 1.0

angle = estimate_orientation(mock_image)
assert abs(angle - 30.0) < 1.0
assert abs(angle) - 30.0 < 1.0

rotated = geometry.rotate_image(mock_image, -angle)
rotated = geometry.rotate_image(mock_image, angle)
angle_rotated = estimate_orientation(rotated)
assert abs(angle_rotated) < 1.0
assert abs(angle_rotated) == 0

mock_tilted_payslip = reader.read_img_as_numpy(mock_tilted_payslip)
assert (estimate_orientation(mock_tilted_payslip) - 30.0) < 1.0
assert estimate_orientation(mock_tilted_payslip) == -30

rotated = geometry.rotate_image(mock_tilted_payslip, -30, expand=True)
angle_rotated = estimate_orientation(rotated)
assert abs(angle_rotated) < 1.0

with pytest.raises(AssertionError):
estimate_orientation(np.ones((10, 10, 10)))

# test with general_page_orientation
assert estimate_orientation(mock_bitmap, (90, 0.9)) in range(140, 160)

rotated = geometry.rotate_image(mock_tilted_payslip, -30)
assert estimate_orientation(rotated, (0, 0.9)) in range(-10, 10)

assert estimate_orientation(mock_image, (0, 0.9)) - 30 < 1.0


def test_get_lang():
sentence = "This is a test sentence."
Expand Down
5 changes: 5 additions & 0 deletions tests/common/test_models_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@ def test_ocrpredictor(mock_pdf, assume_straight_pages, straighten_pages):

if assume_straight_pages:
assert predictor.crop_orientation_predictor is None
if predictor.detect_orientation or predictor.straighten_pages:
assert isinstance(predictor.page_orientation_predictor, NestedObject)
else:
assert predictor.page_orientation_predictor is None
else:
assert isinstance(predictor.crop_orientation_predictor, NestedObject)
assert isinstance(predictor.page_orientation_predictor, NestedObject)

out = predictor(doc)
assert isinstance(out, Document)
Expand Down

0 comments on commit 495151c

Please sign in to comment.