Skip to content

Commit

Permalink
Merge pull request #2 from alistairewj/type-casting
Browse files Browse the repository at this point in the history
Refactor type parsing
  • Loading branch information
alistairewj authored Nov 26, 2019
2 parents 3223438 + f64c440 commit aea9046
Showing 1 changed file with 25 additions and 42 deletions.
67 changes: 25 additions & 42 deletions pyroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
__version__ = "0.1.0"

from collections import OrderedDict
from collections.abc import Iterable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -104,51 +105,33 @@ def _parse_inputs(self, preds, target):
integers.
"""
if type(preds) is OrderedDict:
# is already a ordered dict
pass
elif type(preds) is list:
if hasattr(preds[0], '__len__'):
# convert preds into a dictionary
preds = OrderedDict(
[[i, np.asarray(p)] for i, p in enumerate(preds)]
)
elif type(preds[0]) in (float, int):
preds = OrderedDict([(0, np.asarray(preds))])
# Parse preds
if not isinstance(preds, dict):
parsed = np.asarray(preds)

# In case single list of predictions
if not isinstance(parsed[0], Iterable):
parsed = np.array([preds])

# If needed to transpose matrix
if parsed.shape[0] == len(target):
parsed = preds.T

# Use column names if it is a DataFrame
# Otherwise use increasing integers
if type(preds) == pd.DataFrame:
preds = zip(preds.columns, parsed.values)
else:
raise TypeError(
'unable to parse preds list with element type %s',
type(preds[0])
)
elif type(preds) is pd.DataFrame:
# preds is a dict - convert to ordered
preds = OrderedDict(zip(preds.columns, preds.T.values))
elif type(preds) is np.ndarray:
if len(preds.shape) <= 1:
# numpy vector
preds = OrderedDict([[0, np.asarray(preds)]])
else:
# numpy matrix
preds = OrderedDict(
[[i, preds[:, i]] for i in range(preds.shape[1])]
)
elif type(preds) is dict:
# preds is a dict - convert to ordered
names = sorted(preds.keys())
preds = OrderedDict([[c, np.asarray(preds[c])] for c in names])
else:
raise ValueError(
'Unrecognized type "%s" for predictions.', str(type(preds))
)
preds = enumerate(map(np.asarray, parsed))

if type(target) is pd.Series:
target = target.values
elif type(target) in (list, tuple):
# Finnally, ordered dictionary
preds = OrderedDict(preds)

# Parse target
if isinstance(target, Iterable):
target = np.asarray(target)
elif type(target) is not np.ndarray:
raise TypeError(
'target should be type np.ndarray, was %s', type(target)
)
else:
raise TypeError('Target should be iterable, was %s', type(target))

return preds, target

Expand Down

0 comments on commit aea9046

Please sign in to comment.