Skip to content

Commit

Permalink
ordinal clf
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Dec 10, 2024
1 parent 5f48bb8 commit a5f0d40
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 5 deletions.
65 changes: 64 additions & 1 deletion docs/simba.model_mixin.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,69 @@
Model mixin
----------------------------------------------

Utilities for fit, inference, and evaluation of classifiers.

.. autoclass:: simba.mixins.train_model_mixin.TrainModelMixin
:members:
:undoc-members:
:undoc-members:

Batch random forest inference
----------------------------------------------

.. autoclass:: simba.model.inference_batch.InferenceBatch
:members:
:undoc-members:

Batch multi-class random forest inference
----------------------------------------------

.. autoclass:: simba.model.inference_multiclass_batch.InferenceMulticlassBatch
:members:
:undoc-members:

Grid-search random forest classifiers
----------------------------------------------

.. autoclass:: simba.model.grid_search_rf.GridSearchRandomForestClassifier
:members:
:undoc-members:

Grid-search random forest multi-classifiers
----------------------------------------------

.. autoclass:: simba.model.grid_search_multiclass_rf.GridSearchMulticlassRandomForestClassifier
:members:
:undoc-members:

Random forest inference - validation
----------------------------------------------

.. autoclass:: simba.model.inference_validation.InferenceValidation
:members:
:undoc-members:

Fit random forest classifier
----------------------------------------------

.. autoclass:: simba.model.train_rf.TrainRandomForestClassifier
:members:
:undoc-members:


Fit random forest classifier - multi-class
----------------------------------------------

.. autoclass:: simba.model.train_multiclass_rf.TrainMultiClassRandomForestClassifier
:members:
:undoc-members:


Ordinal classifier methods
----------------------------------------------

.. autoclass:: simba.model.ordinal_clf.OrdinalClassifier
:members:
:undoc-members:



111 changes: 111 additions & 0 deletions simba/model/ordinal_clf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
from typing import Union, Optional, Dict
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from joblib import Parallel, delayed
from sklearn import clone
from simba.mixins.train_model_mixin import TrainModelMixin
from simba.utils.errors import SamplingError, InvalidInputError
from simba.utils.checks import check_valid_array, check_int, check_if_dir_exists, check_file_exist_and_readable
from simba.utils.enums import Formats
from simba.utils.read_write import find_core_cnt, write_pickle, read_pickle

ACCEPTED_MODELS = RandomForestClassifier

class OrdinalClassifier():
"""
This class implements a strategy for ordinal classification by fitting multiple binary classifiers to predict thresholds between classes.
It is particularly useful for problems where the target variable has an inherent order but uneven intervals between levels. Thi includes human severity scores, for example, seizures, stereotopy, convulsion, bizarre behavior scores ranging fro 0-5.
.. note::
`Modified from sklego <`https://github.com/koaning/scikit-lego/blob/main/sklego/meta/ordinal_classification.py>`__.
References
----------
.. [1] Frank, Eibe, and Mark Hall. “A Simple Approach to Ordinal Classification.” In Machine Learning: ECML 2001, edited by Luc De Raedt and Peter Flach, 2167:145–56. Lecture Notes in Computer Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2001. https://doi.org/10.1007/3-540-44795-4_13.
.. [2] Sabnis, Gautam, Leinani Hession, J. Matthew Mahoney, Arie Mobley, Marina Santos, and Vivek Kumar. “Visual Detection of Seizures in Mice Using Supervised Machine Learning,” May 31, 2024. https://doi.org/10.1101/2024.05.29.596520.
:example:
>>> X = np.random.randint(0, 500, (100, 50))
>>> y = np.random.randint(1, 6, (100))
>>> rf_mdl = TrainModelMixin().clf_define()
>>> fitted_mdl = OrdinalClassifier.fit(X, y, rf_mdl, -1)
>>> y_hat = OrdinalClassifier.predict_proba(X, fitted_mdl)
>>> y = OrdinalClassifier.predict(X, fitted_mdl)
>>> OrdinalClassifier.save(mdl=fitted_mdl, save_path=r"C:\mdl.pk")
"""

def __init__(self):
pass

@staticmethod
def fit(X: np.ndarray, y: np.ndarray, clf: Union[ACCEPTED_MODELS], core_cnt: int = -1) -> Dict[int, Union[ACCEPTED_MODELS]]:

def _fit_binary_estimator(clf, X, y, y_label):
y_bin = (y <= y_label).astype(np.int32)
return clone(clf).fit(X, y_bin)

