Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skipping multiprocess pool for single thread scenario #74

Open
wants to merge 1 commit into
base: accelerate_preprocess_new
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions neusomatic/python/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,19 +350,15 @@ def pred_vcf_records(ref_file, final_preds, chroms, num_threads):
map_args.append([path, final_preds[path],
chroms, ref_file])

if num_threads == 1:
all_vcf_records = []
for w in map_args:
all_vcf_records.append(pred_vcf_records_path(w))
else:
pool = multiprocessing.Pool(num_threads)
try:
all_vcf_records = pool.map_async(
pred_vcf_records_path, map_args).get()
pool.close()
if num_threads == 1:
all_vcf_records = [pred_vcf_records_path(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
all_vcf_records = pool.map_async(
pred_vcf_records_path, map_args).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception

Expand Down Expand Up @@ -786,13 +782,14 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads,
current_L = 0
candidate_files = []

pool = multiprocessing.Pool(num_threads)
try:
all_records = pool.map_async(single_thread_call, map_args).get()
pool.close()
if num_threads == 1:
all_records = [single_thread_call(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
all_records = pool.map_async(single_thread_call, map_args).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception

Expand Down
21 changes: 7 additions & 14 deletions neusomatic/python/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,24 +203,17 @@ def __init__(self, roots, max_load_candidates, transform=None,
max_load_, nclasses_t, nclasses_l, self.matrix_dtype])
Ls_.append(self.Ls[i_b])
logger.info("Len's of tsv files in this batch: {}".format(Ls_))
if len(map_args) == 1:
records_ = [extract_info_tsv(map_args[0])]
else:
try:
if num_threads == 1:
records_ = []
for w in map_args:
records_.append(extract_info_tsv(w))
records_ = [extract_info_tsv(w) for w in map_args]
else:
pool = multiprocessing.Pool(num_threads)
try:
with multiprocessing.Pool(num_threads) as pool:
records_ = pool.map_async(
extract_info_tsv, map_args).get()
pool.close()
except Exception as inst:
pool.close()
logger.error(inst)
traceback.print_exc()
raise Exception
except Exception as inst:
logger.error(inst)
traceback.print_exc()
raise Exception

for o in records_:
if o is None:
Expand Down
9 changes: 5 additions & 4 deletions neusomatic/python/extend_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ def extend_features(candidates_vcf,
n_variants = len(all_variants)
logger.info("Number of variants: {}".format(n_variants))
split_len = (n_variants + num_threads - 1) // num_threads
pool = multiprocessing.Pool(num_threads)
map_args = []
nei_cluster = []
batch = []
Expand Down Expand Up @@ -382,8 +381,11 @@ def extend_features(candidates_vcf,
"tBAM_REF_InDel_1bp", "tBAM_ALT_InDel_3bp", "tBAM_ALT_InDel_2bp", "tBAM_ALT_InDel_1bp", "InDel_Length"])

try:
ext_features = pool.map_async(extract_features, map_args).get()
pool.close()
if num_threads == 1:
ext_features = [extract_features(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
ext_features = pool.map_async(extract_features, map_args).get()
with open(output_tsv, "w") as o_f:
o_f.write("\t".join(header) + "\n")
for features in ext_features:
Expand All @@ -392,7 +394,6 @@ def extend_features(candidates_vcf,
"\t".join(map(lambda x: str(x).replace("nan", "0"), w)) + "\n")
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception

Expand Down
16 changes: 6 additions & 10 deletions neusomatic/python/generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,13 +2103,11 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be
for w in map_args:
records_data.append(find_records(w))
else:
pool = multiprocessing.Pool(num_threads)
try:
records_data = pool.map_async(find_records, map_args).get()
pool.close()
with multiprocessing.Pool(num_threads) as pool:
records_data = pool.map_async(find_records, map_args).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception
for o in records_data:
Expand Down Expand Up @@ -2184,16 +2182,14 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be
if num_threads == 1:
records_done_ = [parallel_generation([map_args_records, matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed])]
else:
pool = multiprocessing.Pool(num_threads)
try:
split_len=max(1,len_records//num_threads)
records_done_ = pool.map_async(
parallel_generation, [[map_args_records[i_split:i_split+(split_len)],matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed]
for i_split in range(0, len_records, split_len)]).get()
pool.close()
with multiprocessing.Pool(num_threads) as pool:
records_done_ = pool.map_async(
parallel_generation, [[map_args_records[i_split:i_split+(split_len)],matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed]
for i_split in range(0, len_records, split_len)]).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception

Expand Down
15 changes: 7 additions & 8 deletions neusomatic/python/long_read_indelrealign.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,6 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be
num_threads):
logger = logging.getLogger(parallel_correct_bam.__name__)
if num_threads > 1:
pool = multiprocessing.Pool(num_threads)
bam_header = output_bam[:-4] + ".header"
with open(bam_header, "w") as h_f:
h_f.write(pysam.view("-H", input_bam,))
Expand All @@ -795,10 +794,9 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be
(work, input_bam, realign_bed_file, ref_fasta_file, chrom))

try:
sams = pool.map_async(correct_bam_chrom, map_args).get()
pool.close()
with multiprocessing.Pool(num_threads) as pool:
sams = pool.map_async(correct_bam_chrom, map_args).get()
except Exception as inst:
pool.close()
logger.error(inst)
traceback.print_exc()
raise Exception
Expand Down Expand Up @@ -1356,7 +1354,6 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r
map(lambda x: [x[0], int(x[1]), int(x[2])], target_regions))

get_var = True if output_vcf else False
pool = multiprocessing.Pool(num_threads)
map_args = []
for target_region in target_regions:
map_args.append((work, ref_fasta_file, target_region, pad, chunk_size,
Expand All @@ -1368,10 +1365,12 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r

shuffle(map_args)
try:
realign_output = pool.map_async(run_realignment, map_args).get()
pool.close()
if num_threads == 1:
realign_output = [run_realignment(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
realign_output = pool.map_async(run_realignment, map_args).get()
except Exception as inst:
pool.close()
logger.error(inst)
traceback.print_exc()
raise Exception
Expand Down
31 changes: 17 additions & 14 deletions neusomatic/python/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam,
if filtered_candidates_vcf:
logger.info("Filter candidates.")
if restart or not os.path.exists(filtered_candidates_vcf):
pool = multiprocessing.Pool(num_threads)
map_args = []
for i, (raw_vcf, count_bed, split_region_bed) in enumerate(scan_outputs):
filtered_vcf = os.path.join(os.path.dirname(
Expand All @@ -59,12 +58,14 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam,
min_ao, snp_min_af, snp_min_bq, snp_min_ao, ins_min_af, del_min_af, del_merge_min_af,
ins_merge_min_af, merge_r))
try:
filtered_candidates_vcfs = pool.map_async(
filter_candidates, map_args).get()
pool.close()
if num_threads == 1:
filtered_candidates_vcfs = [filter_candidates(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
filtered_candidates_vcfs = pool.map_async(
filter_candidates, map_args).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception

Expand Down Expand Up @@ -141,13 +142,14 @@ def get_ensemble_beds(work, reference, ensemble_bed, split_regions, matrix_base_
ensemble_beds.append(ensemble_bed_region_file)
map_args.append((reference, ensemble_bed, split_region_,
ensemble_bed_region_file, matrix_base_pad))
pool = multiprocessing.Pool(num_threads)
try:
outputs = pool.map_async(get_ensemble_region, map_args).get()
pool.close()
if num_threads == 1:
outputs = [get_ensemble_region(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
outputs = pool.map_async(get_ensemble_region, map_args).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception
for o in outputs:
Expand Down Expand Up @@ -647,14 +649,15 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp,
strict_labeling,
tsv_batch_size])

pool = multiprocessing.Pool(num_threads)
try:
done_gen = pool.map_async(
generate_dataset_region_parallel, map_args_gen).get()
pool.close()
if num_threads == 1:
done_gen = [generate_dataset_region_parallel(w) for w in map_args_gen]
else:
with multiprocessing.Pool(num_threads) as pool:
done_gen = pool.map_async(
generate_dataset_region_parallel, map_args_gen).get()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception

Expand Down
8 changes: 3 additions & 5 deletions neusomatic/python/resolve_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,16 +478,14 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file,
out_variants_list = []
i = 0
while i < len(map_args):
pool = multiprocessing.Pool(num_threads)
batch_i_s = i
batch_i_e = min(i + n_per_bacth, len(map_args))
out_variants_list.extend(pool.map_async(
find_resolved_variants, map_args[batch_i_s:batch_i_e]).get())
with multiprocessing.Pool(num_threads) as pool:
out_variants_list.extend(pool.map_async(
find_resolved_variants, map_args[batch_i_s:batch_i_e]).get())
i = batch_i_e
pool.close()
except Exception as inst:
logger.error(inst)
pool.close()
traceback.print_exc()
raise Exception
else:
Expand Down
9 changes: 5 additions & 4 deletions neusomatic/python/scan_alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,13 @@ def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam,
i), "count.bed.gz"),
os.path.join(work, "work.{}".format(i), "region.bed")]

pool = multiprocessing.Pool(num_threads)
try:
outputs = pool.map_async(run_scan_alignments, map_args).get()
pool.close()
if num_threads == 1:
outputs = [run_scan_alignments(w) for w in map_args]
else:
with multiprocessing.Pool(num_threads) as pool:
outputs = pool.map_async(run_scan_alignments, map_args).get()
except Exception as inst:
pool.close()
logger.error(inst)
traceback.print_exc()
raise Exception
Expand Down