Skip to content

Commit

Permalink
Fix cgmlst comparison and plots
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanjameskennedy committed Jul 9, 2024
1 parent 93821d6 commit 5595655
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 40 deletions.
91 changes: 55 additions & 36 deletions jasentool/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_null_allele_counts(self, input_files):
else:
null_alleles_count[allele] = 1
print(f"The average number of missing alleles per sample is {sum(sample_null_count.values()) / len(sample_null_count.values())}")
return null_alleles_count
return null_alleles_count, sample_null_count

def get_cgviz_cgmlst_data(self, sample_id):
"""Get sample mongodb data"""
Expand All @@ -61,27 +61,17 @@ def get_jasen_cgmlst_data(self, sample_id):

def compare_cgmlst_alleles(self, row_cgmlst_alleles, col_cgmlst_alleles):
"""Parse through cgmlst alleles of old and new pipeline and compare results"""
# null_values = ["-", "EXC", "INF", "LNF", "PLNF", "PLOT3", "PLOT5", "LOTSC", "NIPH", "NIPHEM", "PAMA"]
# row_cgmlst_alleles = [int(allele) if allele not in null_values else np.nan for allele in row_cgmlst_alleles]
# col_cgmlst_alleles = [int(allele) if allele not in null_values else np.nan for allele in col_cgmlst_alleles]

row_cgmlst_alleles = np.array(row_cgmlst_alleles, dtype=object)
col_cgmlst_alleles = np.array(col_cgmlst_alleles, dtype=object)

# Convert None to np.nan
row_cgmlst_alleles = np.where(row_cgmlst_alleles == "-", np.nan, row_cgmlst_alleles)
col_cgmlst_alleles = np.where(col_cgmlst_alleles == "-", np.nan, col_cgmlst_alleles)

# Convert arrays to numeric, coercing errors to np.nan
row_cgmlst_alleles = pd.to_numeric(row_cgmlst_alleles, errors='coerce')
col_cgmlst_alleles = pd.to_numeric(col_cgmlst_alleles, errors='coerce')

# Identify positions with NaNs in either array
nan_positions = np.isnan(row_cgmlst_alleles) | np.isnan(col_cgmlst_alleles)

# Calculate mismatches excluding positions with NaNs
mismatches = (row_cgmlst_alleles != col_cgmlst_alleles) & ~nan_positions
mismatch_count = np.sum(mismatches)
mismatch_count = 0
null_values = ["-", "EXC", "INF", "LNF", "PLNF", "PLOT3", "PLOT5", "LOTSC", "NIPH", "NIPHEM", "PAMA", "ASM", "ALM"]
for idx, row_allele in enumerate(row_cgmlst_alleles):
col_allele = col_cgmlst_alleles[idx]
if row_allele in null_values or col_allele in null_values:
continue
try:
if int(row_allele) != int(col_allele):
mismatch_count += 1
except ValueError:
print(f"One following alleles are not in integer format: {row_allele} (row) or {col_allele} (column)")
return mismatch_count

def generate_matrix(self, sample_ids, get_cgmlst_data):
Expand All @@ -105,34 +95,63 @@ def plot_heatmap(self, distance_df, output_plot_fpath):
plt.savefig(output_plot_fpath, dpi=600)

def plot_barplot(self, count_dict, output_plot_fpath):
categories = list(count_dict.keys())
counts = list(count_dict.values())
filtered_dict = {k: v for k, v in count_dict.items() if v >= 1000}
sorted_filtered_dict = dict(sorted(filtered_dict.items(), key=lambda item: item[1]))
categories = list(sorted_filtered_dict.keys())
counts = list(sorted_filtered_dict.values())

print(f"The number of alleles that aren't null for more than 1000 samples is {len(categories)}")

plt.figure(figsize=(10, 5))
plt.figure(figsize=(12, 6))
bars = plt.bar(categories, counts, color='skyblue')

# Add titles and labels
plt.xlabel('Alleles')
plt.ylabel('Count')
plt.title('Null Allele Count Bar Plot')

# Rotate the x-axis labels by 90 degrees
plt.xticks(rotation=90)

# Add value labels on top of the bars
for bar in bars:
yval = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, yval + 1, yval, ha='center', va='bottom')

plt.tight_layout()
plt.savefig(output_plot_fpath, dpi=600)

def run(self, input_files, output_fpaths):
output_csv_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen.csv")

