Skip to content

Commit

Permalink
bugfix: information was inappropreately shared between files during R…
Browse files Browse the repository at this point in the history
…T correction
  • Loading branch information
stavis1 committed Jun 14, 2024
1 parent b8367a1 commit 4cbc654
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/MSDpostprocess/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
# import shelve
import re
from functools import cache
import warnings

import pickle
import numpy as np
import pandas as pd
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="brainpy")
from brainpy import isotopic_variants
from brainpy import isotopic_variants
import statsmodels.api as sm
from lineartree import LinearTreeClassifier
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -57,7 +54,6 @@ def fit(self, data):
data = data.copy()
data = data[np.isfinite(data['label'])]
data = self.preprocess(data.copy())
# self.classifier = LogisticRegression(solver = 'liblinear')
self.classifier = CalibratedClassifierCV(GBC(), ensemble=False)
self.classifier.fit(data[self.features], data['label'])
self.logs.info(f'Fit {self.model}')
Expand Down Expand Up @@ -172,25 +168,27 @@ def rt_error(self, subset):
frac = self.lowess_frac,
it = 3,
xvals = subset['Average Rt(min)'])
rt_error = subset['Reference RT'].to_numpy() - regression
subset['rt_error'] = subset['Reference RT'].to_numpy() - regression
#record regression for QC purposes
file = next(f for f in subset['file'])
self.rt_observed[file] = subset['Average Rt(min)']
self.rt_expected[file] = subset['Reference RT']
self.rt_predictions[file] = regression
self.rt_calls[file] = subset['call']
self.logs.debug(f'RT regression fit for {file} used {np.sum(subset["call"])} observations')
return rt_error
return subset

def correct_data(self, data):
#identify high confidence subset for correction
scores = self._predict_prob(data)
data['call'] = self.predict(scores)

#build lowess regressions
rt_error = data.groupby('file')[data.columns].apply(self.rt_error)
data['rt_error'] = [val for file in rt_error for val in file]
data = data.drop(columns = ['call'])
subsets = []
for file in set(data['file']):
subsets.append(self.rt_error(data[data['file'] == file]))
data = pd.concat(subsets)
data.drop(columns = ['call'])

self.logs.info('RT correction has been applied.')
return data
Expand Down

0 comments on commit 4cbc654

Please sign in to comment.