Skip to content

Commit

Permalink
Add lasso from cuML
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Oct 23, 2023
1 parent 70a3290 commit c4d23e1
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
34 changes: 24 additions & 10 deletions elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def fit(
x: Tensor,
y: Tensor,
*,
l2_penalty: float = 0.001,
alpha: float = 0.001,
lasso: bool = False,
max_iter: int = 10_000,
) -> float:
"""Fits the model to the input data using L-BFGS with L2 regularization.
Expand All @@ -78,12 +79,22 @@ def fit(
the input dimension.
y: Target tensor of shape (N,) for binary classification or (N, C) for
multiclass classification, where C is the number of classes.
l2_penalty: L2 regularization strength.
alpha: L2 regularization strength.
max_iter: Maximum number of iterations for the L-BFGS optimizer.
Returns:
Final value of the loss function after optimization.
"""
# Use cuML backend for LASSO
if lasso:
from cuml import Lasso

Check failure on line 90 in elk/training/classifier.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

Import "cuml" could not be resolved (reportMissingImports)

model = Lasso(alpha=alpha)
model.fit(x.cpu().numpy(), y.cpu().numpy())
self.linear.weight.data = torch.from_numpy(model.coef_.T)
self.linear.bias.data = torch.from_numpy(model.intercept_)
return float(model.loss_)

optimizer = torch.optim.LBFGS(
self.parameters(),
line_search_fn="strong_wolfe",
Expand All @@ -104,8 +115,8 @@ def closure():
# Calculate the loss function
logits = self(x).squeeze(-1)
loss = loss_fn(logits, y)
if l2_penalty:
reg_loss = loss + l2_penalty * self.linear.weight.square().sum()
if alpha:
reg_loss = loss + alpha * self.linear.weight.square().sum()
else:
reg_loss = loss

Expand All @@ -122,6 +133,7 @@ def fit_cv(
y: Tensor,
*,
k: int = 5,
lasso: bool = False,
max_iter: int = 10_000,
num_penalties: int = 10,
seed: int = 42,
Expand Down Expand Up @@ -155,7 +167,7 @@ def fit_cv(
indices = torch.randperm(num_samples, device=x.device, generator=rng)

# Try a range of L2 penalties, including 0
l2_penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist()
penalties = [0.0] + torch.logspace(-4, 4, num_penalties).tolist()

num_classes = self.linear.out_features
loss_fn = bce_with_logits if num_classes == 1 else cross_entropy
Expand All @@ -173,8 +185,10 @@ def fit_cv(
val_x, val_y = x[val_indices], y[val_indices]

# Regularization path with warm-starting
for j, l2_penalty in enumerate(l2_penalties):
self.fit(train_x, train_y, l2_penalty=l2_penalty, max_iter=max_iter)
for j, penalty in enumerate(penalties):
self.fit(
train_x, train_y, alpha=penalty, lasso=lasso, max_iter=max_iter
)

logits = self(val_x).squeeze(-1)
loss = loss_fn(logits, val_y)
Expand All @@ -184,9 +198,9 @@ def fit_cv(
best_idx = mean_losses.argmin()

# Refit with the best penalty
best_penalty = l2_penalties[best_idx]
self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter)
return RegularizationPath(l2_penalties, mean_losses.tolist())
best_penalty = penalties[best_idx]
self.fit(x, y, alpha=best_penalty, lasso=lasso, max_iter=max_iter)
return RegularizationPath(penalties, mean_losses.tolist())

@classmethod
def inlp(
Expand Down
3 changes: 2 additions & 1 deletion elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def train_supervised(
device: str,
mode: str,
erase_paraphrases: bool = False,
lasso: bool = False,
max_inlp_iter: int | None = None,
) -> list[Classifier]:
assert not (
Expand Down Expand Up @@ -48,7 +49,7 @@ def train_supervised(

if mode == "cv":
lr_model = Classifier(X.shape[-1], device=device, eraser=eraser)
lr_model.fit_cv(X, train_labels)
lr_model.fit_cv(X, train_labels, lasso=lasso)
return [lr_model]
elif mode == "inlp":
return Classifier.inlp(
Expand Down
4 changes: 4 additions & 0 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Elicit(Run):

seed: int = 42

lasso: bool = False
"""Whether to use L1 regularization."""

supervised: Literal["single", "inlp", "cv"] = "single"
"""Whether to train a supervised classifier, and if so, whether to use
cross-validation. Defaults to "single", which means to train a single classifier
Expand Down Expand Up @@ -76,6 +79,7 @@ def apply_to_layer(
train_dict,
erase_paraphrases=self.erase_paraphrases,
device=device,
lasso=self.lasso,
mode=self.supervised,
max_inlp_iter=self.max_inlp_iter,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_classifier_roughly_same_sklearn():
classifier.fit(
torch.from_numpy(features),
torch.from_numpy(truths),
l2_penalty=0.0,
alpha=0.0,
)
# check that the weights are roughly the same
sklearn_coef = torch.from_numpy(model.coef_)
Expand Down

0 comments on commit c4d23e1

Please sign in to comment.