Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace LGPL Unidecode with more permissive anyascii #30

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions doctr/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cv2
import numpy as np
from scipy.optimize import linear_sum_assignment
from unidecode import unidecode
from anyascii import anyascii

__all__ = [
"TextMatch",
Expand All @@ -32,16 +32,16 @@ def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]:

Returns:
a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their
unidecode counterparts and their lower-case unidecode counterparts match
anyascii counterparts and their lower-case anyascii counterparts match
"""
raw_match = word1 == word2
caseless_match = word1.lower() == word2.lower()
unidecode_match = unidecode(word1) == unidecode(word2)
anyascii_match = anyascii(word1) == anyascii(word2)

# Warning: the order is important here otherwise the pair ("EUR", "€") cannot be matched
unicase_match = unidecode(word1).lower() == unidecode(word2).lower()
unicase_match = anyascii(word1).lower() == anyascii(word2).lower()

return raw_match, caseless_match, unidecode_match, unicase_match
return raw_match, caseless_match, anyascii_match, unicase_match


class TextMatch:
Expand Down Expand Up @@ -92,10 +92,10 @@ def update(
raise AssertionError("prediction size does not match with ground-truth labels size")

for gt_word, pred_word in zip(gt, pred):
_raw, _caseless, _unidecode, _unicase = string_match(gt_word, pred_word)
_raw, _caseless, _anyascii, _unicase = string_match(gt_word, pred_word)
self.raw += int(_raw)
self.caseless += int(_caseless)
self.unidecode += int(_unidecode)
self.anyascii += int(_anyascii)
self.unicase += int(_unicase)

self.total += len(gt)
Expand All @@ -104,23 +104,23 @@ def summary(self) -> Dict[str, float]:
"""Computes the aggregated metrics

