Skip to content

Commit

Permalink
update rec, we forget consider exclude the mfd part
Browse files Browse the repository at this point in the history
  • Loading branch information
veya2ztn committed Sep 19, 2024
1 parent bb18105 commit fc5adc0
Show file tree
Hide file tree
Showing 22 changed files with 888 additions and 136 deletions.
4 changes: 2 additions & 2 deletions batch_running_task/batch_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ do

#sbatch --quotatype=spot -p AI4Chem -N1 -c8 --gres=gpu:1 run.sh sci_index_files.addon.filelist $(($CPU+$START)) $TOTALNUM
#sbatch --quotatype=spot -p AI4Chem -N1 -c8 --gres=gpu:1 run_mfr.sh physics_collection/sci_index_files.remain.filelist 0 1
#sbatch --quotatype=spot -p AI4Chem -N1 -c8 --gres=gpu:1 run_rec.sh physics_collection/sci_index_files.remain.filelist $(($CPU+$START)) $TOTALNUM
sbatch --quotatype=spot -p AI4Chem -N1 -c8 --gres=gpu:1 batch_running_task/task_layout/run_layout_for_missing_page.sh scihub_collection/analysis/not_complete_pdf_page_id.pairlist.filelist 0 1
sbatch --quotatype=spot -p AI4Chem -N1 -c8 --gres=gpu:1 run_rec.sh physics_collection/physics.files.final.filelist $(($CPU+$START)) $TOTALNUM
#sbatch --quotatype=spot -p AI4Chem -N1 -c8 --gres=gpu:1 batch_running_task/task_layout/run_layout_for_missing_page.sh physics_collection/analysis/not_complete_pdf_page_id.pairlist.remain.filelist $(($CPU+$START)) $TOTALNUM
## lets sleep 20s every 10 job start
if [ $(($CPU % 10)) -eq 9 ]; then
sleep 20
Expand Down
157 changes: 155 additions & 2 deletions batch_running_task/get_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ def write_jsonl_to_path(data, path, client):
if thedir:
os.makedirs(thedir, exist_ok=True)
with open(path,'w') as f:
f.write(byte_object)
for d in data:
try:
byte_object = json.dumps(d)
except:

raise NotImplementedError(f"fail to dump {d}")
f.write(byte_object)


import boto3
Expand Down Expand Up @@ -347,4 +353,151 @@ def get_page_num_map_whole():

for result in results:
page_num_map_whole.update(result)
return page_num_map_whole
return page_num_map_whole

output_width =1472 #pdf_metadata['width']#1472
output_height=1920 #pdf_metadata['height']#1920
import sys
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))

from batch_running_task.utils import convert_boxes
def build_dict(pdf_metadata_list, track_id_key = "track_id"):
pdf_metadata_dict = {}
for pdf_metadata in pdf_metadata_list:
track_id = pdf_metadata[track_id_key]
height = pdf_metadata.get('height', 1920)
width = pdf_metadata.get('width',1472)
if height == output_height and width == output_width:
pass
else:
### lets do the bbox convertion
doc_layout_result=pdf_metadata['doc_layout_result']
for pdf_page_metadata in doc_layout_result:
page_id = pdf_page_metadata['page_id']
layout_dets = []
for res in pdf_page_metadata["layout_dets"]:
new_res = res.copy()
xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
bbox= [xmin, ymin, xmax, ymax]
bbox= convert_boxes([bbox], pdf_metadata['width'], pdf_metadata['height'], output_width, output_height)[0]
poly= [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]]
res['poly'] = poly
page_id_to_metadata = {pdf_page_metadata['page_id']: pdf_page_metadata for pdf_page_metadata in pdf_metadata['doc_layout_result']}
pdf_metadata_dict[track_id] = page_id_to_metadata

return pdf_metadata_dict

def read_data_with_patch(result_path, client):
if result_path.startswith("s3:"):
result_path = "opendata:"+result_path
pdf_path_map_to_page_num = []
#assert "layoutV" in result_path
filename = os.path.basename(result_path)
patch_path = os.path.join(os.path.dirname(os.path.dirname(result_path)),"det_patch_good",filename)
missingpath= os.path.join(os.path.dirname(os.path.dirname(result_path)),"fix_missing_page_version2",filename)
mfr_patchpath = os.path.join(os.path.dirname(os.path.dirname(result_path)),"mfr_patch",filename)
mfr_patch_bf16path = os.path.join(os.path.dirname(os.path.dirname(result_path)),"mfr_patch_bf16",filename)
rec_patchpath = os.path.join(os.path.dirname(os.path.dirname(result_path)),"rec_patch",filename)

assert check_path_exists(result_path,client)
#tqdm.write("reading result")
result = read_json_from_path(result_path,client)
result_dict = build_dict(result)

