Skip to content

Commit

Permalink
specify column type for goldstandard file
Browse files Browse the repository at this point in the history
  • Loading branch information
vpchung committed May 8, 2024
1 parent 13be801 commit 73971a3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
24 changes: 19 additions & 5 deletions score.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
- ROC curve
- PR curve
"""

from glob import glob
import argparse
import json
import os

import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

from glob import glob
GOLDSTANDARD_COLS = {"id": str, "disease": int}
PREDICTION_COLS = {"id": str, "disease_probability": np.float64}


def get_args():
"""Set up command-line interface and get arguments."""
Expand All @@ -37,7 +40,10 @@ def extract_gs_file(folder):
"""Extract gold standard file from folder."""
files = glob(os.path.join(folder, "*"))
if len(files) != 1:
raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.")
raise ValueError(
"Expected exactly one gold standard file in folder. "
f"Got {len(files)}. Exiting."
)

return files[0]

Expand All @@ -52,8 +58,16 @@ def main():
gold_file = extract_gs_file(args.goldstandard_folder)

if res.get("validation_status") == "VALIDATED":
pred = pd.read_csv(args.predictions_file)
gold = pd.read_csv(gold_file)
pred = pd.read_csv(
args.predictions_file,
usecols=GOLDSTANDARD_COLS,
dtype=GOLDSTANDARD_COLS
)
gold = pd.read_csv(
gold_file,
usecols=PREDICTION_COLS,
dtype=PREDICTION_COLS
)
scores = score(gold, "disease", pred, "disease_probability")
status = "SCORED"
else:
Expand Down
52 changes: 22 additions & 30 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,24 @@
- `id` is a string
- `disease_probability` is a float between 0 and 1
"""

from glob import glob
import argparse
import json
import os

import numpy as np
import pandas as pd

from glob import glob

EXPECTED_COLS = {
'id': str,
'disease_probability': np.float64
}
GOLDSTANDARD_COLS = {"id": str, "disease": int}
EXPECTED_COLS = {"id": str, "disease_probability": np.float64}


def get_args():
"""Set up command-line interface and get arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--predictions_file",
type=str, required=True)
parser.add_argument("-g", "--goldstandard_folder",
type=str, required=True)
parser.add_argument("-o", "--output",
type=str, default="results.json")
parser.add_argument("-p", "--predictions_file", type=str, required=True)
parser.add_argument("-g", "--goldstandard_folder", type=str, required=True)
parser.add_argument("-o", "--output", type=str, default="results.json")
return parser.parse_args()


Expand All @@ -38,7 +31,7 @@ def check_dups(pred):
duplicates = pred.duplicated(subset=["id"])
if duplicates.any():
return (
f"Found {duplicates.sum()} duplicate participant ID(s): "
f"Found {duplicates.sum()} duplicate ID(s): "
f"{pred[duplicates].id.to_list()}"
)
return ""
Expand All @@ -50,7 +43,7 @@ def check_missing_ids(gold, pred):
missing_ids = gold.index.difference(pred.index)
if missing_ids.any():
return (
f"Found {missing_ids.shape[0]} missing participant ID(s): "
f"Found {missing_ids.shape[0]} missing ID(s): "
f"{missing_ids.to_list()}"
)
return ""
Expand All @@ -62,7 +55,7 @@ def check_unknown_ids(gold, pred):
unknown_ids = pred.index.difference(gold.index)
if unknown_ids.any():
return (
f"Found {unknown_ids.shape[0]} unknown participant ID(s): "
f"Found {unknown_ids.shape[0]} unknown ID(s): "
f"{unknown_ids.to_list()}"
)
return ""
Expand All @@ -72,34 +65,34 @@ def check_nan_values(pred):
"""Check for NAN predictions."""
missing_probs = pred["disease_probability"].isna().sum()
if missing_probs:
return (
f"'disease_probability' column contains {missing_probs} NaN value(s)."
)
return f"'disease_probability' column contains {missing_probs} NaN value(s)."
return ""


def check_prob_values(pred):
"""Check that probabilities are between [0, 1]."""
if (pred["disease_probability"] < 0).any() or (pred["disease_probability"] > 1).any():
return "'disease_probability' values should be between [0, 1] inclusive."
if (pred["disease_probability"] < 0).any() or \
(pred["disease_probability"] > 1).any():
return "'disease_probability' values should be between [0, 1]."
return ""


def extract_gs_file(folder):
"""Extract gold standard file from folder."""
"""Extract goldstandard file from folder."""
files = glob(os.path.join(folder, "*"))
if len(files) != 1:
raise ValueError(f"Expected exactly one gold standard file in folder. Got {len(files)}. Exiting.")

raise ValueError(
"Expected exactly one goldstandard file in folder. "
f"Got {len(files)}. Exiting."
)
return files[0]


def validate(gold_folder, pred_file):
"""Validate predictions file against goldstandard."""
errors = []

gold_file = extract_gs_file(gold_folder)
gold = pd.read_csv(gold_file, index_col="id")
gold = pd.read_csv(gold_file, dtype=GOLDSTANDARD_COLS, index_col="id")
try:
pred = pd.read_csv(
pred_file,
Expand Down Expand Up @@ -140,10 +133,9 @@ def main():
# truncate validation errors if >500 (character limit for sending email)
if len(invalid_reasons) > 500:
invalid_reasons = invalid_reasons[:496] + "..."
res = json.dumps({
"validation_status": status,
"validation_errors": invalid_reasons
})
res = json.dumps(
{"validation_status": status, "validation_errors": invalid_reasons}
)

with open(args.output, "w") as out:
out.write(res)
Expand Down

0 comments on commit 73971a3

Please sign in to comment.