Returns:
a dictionary with the exact match score for the raw data, its lower-case counterpart, its unidecode
counterpart and its lower-case unidecode counterpart
a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii
counterpart and its lower-case anyascii counterpart
"""
if self.total == 0:
raise AssertionError("you need to update the metric before getting the summary")

return dict(
raw=self.raw / self.total,
caseless=self.caseless / self.total,
unidecode=self.unidecode / self.total,
anyascii=self.anyascii / self.total,
unicase=self.unicase / self.total,
)

def reset(self) -> None:
self.raw = 0
self.caseless = 0
self.unidecode = 0
self.anyascii = 0
self.unicase = 0
self.total = 0

Expand Down Expand Up @@ -531,10 +531,10 @@ def update(
is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh
# String comparison
for gt_idx, pred_idx in zip(gt_indices[is_kept], pred_indices[is_kept]):
_raw, _caseless, _unidecode, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx])
_raw, _caseless, _anyascii, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx])
self.raw_matches += int(_raw)
self.caseless_matches += int(_caseless)
self.unidecode_matches += int(_unidecode)
self.anyascii_matches += int(_anyascii)
self.unicase_matches += int(_unicase)

self.num_gts += gt_boxes.shape[0]
Expand All @@ -551,15 +551,15 @@ def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]
recall = dict(
raw=self.raw_matches / self.num_gts if self.num_gts > 0 else None,
caseless=self.caseless_matches / self.num_gts if self.num_gts > 0 else None,
unidecode=self.unidecode_matches / self.num_gts if self.num_gts > 0 else None,
anyascii=self.anyascii_matches / self.num_gts if self.num_gts > 0 else None,
unicase=self.unicase_matches / self.num_gts if self.num_gts > 0 else None,
)

# Precision
precision = dict(
raw=self.raw_matches / self.num_preds if self.num_preds > 0 else None,
caseless=self.caseless_matches / self.num_preds if self.num_preds > 0 else None,
unidecode=self.unidecode_matches / self.num_preds if self.num_preds > 0 else None,
anyascii=self.anyascii_matches / self.num_preds if self.num_preds > 0 else None,
unicase=self.unicase_matches / self.num_preds if self.num_preds > 0 else None,
)

Expand All @@ -574,7 +574,7 @@ def reset(self) -> None:
self.tot_iou = 0.0
self.raw_matches = 0
self.caseless_matches = 0
self.unidecode_matches = 0
self.anyascii_matches = 0
self.unicase_matches = 0


Expand Down
6 changes: 3 additions & 3 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from matplotlib.figure import Figure
from PIL import Image, ImageDraw
from unidecode import unidecode
from anyascii import anyascii

from .common_types import BoundingBox, Polygon4P
from .fonts import get_font
Expand Down Expand Up @@ -295,8 +295,8 @@ def synthesize_page(
try:
d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))
except UnicodeEncodeError:
# When character cannot be encoded, use its unidecode version
d.text((0, 0), unidecode(word["value"]), font=font, fill=(0, 0, 0))
# When character cannot be encoded, use its anyascii version
d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0))

# Colorize if draw_proba
if draw_proba:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ dependencies = [
"Pillow>=8.3.2", # cf. https://github.com/advisories/GHSA-98vv-pw6r-q6q4
"defusedxml>=0.7.0",
"mplcursors>=0.3",
"unidecode>=1.0.0",
"anyascii==0.3.2",
"tqdm>=4.30.0",
"rapidfuzz>=1.6.0",
"huggingface-hub>=0.4.0",
"openvino==2022.3.1"
"openvino==2024.1.0"
]

[project.optional-dependencies]
Expand Down
26 changes: 13 additions & 13 deletions tests/common/test_utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


@pytest.mark.parametrize(
"gt, pred, raw, caseless, unidecode, unicase",
"gt, pred, raw, caseless, anyascii, unicase",
[
[["grass", "56", "True", "EUR"], ["grass", "56", "true", "€"], 0.5, 0.75, 0.75, 1],
[["éléphant", "ça"], ["elephant", "ca"], 0, 0, 1, 1],
],
)
def test_text_match(gt, pred, raw, caseless, unidecode, unicase):
def test_text_match(gt, pred, raw, caseless, anyascii, unicase):
metric = metrics.TextMatch()
with pytest.raises(AssertionError):
metric.summary()
Expand All @@ -20,10 +20,10 @@ def test_text_match(gt, pred, raw, caseless, unidecode, unicase):
metric.update(["a", "b"], ["c"])

metric.update(gt, pred)
assert metric.summary() == dict(raw=raw, caseless=caseless, unidecode=unidecode, unicase=unicase)
assert metric.summary() == dict(raw=raw, caseless=caseless, anyascii=anyascii, unicase=unicase)

metric.reset()
assert metric.raw == metric.caseless == metric.unidecode == metric.unicase == metric.total == 0
assert metric.raw == metric.caseless == metric.anyascii == metric.unicase == metric.total == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -210,8 +210,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]]],
[["elephant"]],
0.5,
{"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "anyascii": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "anyascii": 1, "unicase": 1},
1,
],
[ # Bad match
Expand All @@ -220,8 +220,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]]],
[["elephant"]],
0.5,
{"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "anyascii": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "anyascii": 0, "unicase": 0},
1,
],
[ # Good match
Expand All @@ -230,8 +230,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5], [0.6, 0.6, 0.7, 0.7]]],
[["€", "e"]],
0.2,
{"raw": 0, "caseless": 0, "unidecode": 1, "unicase": 1},
{"raw": 0, "caseless": 0, "unidecode": 0.5, "unicase": 0.5},
{"raw": 0, "caseless": 0, "anyascii": 1, "unicase": 1},
{"raw": 0, "caseless": 0, "anyascii": 0.5, "unicase": 0.5},
0.13,
],
[ # No preds on 2nd sample
Expand All @@ -240,8 +240,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]], None],
[["elephant"], []],
0.5,
{"raw": 0, "caseless": 0.5, "unidecode": 0, "unicase": 0.5},
{"raw": 0, "caseless": 1, "unidecode": 0, "unicase": 1},
{"raw": 0, "caseless": 0.5, "anyascii": 0, "unicase": 0.5},
{"raw": 0, "caseless": 1, "anyascii": 0, "unicase": 1},
1,
],
],
Expand All @@ -258,7 +258,7 @@ def test_ocr_metric(gt_boxes, gt_words, pred_boxes, pred_words, iou_thresh, reca
assert _mean_iou == mean_iou
metric.reset()
assert metric.num_gts == metric.num_preds == metric.tot_iou == 0
assert metric.raw_matches == metric.caseless_matches == metric.unidecode_matches == metric.unicase_matches == 0
assert metric.raw_matches == metric.caseless_matches == metric.anyascii_matches == metric.unicase_matches == 0
# Shape check
with pytest.raises(AssertionError):
metric.update(
Expand Down
Loading