Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnovsak authored and markotoplak committed Aug 17, 2023
1 parent b661c78 commit c3bd09b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
28 changes: 17 additions & 11 deletions Orange/regression/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
import sklearn.linear_model as skl_linear_model
import sklearn.preprocessing as skl_preprocessing

try:
import dask_ml.linear_model as dask_linear_model
from dask_glm.regularizers import ElasticNet
except ImportError:
dask_linear_model = skl_linear_model
ElasticNet = ...

from Orange.data import Variable, ContinuousVariable
from Orange.preprocess import Normalize
from Orange.preprocess.score import LearnerScorer
Expand Down Expand Up @@ -39,25 +46,24 @@ def __init__(self, preprocessors=None, fit_intercept=True):

def _initialize_wrapped(self, X=None, Y=None):
if isinstance(X, da.Array) or isinstance(Y, da.Array):
try:
import dask_ml.linear_model

if dask_linear_model is skl_linear_model:
warnings.warn("dask_ml is not installed, using sklearn instead.")
else:
params = self.params.copy()
penalty = self.__penalty__
params["solver"] = "gradient_descent"
params["penalty"] = self.__penalty__
if self.__penalty__ is not None:

if penalty is not None:
if penalty == "elasticnet":
penalty = ElasticNet(weight=params.pop("l1_ratio"))
params["penalty"] = penalty
params["solver"] = "admm"
params["C"] = 1 / params.pop("alpha")
params["max_iter"] = params["max_iter"] or 100
for key in ["copy_X", "precompute", "positive"]:
params.pop(key, None)
if self.__penalty__ == "elasticnet":
from dask_glm.regularizers import ElasticNet
params["penalty"] = ElasticNet(weight=params.pop("l1_ratio"))

return dask_ml.linear_model.LinearRegression(**params)
except ImportError:
warnings.warn("dask_ml is not installed, using sklearn instead.")
return dask_linear_model.LinearRegression(**params)
return self.__wraps__(**self.params)

def fit(self, X, Y, W=None):
Expand Down
1 change: 1 addition & 0 deletions Orange/tests/test_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_linear_regression_repr(self):
self.assertIsInstance(learner2, LinearRegressionLearner)


# pylint: disable=invalid-name
class TestLinearRegressionLearnerOnDask(TestLinearRegressionLearner):
learners = [
RidgeRegressionLearner(),
Expand Down

0 comments on commit c3bd09b

Please sign in to comment.