diff --git a/tests/test_pyroc.py b/tests/test_pyroc.py index f3a9f14..a4e1687 100644 --- a/tests/test_pyroc.py +++ b/tests/test_pyroc.py @@ -36,6 +36,16 @@ def test_pyroc_input_parsing(ovarian_cancer_dataset): target = df['outcome'].values df = df.drop('outcome', axis=1) + # single numpy array of preds + preds = df['albumin'].values + roc = pyroc.ROC( + target, + preds, + ) + # since roc wasn't provided predictor labels, + # keys are integers (i) + assert (roc.preds[0][:n] == expected_values['albumin'][:n]).all() + # numpy arrays preds = df.values roc = pyroc.ROC(