patch_add_dict = build_dict(read_json_from_path(patch_path,client)) if check_path_exists(patch_path,client) else {}

missing_dict = build_dict(read_json_from_path(missingpath,client)) if check_path_exists(missingpath,client) else {}
mfr_patch_dict = build_dict(read_json_from_path(mfr_patchpath,client)) if check_path_exists(mfr_patchpath,client) else {}
mfr_patch_bf16_dict = build_dict(read_json_from_path(mfr_patch_bf16path,client)) if check_path_exists(mfr_patch_bf16path,client) else {}
rec_patch_dict = build_dict(read_json_from_path(rec_patchpath,client)) if check_path_exists(rec_patchpath,client) else {}

#tqdm.write("reading done")
if len(patch_add_dict) == 0 and len(missing_dict) == 0:
#tqdm.write(f"no patch and missing for {result_path}")
pass
else:

for track_id, pdf_metadata in result_dict.items():
for patch_dict in [patch_add_dict, missing_dict]:
if track_id in patch_dict:
patch_pdf_metadata = patch_dict[track_id]
for page_id, pdf_page_metadata in patch_pdf_metadata.items():
if page_id in pdf_metadata:
## then merge page result
pdf_metadata[page_id]["layout_dets"].extend(pdf_page_metadata["layout_dets"])
else:
pdf_metadata[page_id] = pdf_page_metadata
for pdf_metadata in result:
track_id = pdf_metadata['track_id']
pdf_metadata['height'] = output_height
pdf_metadata['width'] = output_width
doc_layout_result = []
for page_id, pdf_page_metadata in result_dict[track_id].items():
doc_layout_result.append(pdf_page_metadata)
pdf_metadata['doc_layout_result'] = doc_layout_result
return result

def read_data_with_mfr(result_path, client):
if result_path.startswith("s3:"):
result_path = "opendata:"+result_path

filename = os.path.basename(result_path)

mfr_patchpath = os.path.join(os.path.dirname(os.path.dirname(result_path)),"mfr_patch",filename)
mfr_patch_bf16path = os.path.join(os.path.dirname(os.path.dirname(result_path)),"mfr_patch_bf16",filename)
rec_patchpath = os.path.join(os.path.dirname(os.path.dirname(result_path)),"rec_patch",filename)

assert check_path_exists(result_path,client)
#tqdm.write("reading result")
result = read_json_from_path(result_path,client)

mfr_patch_dict = build_dict(read_json_from_path(mfr_patchpath,client),track_id_key = 'path') if check_path_exists(mfr_patchpath,client) else {}
mfr_patch_bf16_dict= build_dict(read_json_from_path(mfr_patch_bf16path,client),track_id_key = 'path') if check_path_exists(mfr_patch_bf16path,client) else {}
#tqdm.write("reading done")
#tqdm.write("adding patch and missing")
for pdf_metadata in tqdm(result, desc="adding patch and missing", leave=False, position=3):
track_id = pdf_metadata['path']
if track_id in mfr_patch_bf16_dict:
current_mfr_patch = mfr_patch_bf16_dict[track_id]
elif track_id in mfr_patch_dict:
current_mfr_patch = mfr_patch_dict[track_id]
else:
continue
for pdf_page_metadata in pdf_metadata['doc_layout_result']:
page_id = pdf_page_metadata['page_id']
bbox_count = 0
for bbox_metadata in pdf_page_metadata['layout_dets']:
if bbox_metadata['category_id'] not in [13, 14]:continue
bbox_count+=1
if bbox_count == 0: continue
patch_mfr_list = current_mfr_patch[page_id]["layout_dets"]
assert len(patch_mfr_list) == bbox_count, f"pdf={track_id} page={page_id} => bbox count {bbox_count} not equal to patch count {len(patch_mfr_list)}"
bbox_id = 0
for bbox_metadata in pdf_page_metadata['layout_dets']:
if bbox_metadata['category_id'] not in [13, 14]:continue
bbox_metadata.update(patch_mfr_list[bbox_id])
bbox_id += 1

rec_patch_dict = build_dict(read_json_from_path(rec_patchpath,client),track_id_key = 'path') if check_path_exists(rec_patchpath,client) else {}
for pdf_metadata in tqdm(result, desc="[REC] adding patch and missing", leave=False, position=3):
track_id = pdf_metadata['path']
if track_id in rec_patch_dict:
current_rec_patch = rec_patch_dict[track_id]
else:
continue
for pdf_page_metadata in pdf_metadata['doc_layout_result']:
page_id = pdf_page_metadata['page_id']
bbox_count = 0
for bbox_metadata in pdf_page_metadata['layout_dets']:
if bbox_metadata['category_id'] != 15:continue
bbox_count+=1
if bbox_count == 0: continue
patch_rec_list = current_rec_patch[page_id]["layout_dets"]
assert len(patch_rec_list) == bbox_count, f"pdf={track_id} page={page_id} => bbox count {bbox_count} not equal to patch count {len(patch_rec_list)}"
bbox_id = 0
for bbox_metadata in pdf_page_metadata['layout_dets']:
if bbox_metadata['category_id'] != 15:continue
bbox_metadata.update(patch_rec_list[bbox_id])
bbox_id += 1
return result
98 changes: 95 additions & 3 deletions batch_running_task/scihub_pdf_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@

