From c4d23e1395eeae41f6184a9de8109c2fa7245d91 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 23 Oct 2023 10:13:08 -0700 Subject: [PATCH] Add lasso from cuML --- elk/training/classifier.py | 34 ++++++++++++++++++++++++---------- elk/training/supervised.py | 3 ++- elk/training/train.py | 4 ++++ tests/test_classifier.py | 2 +- 4 files changed, 31 insertions(+), 12 deletions(-) diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 7b9281ec0..7c28e7c52 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -68,7 +68,8 @@ def fit( x: Tensor, y: Tensor, *, - l2_penalty: float = 0.001, + alpha: float = 0.001, + lasso: bool = False, max_iter: int = 10_000, ) -> float: """Fits the model to the input data using L-BFGS with L2 regularization. @@ -78,12 +79,22 @@ def fit( the input dimension. y: Target tensor of shape (N,) for binary classification or (N, C) for multiclass classification, where C is the number of classes. - l2_penalty: L2 regularization strength. + alpha: L2 regularization strength. max_iter: Maximum number of iterations for the L-BFGS optimizer. Returns: Final value of the loss function after optimization. """ + # Use cuML backend for LASSO + if lasso: + from cuml import Lasso + + model = Lasso(alpha=alpha) + model.fit(x.cpu().numpy(), y.cpu().numpy()) + self.linear.weight.data = torch.from_numpy(model.coef_.T) + self.linear.bias.data = torch.from_numpy(model.intercept_) + return float(model.loss_) + optimizer = torch.optim.LBFGS( self.parameters(), line_search_fn="strong_wolfe", @@ -104,8 +115,8 @@ def closure(): # Calculate the loss function logits = self(x).squeeze(-1) loss = loss_fn(logits, y) - if l2_penalty: - reg_loss = loss + l2_penalty * self.linear.weight.square().sum() + if alpha: + reg_loss = loss + alpha * self.linear.weight.square().sum() else: reg_loss = loss @@ -122,6 +133,7 @@ def fit_cv( y: Tensor, *, k: int = 5, + lasso: bool = False, max_iter: int = 10_000, num_penalties: int = 10, seed: int = 42, @@ -155,7 +167,7 @@ def fit_cv( indices = torch.randperm(num_samples, device=x.device, generator=rng) # Try a range of L2 penalties, including 0 - l2_penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist() + penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist() num_classes = self.linear.out_features loss_fn = bce_with_logits if num_classes == 1 else cross_entropy @@ -173,8 +185,10 @@ def fit_cv( val_x, val_y = x[val_indices], y[val_indices] # Regularization path with warm-starting - for j, l2_penalty in enumerate(l2_penalties): - self.fit(train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter) + for j, penalty in enumerate(penalties): + self.fit( + train_x, train_y, alpha=penalty, lasso=lasso, max_iter=max_iter + ) logits = self(val_x).squeeze(-1) loss = loss_fn(logits, val_y) @@ -184,9 +198,9 @@ def fit_cv( best_idx = mean_losses.argmin() # Refit with the best penalty - best_penalty = l2_penalties[best_idx] - self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter) - return RegularizationPath(l2_penalties, mean_losses.tolist()) + best_penalty = penalties[best_idx] + self.fit(x, y, alpha=best_penalty, lasso=lasso, max_iter=max_iter) + return RegularizationPath(penalties, mean_losses.tolist()) @classmethod def inlp( diff --git a/elk/training/supervised.py b/elk/training/supervised.py index b3f100646..7e5443fc2 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -11,6 +11,7 @@ def train_supervised( device: str, mode: str, erase_paraphrases: bool = False, + lasso: bool = False, max_inlp_iter: int | None = None, ) -> list[Classifier]: assert not ( @@ -48,7 +49,7 @@ def train_supervised( if mode == "cv": lr_model = Classifier(X.shape[-1], device=device, eraser=eraser) - lr_model.fit_cv(X, train_labels) + lr_model.fit_cv(X, train_labels, lasso=lasso) return [lr_model] elif mode == "inlp": return Classifier.inlp( diff --git a/elk/training/train.py b/elk/training/train.py index baa21991f..ec60d779f 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -21,6 +21,9 @@ class Elicit(Run): seed: int = 42 + lasso: bool = False + """Whether to use L1 regularization.""" + supervised: Literal["single", "inlp", "cv"] = "single" """Whether to train a supervised classifier, and if so, whether to use cross-validation. Defaults to "single", which means to train a single classifier @@ -76,6 +79,7 @@ def apply_to_layer( train_dict, erase_paraphrases=self.erase_paraphrases, device=device, + lasso=self.lasso, mode=self.supervised, max_inlp_iter=self.max_inlp_iter, ) diff --git a/tests/test_classifier.py b/tests/test_classifier.py index bdc9023df..5e9184a6d 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -24,7 +24,7 @@ def test_classifier_roughly_same_sklearn(): classifier.fit( torch.from_numpy(features), torch.from_numpy(truths), - l2_penalty=0.0, + alpha=0.0, ) # check that the weights are roughly the same sklearn_coef = torch.from_numpy(model.coef_)