diff --git a/simba/model/regression/model.py b/simba/model/regression/model.py index fdf6542a9..3cf129dcb 100644 --- a/simba/model/regression/model.py +++ b/simba/model/regression/model.py @@ -1,10 +1,11 @@ -from typing import Dict, List, Optional, Tuple from itertools import product +from typing import Dict, List, Optional, Tuple import numpy as np import pandas as pd import xgboost as xgb from sklearn.model_selection import StratifiedKFold + from simba.model.regression.metrics import (mean_absolute_error, mean_absolute_percentage_error, mean_squared_error, r2_score, @@ -15,6 +16,7 @@ from simba.utils.enums import Formats from simba.utils.errors import DataHeaderError + def fit_xgb(x: pd.DataFrame, y: np.ndarray, xgb_reg: xgb.XGBRegressor) -> xgb.XGBRegressor: