Skip to content

Commit

Permalink
fix: format and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
dhdaines committed Jul 16, 2024
1 parent 5cf98a1 commit 25e789e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
3 changes: 2 additions & 1 deletion alexi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .label import Identificateur
from .search import search
from .segment import DEFAULT_MODEL as DEFAULT_SEGMENT_MODEL
from .segment import Segmenteur, RNNSegmenteur
from .segment import RNNSegmenteur, Segmenteur

LOGGER = logging.getLogger("alexi")
VERSION = "0.4.0"
Expand Down Expand Up @@ -59,6 +59,7 @@ def convert_main(args: argparse.Namespace):

def segment_main(args: argparse.Namespace):
"""Segmenter un CSV"""
crf: Segmenteur
if args.model.suffix == ".pt":
crf = RNNSegmenteur(args.model)
else:
Expand Down
4 changes: 3 additions & 1 deletion alexi/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from alexi.label import Identificateur
from alexi.link import Resolver
from alexi.segment import DEFAULT_MODEL as DEFAULT_SEGMENT_MODEL
from alexi.segment import DEFAULT_MODEL_NOSTRUCT, Segmenteur, RNNSegmenteur
from alexi.segment import DEFAULT_MODEL_NOSTRUCT, RNNSegmenteur, Segmenteur
from alexi.types import T_obj

LOGGER = logging.getLogger("extract")
Expand Down Expand Up @@ -329,6 +329,8 @@ def make_doc_tree(docs: list[Document], outdir: Path) -> dict[str, dict[str, str


class Extracteur:
crf: Segmenteur

def __init__(
self,
outdir: Path,
Expand Down
26 changes: 13 additions & 13 deletions alexi/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,16 @@

import joblib # type: ignore
import torch
from allennlp_light.modules.conditional_random_field import (
ConditionalRandomFieldWeightTrans,
)
from torch import nn
from torch.nn.utils.rnn import (
PackedSequence,
pack_padded_sequence,
pad_packed_sequence,
pad_sequence,
)
from allennlp_light.modules.conditional_random_field import (
ConditionalRandomFieldWeightEmission,
ConditionalRandomFieldWeightTrans,
ConditionalRandomFieldWeightLannoy,
)

from alexi.convert import FIELDNAMES
from alexi.format import line_breaks
Expand Down Expand Up @@ -366,12 +364,12 @@ def make_rnn_features(
page: Iterable[T_obj],
features: str = "text+layout+structure",
labels: str = "literal",
):
features = list(
) -> tuple[list[T_obj], list[str]]:
rnn_features = list(
dict((name, val) for name, _, val in (w.partition("=") for w in feats))
for feats in page2features(page, features)
for feats in page2features(list(page), features)
)
for f, w in zip(features, page):
for f, w in zip(rnn_features, page):
f["line:left"] = float(f["line:left"]) / float(w["page_width"])
f["line:top"] = float(f["line:top"]) / float(w["page_height"])
f["v:top"] = float(w["top"]) / float(w["page_height"])
Expand All @@ -384,9 +382,9 @@ def make_rnn_features(
w["page_height"]
)

add_deltas(features)
labels = list(page2labels(page, labels))
return features, labels
add_deltas(rnn_features)
rnn_labels = list(page2labels(page, labels))
return rnn_features, rnn_labels


FEATNAMES = [
Expand Down Expand Up @@ -769,7 +767,9 @@ def __call__(self, words: Iterable[dict[str, Any]]) -> Iterable[dict[str, Any]]:
yield word


class RNNSegmenteur:
class RNNSegmenteur(Segmenteur):
model: RNN

def __init__(self, model: PathLike = DEFAULT_RNN_MODEL, device="cpu"):
model = Path(model)
self.device = torch.device(device)
Expand Down

0 comments on commit 25e789e

Please sign in to comment.