Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pybedtools dependency #57

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
26 changes: 13 additions & 13 deletions neusomatic/python/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dataloader import NeuSomaticDataset, matrix_transform
from utils import get_chromosomes_order, prob2phred
from merge_tsvs import merge_tsvs
from defaults import VARTYPE_CLASSES

import torch._utils
try:
Expand All @@ -51,10 +52,10 @@ def get_type(ref, alt):
return "SNP"


def call_variants(net, vartype_classes, call_loader, out_dir, model_tag, use_cuda):
def call_variants(net, call_loader, out_dir, model_tag, use_cuda):
logger = logging.getLogger(call_variants.__name__)
net.eval()
nclasses = len(vartype_classes)
nclasses = len(VARTYPE_CLASSES)
final_preds = {}
none_preds = {}
true_path = {}
Expand Down Expand Up @@ -85,15 +86,15 @@ def call_variants(net, vartype_classes, call_loader, out_dir, model_tag, use_cud
preds = {}
for i, path_ in enumerate(paths[0]):
path = path_.split("/")[-1]
preds[i] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i]]
if vartype_classes[predicted[i]] != "NONE":
preds[i] = [VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i]]
if VARTYPE_CLASSES[predicted[i]] != "NONE":
file_name = "{}/matrices_{}/{}.png".format(
out_dir, model_tag, path)
if not os.path.exists(file_name):
imwrite(file_name, np.array(
non_transformed_matrices[i, :, :, 0:3]))
true_path[path] = file_name
final_preds[path] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i],
final_preds[path] = [VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i],
list(map(lambda x: round(x, 4), F.softmax(
outputs1[i, :], 0).data.cpu().numpy())),
list(map(lambda x: round(x, 4), F.softmax(
Expand All @@ -103,7 +104,7 @@ def call_variants(net, vartype_classes, call_loader, out_dir, model_tag, use_cud
list(map(lambda x: round(x, 4),
outputs3.data.cpu()[i].numpy()))]
else:
none_preds[path] = [vartype_classes[predicted[i]], pos_pred[i], len_pred[i],
none_preds[path] = [VARTYPE_CLASSES[predicted[i]], pos_pred[i], len_pred[i],
list(map(lambda x: round(x, 4), F.softmax(
outputs1[i, :], 0).data.cpu().numpy())),
list(map(lambda x: round(x, 4), F.softmax(
Expand All @@ -119,7 +120,7 @@ def call_variants(net, vartype_classes, call_loader, out_dir, model_tag, use_cud


def pred_vcf_records_path(record):
path, true_path_, pred_all, chroms, vartype_classes, ref_file = record
path, true_path_, pred_all, chroms, ref_file = record
thread_logger = logging.getLogger(
"{} ({})".format(pred_vcf_records_path.__name__, multiprocessing.current_process().name))
try:
Expand Down Expand Up @@ -154,7 +155,7 @@ def pred_vcf_records_path(record):
if sum(pred_probs) < min_acceptable_probmax:
break
amx_prob = np.argmax(pred_probs)
type_pred = vartype_classes[amx_prob]
type_pred = VARTYPE_CLASSES[amx_prob]
if type_pred == "NONE":
break
center_pred = min(max(0, pred[1][0]), Iw - 1)
Expand Down Expand Up @@ -305,14 +306,14 @@ def pred_vcf_records_path(record):
return None


def pred_vcf_records(ref_file, final_preds, true_path, chroms, vartype_classes, num_threads):
def pred_vcf_records(ref_file, final_preds, true_path, chroms, num_threads):
logger = logging.getLogger(pred_vcf_records.__name__)
logger.info(
"Prepare VCF records for predicted somatic variants in this batch.")
map_args = []
for path in final_preds.keys():
map_args.append([path, true_path[path], final_preds[path],
chroms, vartype_classes, ref_file])
chroms, ref_file])

pool = multiprocessing.Pool(num_threads)
try:
Expand Down Expand Up @@ -404,7 +405,6 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads,
with pysam.FastaFile(ref_file) as rf:
chroms = rf.references

vartype_classes = ['DEL', 'INS', 'NONE', 'SNP']
data_transform = matrix_transform((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
num_channels = 119 if ensemble else 26
net = NeuSomaticNet(num_channels)
Expand Down Expand Up @@ -527,9 +527,9 @@ def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads,
continue

final_preds_, none_preds_, true_path_ = call_variants(
net, vartype_classes, call_loader, out_dir, model_tag, use_cuda)
net, call_loader, out_dir, model_tag, use_cuda)
all_vcf_records.extend(pred_vcf_records(
ref_file, final_preds_, true_path_, chroms, vartype_classes, num_threads))
ref_file, final_preds_, true_path_, chroms, num_threads))
all_vcf_records_none.extend(
pred_vcf_records_none(none_preds_, chroms))

Expand Down
15 changes: 8 additions & 7 deletions neusomatic/python/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import torch
import resource

from utils import skip_empty
from defaults import TYPE_CLASS_DICT, VARTYPE_CLASSES

FORMAT = '%(levelname)s %(asctime)-15s %(name)-20s %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

type_class_dict = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3}


class matrix_transform():

Expand Down Expand Up @@ -52,7 +53,7 @@ def candidate_loader_tsv(tsv, open_tsv, idx, i):
anns = list(map(float, fields[4:]))
else:
anns = []
label = type_class_dict[tag.split(".")[4]]
label = TYPE_CLASS_DICT[tag.split(".")[4]]
if not open_tsv:
i_f.close()
return tag, im, anns, label
Expand All @@ -65,7 +66,7 @@ def extract_info_tsv(record):
try:
n_none = 0
with open(tsv, "r") as i_f:
for line in i_f:
for line in skip_empty(i_f):
tag = line.strip().split()[2]
n_none += (1 if "NONE" in tag else 0)
n_var = L - n_none
Expand All @@ -84,7 +85,7 @@ def extract_info_tsv(record):
cnt_var = 0
with open(tsv, "r") as i_f:
i = -1
for i, line in enumerate(i_f):
for i, line in enumerate(skip_empty(i_f)):
fields = line.strip().split()
ii = int(fields[0])
assert ii == i
Expand All @@ -96,12 +97,12 @@ def extract_info_tsv(record):
var_ids.append(j)
j += 1
_, _, _, _, vartype, _, length, _, _ = tag.split(".")
count_class_t[type_class_dict[vartype]] += 1
count_class_t[TYPE_CLASS_DICT[vartype]] += 1
count_class_l[min(int(length), 3)] += 1
if ((cnt_var < max_load_candidates_var) and ("NONE" not in tag)) or (
(cnt_none < max_load_candidates_none) and ("NONE" in tag)):
im = extract_zlib(base64.b64decode(fields[3]))
label = type_class_dict[tag.split(".")[4]]
label = TYPE_CLASS_DICT[tag.split(".")[4]]
if len(fields) > 4:
anns = list(map(float, fields[4:]))
else:
Expand Down
4 changes: 4 additions & 0 deletions neusomatic/python/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
NUM_ENS_FEATURES=93
VCF_HEADER="##fileformat=VCFv4.2"
TYPE_CLASS_DICT = {"DEL": 0, "INS": 1, "NONE": 2, "SNP": 3}
VARTYPE_CLASSES = ['DEL', 'INS', 'NONE', 'SNP']
8 changes: 5 additions & 3 deletions neusomatic/python/extract_postprocess_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import traceback
import logging

from utils import skip_empty
from defaults import VCF_HEADER

def extract_postprocess_targets(input_vcf, min_len, max_dist, pad):
logger = logging.getLogger(extract_postprocess_targets.__name__)
Expand All @@ -22,10 +24,10 @@ def extract_postprocess_targets(input_vcf, min_len, max_dist, pad):
record_sets = []
record_set = []
with open(input_vcf) as i_f, open(out_vcf, "w") as o_f, open(redo_vcf, "w") as r_f, open(redo_bed, "w") as r_b:
r_f.write("##fileformat=VCFv4.2\n")
r_f.write("{}\n".format(VCF_HEADER))
r_f.write("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n")
for line in i_f:
if len(line) < 2 or line[0] == '#':
for line in skip_empty(i_f):
if len(line) < 2:
continue

chrom, pos, _, ref, alt, _, _, _, _, gt = line.strip().split()
Expand Down
61 changes: 37 additions & 24 deletions neusomatic/python/filter_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import multiprocessing

import pysam
import pybedtools
import numpy as np

from utils import safe_read_info_dict
from utils import safe_read_info_dict, run_bedtools_cmd, vcf_2_bed, write_tsv_file, bedtools_sort, get_tmp_file, skip_empty
from defaults import VCF_HEADER


def filter_candidates(candidate_record):
Expand All @@ -27,9 +27,7 @@ def filter_candidates(candidate_record):

records = {}
with open(candidates_vcf) as v_f:
for line in v_f:
if line[0] == "#":
continue
for line in skip_empty(v_f):
if len(line.strip().split()) != 10:
raise RuntimeError(
"Bad VCF line (<10 fields): {}".format(line))
Expand Down Expand Up @@ -259,29 +257,44 @@ def filter_candidates(candidate_record):
final_records.append([chrom, pos - 1, ref, alt, line])
final_records = sorted(final_records, key=lambda x: x[0:2])
if dbsnp:
filtered_bed = pybedtools.BedTool(map(lambda x:
pybedtools.Interval(x[1][0], int(x[1][1]),
int(x[1][
1]) + 1,
x[1][2], x[1][3], str(x[0])),
enumerate(final_records))).sort()
dbsnp = pybedtools.BedTool(dbsnp).each(lambda x:
pybedtools.Interval(x[0], int(x[1]),
int(x[
1]) + 1,
x[3], x[4])).sort()
non_in_dbsnp_1 = filtered_bed.window(dbsnp, w=0, v=True)
non_in_dbsnp_2 = filtered_bed.window(dbsnp, w=0).filter(
lambda x: x[1] != x[7] or x[3] != x[9] or x[4] != x[10]).sort()
filtered_bed = get_tmp_file()
intervals = []
for x in enumerate(final_records):
intervals.append([x[1][0], int(x[1][1]), int(
x[1][1]) + 1, x[1][2], x[1][3], str(x[0])])
write_tsv_file(filtered_bed, intervals)
filtered_bed = bedtools_sort(
filtered_bed, run_logger=thread_logger)

dbsnp_tmp = get_tmp_file()
vcf_2_bed(dbsnp, dbsnp_tmp)
bedtools_sort(dbsnp_tmp, output_fn=dbsnp, run_logger=thread_logger)
non_in_dbsnp_1 = bedtools_window(
filtered_bed, dbsnp, args=" -w 0 -v", run_logger=thread_logger)
non_in_dbsnp_2 = bedtools_window(
filtered_bed, dbsnp, args=" -w 0", run_logger=thread_logger)

tmp_ = get_tmp_file()
with open(non_in_dbsnp_2) as i_f, open(tmp_, "w") as o_f:
for line in skip_empty(i_f):
x = line.strip().split()
if x[1]!=x[7] or x[3]!=x[9] or x[4]!=x[10]:
o_f.write(line)
non_in_dbsnp_2 = tmp_

non_in_dbsnp_ids = []
for x in non_in_dbsnp_1:
non_in_dbsnp_ids.append(int(x[5]))
for x in non_in_dbsnp_2:
non_in_dbsnp_ids.append(int(x[5]))
with open(non_in_dbsnp_1) as i_f:
for line in skip_empty(i_f):
x = line.strip().split("\t")
non_in_dbsnp_ids.append(int(x[5]))
with open(non_in_dbsnp_2) as i_f:
for line in skip_empty(i_f):
x = line.strip().split("\t")
non_in_dbsnp_ids.append(int(x[5]))
final_records = list(map(lambda x: x[1], filter(
lambda x: x[0] in non_in_dbsnp_ids, enumerate(final_records))))
with open(filtered_candidates_vcf, "w") as o_f:
o_f.write("##fileformat=VCFv4.2\n")
o_f.write("{}\n".format(VCF_HEADER))
o_f.write(
"#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE\n")
for record in final_records:
Expand Down
Loading