diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 786de4c..1381d54 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -34,13 +34,8 @@ jobs: run: | pip install --upgrade pip pip install -e . - pip install ruff pytest + pip install pytest - - name: Run ruff - if: always() - shell: bash - run: python -m ruff prodigy_pdf tests - - name: Run pytest if: always() shell: bash diff --git a/README.md b/README.md index fea6567..2bddf2a 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,16 @@ You can install this plugin via `pip`. pip install "prodigy-pdf @ git+https://github.com/explosion/prodigy-pdf" ``` +If you want to use the OCR recipes, you'll also want to ensure that tesseract is installed. + +```bash +# for mac +brew install tesseract + +# for ubuntu +sudo apt install tesseract-ocr +``` + To learn more about this plugin, you can check the [Prodigy docs](https://prodi.gy/docs/plugins/#pdf). ## Issues? diff --git a/prodigy_pdf/__init__.py b/prodigy_pdf/__init__.py index bee1776..ab66c7d 100644 --- a/prodigy_pdf/__init__.py +++ b/prodigy_pdf/__init__.py @@ -1,13 +1,16 @@ -from typing import List +from typing import List, Dict import base64 from io import BytesIO from pathlib import Path +from PIL import Image +import pytesseract import pypdfium2 as pdfium from prodigy import recipe, set_hashes, ControllerComponentsDict -from prodigy.components.stream import Stream -from prodigy.util import msg +from prodigy.components.stream import Stream, get_stream +from prodigy.util import msg, split_string + def page_to_image(page: pdfium.PdfPage) -> str: """Turns a PdfPage into a base64 image for Prodigy""" @@ -29,7 +32,6 @@ def generate_pdf_pages(pdf_paths: List[Path]): "image": page_to_image(page), "meta": { "page": page_number, - "pdf": pdf_path.parts[-1], "path": str(pdf_path) } }) @@ -67,9 +69,9 @@ def before_db(examples): del eg["image"] return examples - color = ["#ffff00", "#00ffff", "#ff00ff", "#00ff7f", "#ff6347", "#00bfff", + color = ["#00ffff", "#ff00ff", "#00ff7f", "#ff6347", "#00bfff", "#ffa500", "#ff69b4", "#7fffd4", "#ffd700", "#ffdab9", "#adff2f", - "#d2b48c", "#dcdcdc"] + "#d2b48c", "#dcdcdc", "#ffff00", ] return { "dataset": dataset, @@ -86,3 +88,108 @@ def before_db(examples): } }, } + + +def page_to_cropped_image(pil_page: Image, span: Dict, scale: int): + left, upper = span['x'], span['y'] + right, lower = left + span['width'], upper + span['height'] + scaled = (left * scale, upper * scale, right * scale, lower * scale) + cropped = pil_page.crop(scaled) + with BytesIO() as buffered: + cropped.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()) + return cropped, f"data:image/png;base64,{img_str.decode('utf-8')}" + + +def fold_ocr_dashes(ocr_input:str) -> str: + """ + OCR might literally add dashes at the end of the line to indicate + continuation of the word. This can be fine in some cases, but this + function can fold it all into a single string. + """ + new = "" + for line in ocr_input.split("\n"): + line = line.strip() + if line.rfind("-") == -1: + newline = line + " " + else: + newline = line[:line.rfind("-")] + new += newline + return new.strip() + + +def _validate_ocr_example(ex: Dict): + if 'meta' not in ex: + raise ValueError(f"It seems the `meta` key is missing from an example: {ex}. Did you annotate this data with `pdf.image.manual`?") + if 'path' not in ex['meta']: + raise ValueError(f"It seems the `path` key is missing from an example metadata: {ex}. Did you annotate this data with `pdf.image.manual`?") + if 'page' not in ex['meta']: + raise ValueError(f"It seems the `page` key is missing from an example metadata: {ex}. Did you annotate this data with `pdf.image.manual`?") + + +@recipe( + "pdf.ocr.correct", + # fmt: off + dataset=("Dataset to save answers to", "positional", None, str), + source=("Source with PDF Annotations", "positional", None, str), + labels=("Labels to consider", "option", "l", split_string), + scale=("Zoom scale. Increase above 3 to upscale the image for OCR.", "option", "s", int), + remove_base64=("Remove base64-encoded image data", "flag", "R", bool), + fold_dashes=("Removes dashes at the end of a textline and folds them with the next term.", "flag", "f", bool), + autofocus=("Autofocus on the transcript UI", "flag", "af", bool) + # fmt: on +) +def pdf_ocr_correct( + dataset: str, + source: str, + labels: str, + scale: int = 3, + remove_base64:bool=False, + fold_dashes:bool = False, + autofocus: bool = False +) -> ControllerComponentsDict: + """Applies OCR to annotated segments and gives a textbox for corrections.""" + stream = get_stream(source) + + def new_stream(stream): + for ex in stream: + useful_spans = [span for span in ex.get('spans', []) if span['label'] in labels] + if useful_spans: + _validate_ocr_example(ex) + pdf = pdfium.PdfDocument(ex['meta']['path']) + page = pdf.get_page(ex['meta']['page']) + pil_page = page.render(scale=scale).to_pil() + for annot in useful_spans: + cropped, img_str = page_to_cropped_image(pil_page, span=annot, scale=scale) + annot["image"] = img_str + annot["text"] = pytesseract.image_to_string(cropped) + if fold_dashes: + annot["text"] = fold_ocr_dashes(annot["text"]) + annot["transcription"] = annot["text"] + text_input_fields = { + "field_rows": 12, + "field_label": "Transcript", + "field_id": "transcription", + "field_autofocus": autofocus, + } + del annot['id'] + yield set_hashes({**annot, **text_input_fields}) + + def before_db(examples): + # Remove all data URIs before storing example in the database + for eg in examples: + if eg["image"].startswith("data:"): + del eg["image"] + return examples + + blocks = [{"view_id": "classification"}, {"view_id": "text_input"}] + + return { + "dataset": dataset, + "stream": new_stream(stream), + "before_db": before_db if remove_base64 else None, + "view_id": "blocks", + "config": { + "blocks": blocks + }, + } diff --git a/setup.cfg b/setup.cfg index 30e603a..8601795 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -version = 0.1.0 +version = 0.2.0 description = Recipes for PDF annotation url = https://github.com/explosion/prodigy-pdf author = Explosion @@ -11,6 +11,7 @@ python_requires = >=3.8 install_requires = pypdfium2==4.20.0 Pillow==9.4.0 + pytesseract==0.3.10 [options.entry_points] prodigy_recipes = diff --git a/tests/test_basics.py b/tests/test_basics.py index 79e494e..38aaf76 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -1,5 +1,5 @@ from pathlib import Path -from prodigy_pdf import generate_pdf_pages +from prodigy_pdf import generate_pdf_pages, fold_ocr_dashes def test_smoke_internal(): @@ -9,3 +9,23 @@ def test_smoke_internal(): assert len(pages) == 6 for page in pages: assert "data" in page['image'] + + +def test_fold_dashes(): + going_in = """ + Real-Time Strategy (RTS) games have become an increas- + ingly popular test-bed for modern artificial intelligence tech- + niques. With this rise in popularity has come the creation of + several annual competitions, in which AI agents (bots) play + the full game of StarCraft: Broodwar by Blizzard Entertain- + ment. The three major annual StarCraft AI Competitions are + the Student StarCraft AI Tournament (SSCAIT), the Com- + putational Intelligence in Games (CIG) competition, and the + Artificial Intelligence and Interactive Digital Entertainment + (AIIDE) competition. In this paper we will give an overview + of the current state of these competitions, and the bots that + compete in them. + """ + + expected = "Real-Time Strategy (RTS) games have become an increasingly popular test-bed for modern artificial intelligence techniques. With this rise in popularity has come the creation of several annual competitions, in which AI agents (bots) play the full game of StarCraft: Broodwar by Blizzard Entertainment. The three major annual StarCraft AI Competitions are the Student StarCraft AI Tournament (SSCAIT), the Computational Intelligence in Games (CIG) competition, and the Artificial Intelligence and Interactive Digital Entertainment (AIIDE) competition. In this paper we will give an overview of the current state of these competitions, and the bots that compete in them." + assert fold_ocr_dashes(going_in) == expected