diff --git a/neusomatic/python/call.py b/neusomatic/python/call.py index be1b40a..abe053e 100755 --- a/neusomatic/python/call.py +++ b/neusomatic/python/call.py @@ -350,19 +350,15 @@ def pred_vcf_records(ref_file, final_preds, chroms, num_threads): map_args.append([path, final_preds[path], chroms, ref_file]) - if num_threads == 1: - all_vcf_records = [] - for w in map_args: - all_vcf_records.append(pred_vcf_records_path(w)) - else: - pool = multiprocessing.Pool(num_threads) try: - all_vcf_records = pool.map_async( - pred_vcf_records_path, map_args).get() - pool.close() + if num_threads == 1: + all_vcf_records = [pred_vcf_records_path(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + all_vcf_records = pool.map_async( + pred_vcf_records_path, map_args).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception @@ -786,13 +782,14 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads, current_L = 0 candidate_files = [] - pool = multiprocessing.Pool(num_threads) try: - all_records = pool.map_async(single_thread_call, map_args).get() - pool.close() + if num_threads == 1: + all_records = [single_thread_call(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + all_records = pool.map_async(single_thread_call, map_args).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception diff --git a/neusomatic/python/dataloader.py b/neusomatic/python/dataloader.py index 0fe5470..504bd6f 100755 --- a/neusomatic/python/dataloader.py +++ b/neusomatic/python/dataloader.py @@ -203,24 +203,17 @@ def __init__(self, roots, max_load_candidates, transform=None, max_load_, nclasses_t, nclasses_l, self.matrix_dtype]) Ls_.append(self.Ls[i_b]) logger.info("Len's of tsv files in this batch: {}".format(Ls_)) - if len(map_args) == 1: - records_ = [extract_info_tsv(map_args[0])] - else: + try: if num_threads == 1: - records_ = [] - for w in map_args: - records_.append(extract_info_tsv(w)) + records_ = [extract_info_tsv(w) for w in map_args] else: - pool = multiprocessing.Pool(num_threads) - try: + with multiprocessing.Pool(num_threads) as pool: records_ = pool.map_async( extract_info_tsv, map_args).get() - pool.close() - except Exception as inst: - pool.close() - logger.error(inst) - traceback.print_exc() - raise Exception + except Exception as inst: + logger.error(inst) + traceback.print_exc() + raise Exception for o in records_: if o is None: diff --git a/neusomatic/python/extend_features.py b/neusomatic/python/extend_features.py index 5d4b482..54f43d4 100755 --- a/neusomatic/python/extend_features.py +++ b/neusomatic/python/extend_features.py @@ -330,7 +330,6 @@ def extend_features(candidates_vcf, n_variants = len(all_variants) logger.info("Number of variants: {}".format(n_variants)) split_len = (n_variants + num_threads - 1) // num_threads - pool = multiprocessing.Pool(num_threads) map_args = [] nei_cluster = [] batch = [] @@ -382,8 +381,11 @@ def extend_features(candidates_vcf, "tBAM_REF_InDel_1bp", "tBAM_ALT_InDel_3bp", "tBAM_ALT_InDel_2bp", "tBAM_ALT_InDel_1bp", "InDel_Length"]) try: - ext_features = pool.map_async(extract_features, map_args).get() - pool.close() + if num_threads == 1: + ext_features = [extract_features(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + ext_features = pool.map_async(extract_features, map_args).get() with open(output_tsv, "w") as o_f: o_f.write("\t".join(header) + "\n") for features in ext_features: @@ -392,7 +394,6 @@ def extend_features(candidates_vcf, "\t".join(map(lambda x: str(x).replace("nan", "0"), w)) + "\n") except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception diff --git a/neusomatic/python/generate_dataset.py b/neusomatic/python/generate_dataset.py index 4a9e93d..d37193a 100755 --- a/neusomatic/python/generate_dataset.py +++ b/neusomatic/python/generate_dataset.py @@ -2103,13 +2103,11 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be for w in map_args: records_data.append(find_records(w)) else: - pool = multiprocessing.Pool(num_threads) try: - records_data = pool.map_async(find_records, map_args).get() - pool.close() + with multiprocessing.Pool(num_threads) as pool: + records_data = pool.map_async(find_records, map_args).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception for o in records_data: @@ -2184,16 +2182,14 @@ def generate_dataset(work, truth_vcf_file, mode, tumor_pred_vcf_file, region_be if num_threads == 1: records_done_ = [parallel_generation([map_args_records, matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed])] else: - pool = multiprocessing.Pool(num_threads) try: split_len=max(1,len_records//num_threads) - records_done_ = pool.map_async( - parallel_generation, [[map_args_records[i_split:i_split+(split_len)],matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed] - for i_split in range(0, len_records, split_len)]).get() - pool.close() + with multiprocessing.Pool(num_threads) as pool: + records_done_ = pool.map_async( + parallel_generation, [[map_args_records[i_split:i_split+(split_len)],matrix_base_pad, chrom_lengths, tumor_count_bed, normal_count_bed] + for i_split in range(0, len_records, split_len)]).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception diff --git a/neusomatic/python/long_read_indelrealign.py b/neusomatic/python/long_read_indelrealign.py index 53c1635..e5358aa 100755 --- a/neusomatic/python/long_read_indelrealign.py +++ b/neusomatic/python/long_read_indelrealign.py @@ -783,7 +783,6 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be num_threads): logger = logging.getLogger(parallel_correct_bam.__name__) if num_threads > 1: - pool = multiprocessing.Pool(num_threads) bam_header = output_bam[:-4] + ".header" with open(bam_header, "w") as h_f: h_f.write(pysam.view("-H", input_bam,)) @@ -795,10 +794,9 @@ def parallel_correct_bam(work, input_bam, output_bam, ref_fasta_file, realign_be (work, input_bam, realign_bed_file, ref_fasta_file, chrom)) try: - sams = pool.map_async(correct_bam_chrom, map_args).get() - pool.close() + with multiprocessing.Pool(num_threads) as pool: + sams = pool.map_async(correct_bam_chrom, map_args).get() except Exception as inst: - pool.close() logger.error(inst) traceback.print_exc() raise Exception @@ -1356,7 +1354,6 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r map(lambda x: [x[0], int(x[1]), int(x[2])], target_regions)) get_var = True if output_vcf else False - pool = multiprocessing.Pool(num_threads) map_args = [] for target_region in target_regions: map_args.append((work, ref_fasta_file, target_region, pad, chunk_size, @@ -1368,10 +1365,12 @@ def long_read_indelrealign(work, input_bam, output_bam, output_vcf, output_not_r shuffle(map_args) try: - realign_output = pool.map_async(run_realignment, map_args).get() - pool.close() + if num_threads == 1: + realign_output = [run_realignment(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + realign_output = pool.map_async(run_realignment, map_args).get() except Exception as inst: - pool.close() logger.error(inst) traceback.print_exc() raise Exception diff --git a/neusomatic/python/preprocess.py b/neusomatic/python/preprocess.py index bc5aa9a..be3b974 100755 --- a/neusomatic/python/preprocess.py +++ b/neusomatic/python/preprocess.py @@ -50,7 +50,6 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam, if filtered_candidates_vcf: logger.info("Filter candidates.") if restart or not os.path.exists(filtered_candidates_vcf): - pool = multiprocessing.Pool(num_threads) map_args = [] for i, (raw_vcf, count_bed, split_region_bed) in enumerate(scan_outputs): filtered_vcf = os.path.join(os.path.dirname( @@ -59,12 +58,14 @@ def process_split_region(tn, work, region, reference, mode, alignment_bam, min_ao, snp_min_af, snp_min_bq, snp_min_ao, ins_min_af, del_min_af, del_merge_min_af, ins_merge_min_af, merge_r)) try: - filtered_candidates_vcfs = pool.map_async( - filter_candidates, map_args).get() - pool.close() + if num_threads == 1: + filtered_candidates_vcfs = [filter_candidates(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + filtered_candidates_vcfs = pool.map_async( + filter_candidates, map_args).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception @@ -141,13 +142,14 @@ def get_ensemble_beds(work, reference, ensemble_bed, split_regions, matrix_base_ ensemble_beds.append(ensemble_bed_region_file) map_args.append((reference, ensemble_bed, split_region_, ensemble_bed_region_file, matrix_base_pad)) - pool = multiprocessing.Pool(num_threads) try: - outputs = pool.map_async(get_ensemble_region, map_args).get() - pool.close() + if num_threads == 1: + outputs = [get_ensemble_region(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + outputs = pool.map_async(get_ensemble_region, map_args).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception for o in outputs: @@ -647,14 +649,15 @@ def preprocess(work, mode, reference, region_bed, tumor_bam, normal_bam, dbsnp, strict_labeling, tsv_batch_size]) - pool = multiprocessing.Pool(num_threads) try: - done_gen = pool.map_async( - generate_dataset_region_parallel, map_args_gen).get() - pool.close() + if num_threads == 1: + done_gen = [generate_dataset_region_parallel(w) for w in map_args_gen] + else: + with multiprocessing.Pool(num_threads) as pool: + done_gen = pool.map_async( + generate_dataset_region_parallel, map_args_gen).get() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception diff --git a/neusomatic/python/resolve_variants.py b/neusomatic/python/resolve_variants.py index 4e7c009..3cf1902 100755 --- a/neusomatic/python/resolve_variants.py +++ b/neusomatic/python/resolve_variants.py @@ -478,16 +478,14 @@ def resolve_variants(input_bam, resolved_vcf, reference, target_vcf_file, out_variants_list = [] i = 0 while i < len(map_args): - pool = multiprocessing.Pool(num_threads) batch_i_s = i batch_i_e = min(i + n_per_bacth, len(map_args)) - out_variants_list.extend(pool.map_async( - find_resolved_variants, map_args[batch_i_s:batch_i_e]).get()) + with multiprocessing.Pool(num_threads) as pool: + out_variants_list.extend(pool.map_async( + find_resolved_variants, map_args[batch_i_s:batch_i_e]).get()) i = batch_i_e - pool.close() except Exception as inst: logger.error(inst) - pool.close() traceback.print_exc() raise Exception else: diff --git a/neusomatic/python/scan_alignments.py b/neusomatic/python/scan_alignments.py index 3e7b56a..21e7210 100755 --- a/neusomatic/python/scan_alignments.py +++ b/neusomatic/python/scan_alignments.py @@ -198,12 +198,13 @@ def scan_alignments(work, merge_d_for_scan, scan_alignments_binary, input_bam, i), "count.bed.gz"), os.path.join(work, "work.{}".format(i), "region.bed")] - pool = multiprocessing.Pool(num_threads) try: - outputs = pool.map_async(run_scan_alignments, map_args).get() - pool.close() + if num_threads == 1: + outputs = [run_scan_alignments(w) for w in map_args] + else: + with multiprocessing.Pool(num_threads) as pool: + outputs = pool.map_async(run_scan_alignments, map_args).get() except Exception as inst: - pool.close() logger.error(inst) traceback.print_exc() raise Exception