from get_data_utils import *
from utils import collect_mfdetrec_res_per_page, formula_in_text
from torch.utils.data import IterableDataset,get_worker_info,DataLoader, Dataset
from utils import Timers,convert_boxes
import torch
from utils import collect_paragraph_image_and_its_coordinate

def update_det_boxes(dt_boxes, mfdetrec_res):
new_dt_boxes = dt_boxes
for mf_box in mfdetrec_res:
flag, left_box, right_box = False, None, None
for idx, text_box in enumerate(new_dt_boxes):
if 'bbox' in mf_box:
bbox = mf_box['bbox']
elif 'poly' in mf_box:
xmin, ymin = int(mf_box['poly'][0]), int(mf_box['poly'][1])
xmax, ymax = int(mf_box['poly'][4]), int(mf_box['poly'][5])
bbox= [xmin, ymin, xmax, ymax]
else:
raise NotImplementedError("mf_box should have bbox or poly")
ret, left_box, right_box = formula_in_text(bbox, text_box)
if ret:
new_dt_boxes.pop(idx)
if left_box is not None:
new_dt_boxes.append(left_box)
if right_box is not None:
new_dt_boxes.append(right_box)
break

return new_dt_boxes

def clean_pdf_path(pdf_path):
return pdf_path[len("opendata:"):] if pdf_path.startswith("opendata:") else pdf_path

Expand Down Expand Up @@ -146,6 +171,8 @@ def __next__(self):
raise StopIteration
return output



