diff --git a/src/MSDpostprocess/models.py b/src/MSDpostprocess/models.py index b4d78e8..1e553f4 100644 --- a/src/MSDpostprocess/models.py +++ b/src/MSDpostprocess/models.py @@ -8,14 +8,11 @@ # import shelve import re from functools import cache -import warnings import pickle import numpy as np import pandas as pd -with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="brainpy") - from brainpy import isotopic_variants +from brainpy import isotopic_variants import statsmodels.api as sm from lineartree import LinearTreeClassifier from sklearn.preprocessing import StandardScaler @@ -57,7 +54,6 @@ def fit(self, data): data = data.copy() data = data[np.isfinite(data['label'])] data = self.preprocess(data.copy()) - # self.classifier = LogisticRegression(solver = 'liblinear') self.classifier = CalibratedClassifierCV(GBC(), ensemble=False) self.classifier.fit(data[self.features], data['label']) self.logs.info(f'Fit {self.model}') @@ -172,7 +168,7 @@ def rt_error(self, subset): frac = self.lowess_frac, it = 3, xvals = subset['Average Rt(min)']) - rt_error = subset['Reference RT'].to_numpy() - regression + subset['rt_error'] = subset['Reference RT'].to_numpy() - regression #record regression for QC purposes file = next(f for f in subset['file']) self.rt_observed[file] = subset['Average Rt(min)'] @@ -180,7 +176,7 @@ def rt_error(self, subset): self.rt_predictions[file] = regression self.rt_calls[file] = subset['call'] self.logs.debug(f'RT regression fit for {file} used {np.sum(subset["call"])} observations') - return rt_error + return subset def correct_data(self, data): #identify high confidence subset for correction @@ -188,9 +184,11 @@ def correct_data(self, data): data['call'] = self.predict(scores) #build lowess regressions - rt_error = data.groupby('file')[data.columns].apply(self.rt_error) - data['rt_error'] = [val for file in rt_error for val in file] - data = data.drop(columns = ['call']) + subsets = [] + for file in set(data['file']): + subsets.append(self.rt_error(data[data['file'] == file])) + data = pd.concat(subsets) + data.drop(columns = ['call']) self.logs.info('RT correction has been applied.') return data