def plot_boxplot(self, count_dict, output_plot_fpath):
counts = list(count_dict.values())
plt.figure(figsize=(10, 6)) # Optional: set the figure size
plt.boxplot(counts, vert=True, patch_artist=True) # `vert=True` for vertical boxplot, `patch_artist=True` for filled boxes

# Add title and labels
plt.xlabel('Null allele count')
plt.title('Number of null alleles per sample')

min_value = np.min(counts)

# Label the minimum value on the plot
plt.annotate(f'Min: {min_value}', xy=(1, min_value), xytext=(1.05, min_value),
arrowprops=dict(facecolor='black', shrink=0.05),
horizontalalignment='left')

plt.savefig(output_plot_fpath, dpi=600)

def run(self, input_files, output_fpaths, generate_matrix):
# heatmap_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen_heatmap.png")
barplot_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "null_alleles_barplot.png")
null_alleles_count = self.get_null_allele_counts(input_files)
boxplot_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "sample_null_boxplot.png")
null_alleles_count, sample_null_count = self.get_null_allele_counts(input_files)
self.plot_boxplot(sample_null_count, boxplot_fpath)
self.plot_barplot(null_alleles_count, barplot_fpath)
sample_ids = [os.path.basename(input_file).replace("_result.json", "") for input_file in input_files]
cgviz_matrix_df = self.generate_matrix(sample_ids, self.get_cgviz_cgmlst_data)
jasen_matrix_df = self.generate_matrix(sample_ids, self.get_jasen_cgmlst_data)
distance_df = jasen_matrix_df - cgviz_matrix_df
distance_df = distance_df.astype(float)
distance_df.to_csv(output_csv_fpath, index=True, header=True)
# self.plot_heatmap(distance_df, output_plot_fpath)
if generate_matrix:
output_csv_fpath = os.path.join(os.path.dirname(output_fpaths[0]), "cgviz_vs_jasen.csv")
sample_ids = [os.path.basename(input_file).replace("_result.json", "") for input_file in input_files]
cgviz_matrix_df = self.generate_matrix(sample_ids, self.get_cgviz_cgmlst_data)
jasen_matrix_df = self.generate_matrix(sample_ids, self.get_jasen_cgmlst_data)
distance_df = jasen_matrix_df - cgviz_matrix_df
distance_df = distance_df.astype(float)
distance_df.to_csv(output_csv_fpath, index=True, header=True)
# self.plot_heatmap(distance_df, output_plot_fpath)
8 changes: 4 additions & 4 deletions jasentool/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,22 @@ def compare_data(self, sample_id, old_data, new_data):
cgmlst_alleles = self.compare_cgmlst_alleles(old_data["cgmlst_alleles"], new_data["cgmlst_alleles"])
return True, f"{sample_id},{pvl_comp},{mlst_seqtype_comp},{mlst_alleles},{cgmlst_alleles}"

def run(self, input_files, output_fpaths, combined_output):
def run(self, input_files, output_fpaths, combined_output, generate_matrix):
"""Execute validation of new pipeline (jasen)"""
utils = Utils()
matrix = Matrix(self.input_dir, self.db_collection)
csv_output = "sample_id,pvl,mlst_seqtype,mlst_allele_matches(%),cgmlst_allele_matches(%)"
mlst_at_header = "old_arcC,new_arcC,old_aroE,new_aroE,old_glpF,new_glpF,old_gmk,new_gmk,old_pta,new_pta,old_tpi,new_tpi,old_yqiL,new_yqiL"
failed_csv_output = f"sample_id,old_mlst_seqtype,new_mlst_allele_matches(%),{mlst_at_header}"
matrix.run(input_files, output_fpaths)
matrix.run(input_files, output_fpaths, generate_matrix)
for input_idx, input_file in enumerate(input_files):
with open(input_file, 'r', encoding="utf-8") as fin:
sample_json = json.load(fin)
sample_id = self.get_sample_id(sample_json)
if not self._check_exists(self.db_collection, sample_id):
if not self._check_exists(sample_id):
print(f"The sample provided ({sample_id}) does not exist in the provided database ({Database.db_name}) or collection ({self.db_collection}).")
continue
mdb_data_dict = self.get_mdb_cgv_data(self.db_collection, sample_id)
mdb_data_dict = self.get_mdb_cgv_data(sample_id)
if mdb_data_dict:
#species_name = self.get_species_name(sample_json)
fin_data_dict = self.get_fin_data(sample_json)
Expand Down

0 comments on commit 5595655

Please sign in to comment.