class RecImageDataset(Dataset, DatasetUtils,ImageTransformersUtils):
error_count=0
def __init__(self, metadata_filepath,
Expand All @@ -162,7 +189,67 @@ def __len__(self):

def __getitem__(self, index) :
pdf_metadata = self.metadata[index]
return deal_with_one_pdf(pdf_metadata, self.client)
return self.get_cropped_image_list_via_remove_mfd_part(pdf_metadata, self.client)

@staticmethod
def collect_location_and_dt_box_from_page_metadata(pdf_path, pdf_page_metadata):
location_keys = []
page_id = pdf_page_metadata['page_id']
mfd_res_list = collect_mfdetrec_res_per_page(pdf_page_metadata['layout_dets']) # List[Dict] [{'bbox':[a,b,c,d]}, {'bbox':[a,b,c,d]}]
for bbox_metadata in pdf_page_metadata['layout_dets']:
if bbox_metadata['category_id']!=15:continue
bbox_id = tuple(bbox_metadata['poly'])
tmp_box = np.array(bbox_metadata['poly']).reshape(-1, 2)
tmp_box = sorted_boxes(tmp_box[None])[0].astype('float32')
dt_boxes = [tmp_box]
if mfd_res_list:
dt_boxes = update_det_boxes(dt_boxes, mfd_res_list)
# logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), aft-bef))
if len(dt_boxes) == 1 and bbox_metadata.get('text',"")!="":
#print("we can skip this one because it has no formula, and origin ocr is corr")
continue
## this mean we do not need modify it, lets skip

for dt_box in dt_boxes:
#print(dt_box)
### from dt_box to get bbox
sub_box_id = (dt_box[0][0],dt_box[0][1],dt_box[1][0],dt_box[1][1],dt_box[2][0],dt_box[2][1],dt_box[3][0],dt_box[3][1])
location= (clean_pdf_path(pdf_path),page_id,bbox_id,sub_box_id)
location_keys.append(location)
return location_keys


def get_cropped_image_list_via_remove_mfd_part(self,pdf_metadata,client):

images_pool = {}
pdf_path = pdf_metadata['path']
height = pdf_metadata['height']
width = pdf_metadata['width']

if pdf_path.startswith('s3'):
pdf_path = "opendata:"+pdf_path
try:
with read_pdf_from_path(pdf_path, client) as pdf:

for pdf_page_metadata in pdf_metadata['doc_layout_result']:
page_id = pdf_page_metadata['page_id']
page = pdf.load_page(page_id)
ori_im = process_pdf_page_to_image(page, 200, output_width=width,output_height=height)
location_keys = self.collect_location_and_dt_box_from_page_metadata(pdf_path, pdf_page_metadata)
for location in location_keys:
_,_,_,sub_box_id = location
dt_box = np.array(sub_box_id).reshape(-1, 2)
img_crop = get_rotate_crop_image(ori_im, dt_box, padding=10)
images_pool[location] = img_crop

return (pdf_path,images_pool)
except KeyboardInterrupt:
raise
except:
traceback.print_exc()
raise
tqdm.write(f"[Error]: {pdf_path}")
return (pdf_path,{})

class DetImageDataset(Dataset, DatasetUtils,ImageTransformersUtils):
error_count=0
Expand Down Expand Up @@ -407,8 +494,13 @@ def build_pdf_id_and_page_id_pair(self,metadata_filepath,pdf_id_and_page_id_pair
metadata= self.smart_read_json(metadata_filepath)
metadata= np.array_split(metadata, partion_num)[partion_idx]
self.metadata = metadata
self.pdf_id_and_page_id_pair = pdf_id_and_page_id_pair

track_id_to_pdf_id = {metadata[i]['track_id']:i for i in range(len(metadata))}
self.pdf_id_and_page_id_pair=[]
for pdf_id, page_id in pdf_id_and_page_id_pair:
if isinstance(pdf_id, str):
pdf_id = track_id_to_pdf_id[pdf_id]
self.pdf_id_and_page_id_pair.append((pdf_id, page_id))


class AddonDataset(Dataset,DatasetUtils,ImageTransformersUtils):
error_count = 0
Expand Down
2 changes: 2 additions & 0 deletions batch_running_task/task_layout/batch_deal_with_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class BatchLayoutConfig(BatchModeConfig):
accelerated_mfd: bool = False
async_mode: bool = False
result_save_path: str=RESULT_SAVE_PATH
use_lock: bool = True
debug:bool = False
def from_dict(kargs):
return BatchLayoutConfig(**kargs)
def to_dict(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
device = model_configs['model_args']['device']
dpi = model_configs['model_args']['pdf_dpi']

task_name = "layoutV6"
version = "fix_missing_page_version1"

version = "fix_missing_page_version2"
layout_model = None
mfd_model = None
client = None
Expand All @@ -42,16 +42,17 @@
if os.path.exists(CURRENT_END_SIGN):
break
filename = os.path.basename(inputs_path)
assert "layoutV" in inputs_path
#assert "layoutV" in inputs_path
result_save_root = os.path.join(os.path.dirname(os.path.dirname(inputs_path)),version)
inputs_path = os.path.join(INPUT_LOAD_PATH,filename)
#inputs_path = os.path.join(INPUT_LOAD_PATH,filename)

if inputs_path.startswith('s3'):
inputs_path = "opendata:"+inputs_path
# assert inputs_path.startswith('opendata:s3')
# assert result_path.startswith('opendata:s3')
if client is None:
client = build_client()

if not check_path_exists(inputs_path,client):
tqdm.write(f"[Skip]: no {inputs_path} ")
continue
Expand Down Expand Up @@ -90,23 +91,23 @@


result_path = os.path.join(result_save_root, filename_with_partion)

lock_path = os.path.join(LOCKSERVER, "checklocktime", filename_with_partion)
last_start_time = check_lock_and_last_start_time(lock_path,client)
if last_start_time and not args.redo:
date_string = last_start_time
date_format = "%Y-%m-%d %H:%M:%S"
date = datetime.strptime(date_string, date_format)
deltatime = datetime.now() - date
if deltatime < timedelta(hours=1):
tqdm.write(f"[Skip]: {filename_with_partion} is locked by {date_string} created at {last_start_time} [now is {deltatime}]")
continue

create_last_start_time_lock(os.path.join(LOCKSERVER,"createlocktime", filename_with_partion),client)
if args.use_lock:
lock_path = os.path.join(LOCKSERVER, "checklocktime", filename_with_partion)
last_start_time = check_lock_and_last_start_time(lock_path,client)
if last_start_time and not args.redo:
date_string = last_start_time
date_format = "%Y-%m-%d %H:%M:%S"
date = datetime.strptime(date_string, date_format)
deltatime = datetime.now() - date
if deltatime < timedelta(hours=0.1):
tqdm.write(f"[Skip]: {filename_with_partion} is locked by {date_string} created at {last_start_time} [now is {deltatime}]")
continue
create_last_start_time_lock(os.path.join(LOCKSERVER,"createlocktime", filename_with_partion),client)

print(f"now we deal with {inputs_path} to {result_path}")
os.makedirs(os.path.dirname(result_path), exist_ok=True)

if args.debug:raise
if layout_model is None:layout_model = get_layout_model(model_configs,args.accelerated_layout)
if mfd_model is None:mfd_model = get_batch_YOLO_model(model_configs,batch_size=args.inner_batch_size,use_tensorRT=args.accelerated_mfd)
if ocrmodel is None:ocrmodel = ModifiedPaddleOCR(show_log=True)
Expand Down
Loading

0 comments on commit fc5adc0

Please sign in to comment.