Skip to content

Commit

Permalink
Merge pull request #4 from Clinical-Genomics-Lund/add-validation-note…
Browse files Browse the repository at this point in the history
…book

Add tb and saureus validation notebooks
  • Loading branch information
ryanjameskennedy authored Jul 29, 2024
2 parents ba0d57a + 0ecec81 commit 851ab88
Show file tree
Hide file tree
Showing 10 changed files with 956 additions and 35 deletions.
18 changes: 16 additions & 2 deletions jasentool/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __csv_file(group, required, help):

def __sh_file(group, required, help):
"""Add sh_file argument to group"""
group.add_argument('--sh_file', required=required, help=help)
group.add_argument('--sh_file', required=required, default=None, help=help)

def __bam_file(group, required):
"""Add bam_file argument to group"""
Expand Down Expand Up @@ -164,6 +164,11 @@ def __combined_output(group):
group.add_argument('--combined_output', dest='combined_output', action='store_true',
help='combine all of the outputs into one output')

def __generate_matrix(group):
"""Add generate_matrix argument to group"""
group.add_argument('--generate_matrix', dest='generate_matrix', action='store_true',
help='generate cgmlst matrix')

def __save_dbs(group):
"""Save all intermediary dbs created for TBProfiler db convergence"""
group.add_argument('--save_dbs', dest='save_dbs', action='store_true',
Expand All @@ -174,6 +179,12 @@ def __sample_sheet(group, required):
group.add_argument('--sample_sheet', required=required, dest='sample_sheet',
action='store_true', help='sample sheet input')

def __alter_sample_id(group, required):
"""Add sample_sheet argument to group"""
group.add_argument('--alter_sample_id', required=required,
dest='alter_sample_id', action='store_true', default=False,
help='alter sample id to be lims ID + sequencing run')

def __cpus(group):
"""Add cpus argument to group"""
group.add_argument('--cpus', dest='cpus', type=int, default=2, help='input cpus')
Expand Down Expand Up @@ -224,6 +235,7 @@ def get_main_parser():
__db_collection(group, required=True)
with arg_group(parser, 'optional arguments') as group:
__combined_output(group)
__generate_matrix(group)
__uri(group)
__prefix(group)
__help(group)
Expand All @@ -241,6 +253,7 @@ def get_main_parser():
__assay(group, required=False)
__platform(group, required=False)
__sample_sheet(group, required=False)
__alter_sample_id(group, required=False)
__help(group)

with subparser(sub_parsers, 'convert', 'Convert file format') as parser:
Expand All @@ -255,13 +268,14 @@ def get_main_parser():
with subparser(sub_parsers, 'fix', 'Fix bjorn microbiology csv file') as parser:
with arg_group(parser, 'required named arguments') as group:
__csv_file(group, required=True, help='path to bjorn csv file')
__sh_file(group, required=True, help='path to bjorn sh file')
__output_file(group, required=True, help='path to fixed output csv file')
with arg_group(parser, 'optional arguments') as group:
__sh_file(group, required=False, help='path to bjorn sh file')
__remote_dir(group, required=False)
__remote_hostname(group, required=False)
__remote(group, required=False)
__auto_start(group, required=False)
__alter_sample_id(group, required=False)
__help(group)

with subparser(sub_parsers, 'converge', 'Converge TB mutation catalogues') as parser:
Expand Down
7 changes: 4 additions & 3 deletions jasentool/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
class Fix:
"""Class that fixes csvs for start_nextflow_analysis.pl"""
@staticmethod
def fix_csv(input_file, output_fpath):
def fix_csv(input_file, output_fpath, alter_sample_id):
"""Convert the provided bjorn csvs into new jasen-compatible csvs"""
assays = []
out_fpaths = []
with open(input_file, 'r', encoding="utf-8") as csvfile:
samples = pd.read_csv(csvfile)
samples.insert(2, 'sample_name', samples['id'])
samples['id'] = samples['id'].str.lower() + "_" + samples['sequencing_run'].str.lower()
samples['assay'] = samples['species']
samples['id'] = samples['clarity_sample_id'].str.lower() + "_" + samples['sequencing_run'].str.lower() if alter_sample_id else samples['id']
if "species" in samples.columns:
samples['assay'] = samples['species']
for assay, df_assay in samples.groupby('assay'):
out_fpath = f'{os.path.splitext(output_fpath)[0]}_{assay}.csv'
df_assay.to_csv(out_fpath, encoding='utf-8', index=False)
Expand Down
14 changes: 7 additions & 7 deletions jasentool/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def validate(self, options):
output_fpaths = self._get_output_fpaths(input_files, options.output_dir,
options.output_file, options.prefix,
options.combined_output)
validate = Validate()
validate.run(input_files, output_fpaths, options.db_collection, options.combined_output)
validate = Validate(options.input_dir, options.db_collection)
validate.run(input_files, output_fpaths, options.combined_output, options.generate_matrix)

def missing(self, options):
"""Execute search for missing samples from new pipeline results"""
Expand All @@ -95,16 +95,16 @@ def missing(self, options):
db.initialize(options.db_name)
if options.sample_sheet:
csv_dict = missing.parse_sample_sheet(options.input_file[0], options.restore_dir)
utils.write_out_csv(csv_dict, options.assay, options.platform, options.output_file)
utils.write_out_csv(csv_dict, options.assay, options.platform, options.output_file, options.alter_sample_id)
if options.analysis_dir:
log_fpath = os.path.splitext(options.missing_log)[0] + ".log"
empty_fpath = os.path.splitext(options.output_file)[0] + "_empty.csv"
meta_dict = db.find(options.db_collection, {"metadata.QC": "OK"}, db.get_meta_fields())
analysis_dir_fnames = missing.parse_dir(options.analysis_dir)
csv_dict, missing_samples_txt = missing.find_missing(meta_dict, analysis_dir_fnames, options.restore_dir)
empty_files_dict, csv_dict = missing.remove_empty_files(csv_dict)
utils.write_out_csv(csv_dict, options.assay, options.platform, options.output_file)
utils.write_out_csv(empty_files_dict, options.assay, options.platform, empty_fpath)
utils.write_out_csv(csv_dict, options.assay, options.platform, options.output_file, options.alter_sample_id)
utils.write_out_csv(empty_files_dict, options.assay, options.platform, empty_fpath, options.alter_sample_id)
utils.write_out_txt(missing_samples_txt, log_fpath)
if options.restore_file:
bash_fpath = os.path.splitext(options.restore_file)[0] + ".sh"
Expand All @@ -126,8 +126,8 @@ def fix(self, options):
"""Execute fixing of file to desired format(s)"""
utils = Utils()
fix = Fix()
csv_files, assays = fix.fix_csv(options.csv_file, options.output_file)
batch_files = fix.fix_sh(options.sh_file, options.output_file, assays)
csv_files, assays = fix.fix_csv(options.csv_file, options.output_file, options.alter_sample_id)
batch_files = fix.fix_sh(options.sh_file, options.output_file, assays) if options.sh_file else options.sh_file
if (options.remote or options.auto_start) and batch_files:
utils.copy_batch_and_csv_files(batch_files, csv_files, options.remote_dir, options.remote_hostname, options.auto_start or options.remote)
if options.auto_start:
Expand Down
194 changes: 194 additions & 0 deletions jasentool/matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Module for validating pipelines"""

import os
import sys
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from jasentool.database import Database

class Matrix:
"""Class to validate old pipeline (cgviz) with new pipeline (jasen)"""
def __init__(self, input_dir, db_collection):
self.input_dir = input_dir
self.db_collection = db_collection

def search(self, search_query, search_kw, search_list):
"""Search for query in list of arrays"""
return [element for element in search_list if element[search_kw] == search_query]

def get_null_allele_counts(self, input_files):
"""Get null position counts"""
null_alleles_count = {}
sample_null_count = {}
for input_file in input_files:
sample_id = os.path.basename(input_file).replace("_result.json", "")
sample_null_count[sample_id] = 0
with open(input_file, 'r', encoding="utf-8") as fin:
sample_json = json.load(fin)
jasen_cgmlst = self.search("cgmlst", "type", sample_json["typing_result"])
jasen_cgmlst_alleles = dict(jasen_cgmlst[0]["result"]["alleles"])
for allele in jasen_cgmlst_alleles:
if type(jasen_cgmlst_alleles[allele]) == str:
sample_null_count[sample_id] += 1
if allele in null_alleles_count:
null_alleles_count[allele] += 1
else:
null_alleles_count[allele] = 1
print(f"The average number of missing alleles per sample is {sum(sample_null_count.values()) / len(sample_null_count.values())}")
return null_alleles_count, sample_null_count

def get_cgviz_cgmlst_data(self, sample_id):
"""Get sample mongodb data"""
mdb_cgmlst = list(Database.get_cgmlst(self.db_collection, {"id": sample_id, "metadata.QC": "OK"}))
try:
mdb_cgmlst_alleles = mdb_cgmlst[0]["alleles"]
return mdb_cgmlst_alleles
except IndexError:
print(f"IndexError re sample {sample_id}")
return False

def get_jasen_cgmlst_data(self, sample_id):
"""Get sample input file data"""
input_file = os.path.join(self.input_dir, sample_id + "_result.json")
with open(input_file, 'r', encoding="utf-8") as fin:
sample_json = json.load(fin)
jasen_cgmlst = self.search("cgmlst", "type", sample_json["typing_result"])
jasen_cgmlst_alleles = list(jasen_cgmlst[0]["result"]["alleles"].values())
return jasen_cgmlst_alleles

def compare_cgmlst_alleles(self, row_cgmlst_alleles, col_cgmlst_alleles):
"""Parse through cgmlst alleles of old and new pipeline and compare results"""
mismatch_count = 0
null_values = ["-", "EXC", "INF", "LNF", "PLNF", "PLOT3", "PLOT5", "LOTSC", "NIPH", "NIPHEM", "PAMA", "ASM", "ALM"]
for idx, row_allele in enumerate(row_cgmlst_alleles):
col_allele = col_cgmlst_alleles[idx]
if row_allele in null_values or col_allele in null_values:
continue
try:
if int(row_allele) != int(col_allele):
mismatch_count += 1
except ValueError:
print(f"One following alleles are not in integer format: {row_allele} (row) or {col_allele} (column)")
return mismatch_count

def generate_matrix(self, sample_ids, get_cgmlst_data):
matrix_df = pd.DataFrame(index=sample_ids, columns=sample_ids)
id_allele_dict = {sample_id: get_cgmlst_data(sample_id) for sample_id in sample_ids}
print(f"The sample id - alleles dict is approximately {sys.getsizeof(id_allele_dict)} bytes in size")
for row_sample in sample_ids:
row_sample_cgmlst = id_allele_dict[row_sample]
for col_sample in sample_ids:
col_sample_cgmlst = id_allele_dict[col_sample]
if row_sample_cgmlst and col_sample_cgmlst:
matrix_df.loc[row_sample, col_sample] = self.compare_cgmlst_alleles(row_sample_cgmlst, col_sample_cgmlst)
return matrix_df

def plot_heatmap(self, distance_df, output_plot_fpath):
plt.figure(figsize=(10, 8))
sns.heatmap(distance_df, annot=True, cmap="coolwarm", center=0)
plt.title("Differential Matrix Heatmap of cgmlst")
plt.xlabel("Jasen")
plt.ylabel("Cgviz")
plt.savefig(output_plot_fpath, dpi=600)

def plot_barplot(self, count_dict, output_plot_fpath):
filtered_dict = {k: v for k, v in count_dict.items() if v >= 1000}
sorted_filtered_dict = dict(sorted(filtered_dict.items(), key=lambda item: item[1]))
categories = list(sorted_filtered_dict.keys())
counts = list(sorted_filtered_dict.values())

print(f"The number of alleles that aren't null for more than 1000 samples is {len(categories)}")

plt.figure(figsize=(10, 8))
bars = plt.bar(categories, counts, color="skyblue")

# Add titles and labels
plt.xlabel("Alleles")
plt.ylabel("Count")
plt.title("Null Allele Count Bar Plot")

# Rotate the x-axis labels by 90 degrees
plt.xticks(rotation=90)

# Add value labels on top of the bars
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 1, yval, ha="center", va="bottom")

plt.tight_layout()
plt.savefig(output_plot_fpath, dpi=600)

def plot_matrix_boxplot(self, df, output_plot_fpath):
plt.figure(figsize=(10, 8))
counts = list(df["sum"])
sample_ids = list(df["SampleID"])
plt.boxplot(counts)

# Add jittered data points
jitter = 0.04 # Adjust the jitter as needed
x_jitter = np.random.normal(1, jitter, size=len(counts))
plt.scatter(x_jitter, counts, alpha=0.5, color="blue")

# Set labels and title
plt.xlabel("Samples")
plt.ylabel("Sum of sample allele differences")
plt.title("Summed differential matrix of distances between pipelines' cgMLST results")

# Annotate outliers
for i, count in enumerate(counts):
if count > 250000 or count < -750000:
if float(x_jitter[i]) < 1:
plt.annotate(f"{sample_ids[i]}", xy=(x_jitter[i] - 0.01, count), xytext=(x_jitter[i] - 0.01, count),
horizontalalignment="right", fontsize=8)
else:
plt.annotate(f"{sample_ids[i]}", xy=(x_jitter[i] - 0.01, count), xytext=(x_jitter[i] + 0.01, count),
horizontalalignment="left", fontsize=8)

plt.tight_layout()
plt.savefig(output_plot_fpath, dpi=600)

def plot_boxplot(self, count_dict, output_plot_fpath):
counts = list(count_dict.values())
plt.figure(figsize=(10, 8)) # Optional: set the figure size
plt.boxplot(counts, vert=True, patch_artist=True) # `vert=True` for vertical boxplot, `patch_artist=True` for filled boxes

# Add title and labels
plt.xlabel("Null allele count")
plt.title("Number of null alleles per sample")

min_value = np.min(counts)

# Label the minimum value on the plot
plt.annotate(f"Min: {min_value}", xy=(1, min_value), xytext=(1.05, min_value),
arrowprops=dict(facecolor="black", shrink=0.05),
horizontalalignment="left")

plt.savefig(output_plot_fpath, dpi=600)

def run(self, input_files, output_fpaths, generate_matrix):
# heatmap_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen_heatmap.png")
output_csv_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen.csv")
boxplot_matrix_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "summed_differential_matrix_boxplot.png")
barplot_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "null_alleles_barplot.png")
boxplot_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "sample_null_boxplot.png")
null_alleles_count, sample_null_count = self.get_null_allele_counts(input_files)
self.plot_boxplot(sample_null_count, boxplot_fpath)
self.plot_barplot(null_alleles_count, barplot_fpath)
if generate_matrix:
sample_ids = [os.path.basename(input_file).replace("_result.json", "") for input_file in input_files]
cgviz_matrix_df = self.generate_matrix(sample_ids, self.get_cgviz_cgmlst_data)
jasen_matrix_df = self.generate_matrix(sample_ids, self.get_jasen_cgmlst_data)
distance_df = jasen_matrix_df - cgviz_matrix_df
distance_df = distance_df.astype(float)
distance_df.to_csv(output_csv_fpath, index=True, header=True)
# self.plot_heatmap(distance_df, output_plot_fpath)
if os.path.exists(output_csv_fpath):
distance_df = pd.read_csv(output_csv_fpath, index_col=0)
distance_df["sum"] = distance_df.sum(axis=1)
distance_df = distance_df.reset_index()
distance_df.rename(columns={'index': 'SampleID'}, inplace=True)
filtered_df = distance_df[["SampleID", "sum"]]
self.plot_matrix_boxplot(filtered_df, boxplot_matrix_fpath)
13 changes: 9 additions & 4 deletions jasentool/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@
class Utils:
"""Class containing utilities used throughout jasentool"""
@staticmethod
def write_out_csv(csv_dict, assay, platform, out_fpath):
def write_out_csv(csv_dict, assay, platform, out_fpath, alter_sample_id=False):
"""Write out file as csv"""
with open(out_fpath, 'w+', encoding="utf-8") as csvfile:
fieldnames = ["id", "clarity_sample_id", "group", "species", "assay",
fieldnames = ["id", "clarity_sample_id", "sample_name", "group", "species", "assay",
"platform", "sequencing_run", "read1", "read2"] #header
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for sample in csv_dict:
row_dict = {"id": sample, "clarity_sample_id": csv_dict[sample][0],
lims_id = csv_dict[sample][0]
sequencing_run = csv_dict[sample][3]
sample_id = str(lims_id.lower() + "_" + sequencing_run.lower()) if alter_sample_id else sample
row_dict = {"id": sample_id,
"clarity_sample_id": lims_id,
"sample_name": sample,
"group": csv_dict[sample][1], "species": csv_dict[sample][2],
"assay": assay, "platform": platform,
"sequencing_run": csv_dict[sample][3],
"sequencing_run": sequencing_run,
"read1": csv_dict[sample][4][0],
"read2": csv_dict[sample][4][1]} #write rows to CSV
writer.writerow(row_dict)
Expand Down
Loading

0 comments on commit 851ab88

Please sign in to comment.