diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py index 5fbfc96..6b79cfc 100644 --- a/onnxtr/models/predictor/base.py +++ b/onnxtr/models/predictor/base.py @@ -9,7 +9,7 @@ from onnxtr.models.builder import DocumentBuilder from onnxtr.models.engine import EngineConfig -from onnxtr.utils.geometry import extract_crops, extract_rcrops, rotate_image +from onnxtr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds from ..classification import crop_orientation_predictor, page_orientation_predictor @@ -112,8 +112,8 @@ def _straighten_pages( ] ) return [ - # expand if height and width are not equal - rotate_image(page, angle, expand=page.shape[0] != page.shape[1]) + # expand if height and width are not equal, afterwards remove padding + remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1])) for page, angle in zip(pages, origin_pages_orientations) ] diff --git a/onnxtr/utils/geometry.py b/onnxtr/utils/geometry.py index 8b1f2ed..0c84e29 100644 --- a/onnxtr/utils/geometry.py +++ b/onnxtr/utils/geometry.py @@ -391,6 +391,26 @@ def rotate_image( return rot_img +def remove_image_padding(image: np.ndarray) -> np.ndarray: + """Remove black border padding from an image + + Args: + ---- + image: numpy tensor to remove padding from + + Returns: + ------- + Image with padding removed + """ + # Find the bounding box of the non-black region + rows = np.any(image, axis=1) + cols = np.any(image, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return image[rmin : rmax + 1, cmin : cmax + 1] + + def estimate_page_angle(polys: np.ndarray) -> float: """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the estimated angle ccw in degrees diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py index d78df06..eaa3e2a 100644 --- a/tests/common/test_utils_geometry.py +++ b/tests/common/test_utils_geometry.py @@ -167,6 +167,17 @@ def test_rotate_image(): assert rotated[0, :, 0].sum() <= 1 +def test_remove_image_padding(): + img = np.ones((32, 64, 3), dtype=np.float32) + padded = np.pad(img, ((10, 10), (20, 20), (0, 0))) + cropped = geometry.remove_image_padding(padded) + assert np.all(cropped == img) + + # No padding + cropped = geometry.remove_image_padding(img) + assert np.all(cropped == img) + + @pytest.mark.parametrize( "abs_geoms, img_size, rel_geoms", [