-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchop_correction.py
71 lines (53 loc) · 2.1 KB
/
chop_correction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import os
import glob
import json
import multiprocessing as mp
from others.crop import find_best_match, count_w
def load_json(file_path):
try:
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
return data
except json.JSONDecodeError:
print(f"Error: '{file_path}' is not a valid JSON")
return None
def process_json(args, js_path):
js_basename = os.path.basename(js_path)
para_list = load_json(js_path)
for para in para_list:
status = para["status"]
if status == "corrected":
para_text = para["content"]
para_llm = para["content_"]
best_match, h_score = find_best_match(para_text, para_llm)
para["content_c"] = best_match
para["halu_score"] = h_score
if count_w(best_match) < count_w(para_llm):
print("cutoff:", js_path)
print("before:", para_llm)
print("after_:", best_match)
if count_w(para_text) > count_w(best_match) + 6:
# when OCR is longer than chop off
print("OCR longer ", js_path)
if not os.path.exists(args.o):
os.makedirs(args.o, exist_ok=True)
save_target = os.path.join(args.o, js_basename)
with open(save_target, "w", encoding="utf-8") as file:
json.dump(para_list, file, indent=4, ensure_ascii=False)
# print(f"Saved file {save_target} successfully")
# Work for each book, not for all at the same time.
def main():
parser = argparse.ArgumentParser(description="postprocessing after matching")
parser.add_argument("json_dir", help="dir to load input json files")
parser.add_argument("o", help="where to out/save json files")
args = parser.parse_args()
json_paths = glob.glob(os.path.join(args.json_dir, "*.json"))
# Use multiprocessing to process JSON files in parallel
with mp.Pool(processes=mp.cpu_count() - 4) as pool:
pool.starmap(
process_json,
[(args, js_path) for js_path in json_paths],
)
if __name__ == "__main__":
main()