diff --git a/pyroc.py b/pyroc.py index 4de7a03..41f7bce 100644 --- a/pyroc.py +++ b/pyroc.py @@ -6,6 +6,7 @@ __version__ = "0.1.0" from collections import OrderedDict +from collections.abc import Iterable import numpy as np import pandas as pd @@ -104,51 +105,33 @@ def _parse_inputs(self, preds, target): integers. """ - if type(preds) is OrderedDict: - # is already a ordered dict - pass - elif type(preds) is list: - if hasattr(preds[0], '__len__'): - # convert preds into a dictionary - preds = OrderedDict( - [[i, np.asarray(p)] for i, p in enumerate(preds)] - ) - elif type(preds[0]) in (float, int): - preds = OrderedDict([(0, np.asarray(preds))]) + # Parse preds + if not isinstance(preds, dict): + parsed = np.asarray(preds) + + # In case single list of predictions + if not isinstance(parsed[0], Iterable): + parsed = np.array([preds]) + + # If needed to transpose matrix + if parsed.shape[0] == len(target): + parsed = preds.T + + # Use column names if it is a DataFrame + # Otherwise use increasing integers + if type(preds) == pd.DataFrame: + preds = zip(preds.columns, parsed.values) else: - raise TypeError( - 'unable to parse preds list with element type %s', - type(preds[0]) - ) - elif type(preds) is pd.DataFrame: - # preds is a dict - convert to ordered - preds = OrderedDict(zip(preds.columns, preds.T.values)) - elif type(preds) is np.ndarray: - if len(preds.shape) <= 1: - # numpy vector - preds = OrderedDict([[0, np.asarray(preds)]]) - else: - # numpy matrix - preds = OrderedDict( - [[i, preds[:, i]] for i in range(preds.shape[1])] - ) - elif type(preds) is dict: - # preds is a dict - convert to ordered - names = sorted(preds.keys()) - preds = OrderedDict([[c, np.asarray(preds[c])] for c in names]) - else: - raise ValueError( - 'Unrecognized type "%s" for predictions.', str(type(preds)) - ) + preds = enumerate(map(np.asarray, parsed)) - if type(target) is pd.Series: - target = target.values - elif type(target) in (list, tuple): + # Finnally, ordered dictionary + preds = OrderedDict(preds) + + # Parse target + if isinstance(target, Iterable): target = np.asarray(target) - elif type(target) is not np.ndarray: - raise TypeError( - 'target should be type np.ndarray, was %s', type(target) - ) + else: + raise TypeError('Target should be iterable, was %s', type(target)) return preds, target