classes_ = np.sort(np.unique(y))
check_valid_array(data=classes_, source=f'{__class__.__name__} y', accepted_ndims=(1,), accepted_dtypes=(int,))
if len(classes_) < 3:
raise InvalidInputError(msg=f'Found {len(classes_)} classes in y [{classes_}], requires at least 3', source=f'{OrdinalClassifier.__name__} fit')
intervals = [classes_[i] - classes_[i-1] for i in range(1, len(classes_))]
if len(set(intervals)) != 1:
raise InvalidInputError(msg=f'The values in y ({classes_}) are not of equal interval.', source=f'{OrdinalClassifier.__name__} fit')
check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
if not isinstance(clf, (RandomForestClassifier,)) or ('predict_proba' not in dir(clf)):
raise InvalidInputError(msg=f'clf is not of valid type: {type(clf)} (accepted: {ACCEPTED_MODELS})', source=f'{OrdinalClassifier.__name__} fit')
check_int(name='core_cnt', min_value=-1, unaccepted_vals=[0], value=core_cnt)
core_cnt = [find_core_cnt()[0] if core_cnt == -1 or core_cnt > find_core_cnt()[0] else core_cnt][0]
return dict(zip(classes_[:-1], Parallel(n_jobs=core_cnt)(delayed(_fit_binary_estimator)(clf, X, y, y_label) for y_label in classes_[:-1])))

@staticmethod
def predict_proba(X: np.ndarray, mdl: Dict[int, Union[ACCEPTED_MODELS]]) -> np.ndarray:
OrdinalClassifier._check_valid_mdl_dict(mdls=mdl)
check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
if mdl[list(mdl.keys())[0]].n_features_ != X.shape[1]:
raise InvalidInputError(msg=f'Model expects {mdl[list(mdl.keys())[0]].n_features_} features, got {X.shape[1]}.', source=f'{OrdinalClassifier.__name__} predict')
check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
raw_proba = np.array([estimator.predict_proba(X)[:, 1] for estimator in mdl.values()]).T
return np.diff(np.column_stack((np.zeros(X.shape[0]), raw_proba, np.ones(X.shape[0]))), n=1, axis=1)


@staticmethod
def predict(X: np.ndarray, mdl: Dict[int, Union[ACCEPTED_MODELS]]) -> np.ndarray:
OrdinalClassifier._check_valid_mdl_dict(mdls=mdl)
check_valid_array(data=X, source=f'{__class__.__name__} x', accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)
if mdl[list(mdl.keys())[0]].n_features_ != X.shape[1]:
raise InvalidInputError(msg=f'Model expects {mdl[list(mdl.keys())[0]].n_features_} features, got {X.shape[1]}.', source=f'{OrdinalClassifier.__name__} predict')
return np.argmax(OrdinalClassifier.predict_proba(X, mdl=mdl), axis=1)

@staticmethod
def save(mdl: Dict[int, Union[ACCEPTED_MODELS]], save_path: Union[str, os.PathLike]):
OrdinalClassifier._check_valid_mdl_dict(mdls=mdl)
check_if_dir_exists(in_dir=os.path.dirname(save_path), source=f'{OrdinalClassifier.__name__} save')
write_pickle(data=mdl, save_path=save_path)


@staticmethod
def load(file_path: Union[str, os.PathLike]) -> Dict[int, Union[ACCEPTED_MODELS]]:
check_file_exist_and_readable(file_path=file_path)
return read_pickle(data_path=file_path)


@staticmethod
def _check_valid_mdl_dict(mdls: Dict[int, Union[ACCEPTED_MODELS]]) -> None:
features_in_cnt = []
for mdl in mdls.values(): features_in_cnt.append(mdl.n_features_)
if len(set(features_in_cnt)) != 1:
raise InvalidInputError(msg=f'The models has different N features [{features_in_cnt}]')

# X = np.random.randint(0, 500, (100, 50))
# y = np.random.randint(1, 6, (100))
# rf_mdl = TrainModelMixin().clf_define()
# fitted_mdls = OrdinalClassifier.fit(X, y, rf_mdl, -1)
# y_hat = OrdinalClassifier.predict_proba(X, fitted_mdls)
# y = OrdinalClassifier.predict(X, fitted_mdls)
# OrdinalClassifier.save(mdl=fitted_mdls, save_path=r"C:\Users\sroni\OneDrive\Desktop\mdl.pk")

#predict_proba(X)
# ordinal_clf.predict(X)
6 changes: 2 additions & 4 deletions simba/utils/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,7 @@ def read_pickle(data_path: Union[str, os.PathLike], verbose: Optional[bool] = Fa
:example:
>>> data = read_pickle(data_path='/test/unsupervised/cluster_models')
"""
data = None
if os.path.isdir(data_path):
if verbose:
print(f"Reading in data directory {data_path}...")
Expand Down Expand Up @@ -1841,10 +1842,7 @@ def read_pickle(data_path: Union[str, os.PathLike], verbose: Optional[bool] = Fa
source=read_pickle.__name__,
)
else:
raise InvalidFilepathError(
msg=f"The path {data_path} is neither a valid file or directory path",
source=read_pickle.__name__,
)
raise InvalidFilepathError(msg=f"The path {data_path} is neither a valid file or directory path", source=read_pickle.__name__)

return data

Expand Down

0 comments on commit a5f0d40

Please sign in to comment.