Skip to content

Commit

Permalink
Actually use lasso
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Nov 1, 2023
1 parent c4d23e1 commit 739e970
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
14 changes: 9 additions & 5 deletions elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@ def fit(
if lasso:
from cuml import Lasso

Check failure on line 90 in elk/training/classifier.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Import "cuml" could not be resolved (reportMissingImports)

model = Lasso(alpha=alpha)
model.fit(x.cpu().numpy(), y.cpu().numpy())
self.linear.weight.data = torch.from_numpy(model.coef_.T)
self.linear.bias.data = torch.from_numpy(model.intercept_)
return float(model.loss_)
model = Lasso(alpha=alpha, selection="random")
model.fit(x, y)

W = torch.as_tensor(model.coef_).unsqueeze(0).to(x.device)
b = torch.as_tensor(model.intercept_).to(x.device)

self.linear.weight.data = W
self.linear.bias.data = b
return 0.0

optimizer = torch.optim.LBFGS(
self.parameters(),
Expand Down
2 changes: 1 addition & 1 deletion elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train_supervised(
).classifiers
elif mode == "single":
lr_model = Classifier(X.shape[-1], device=device, eraser=eraser)
lr_model.fit(X, train_labels)
lr_model.fit(X, train_labels, lasso=lasso)
return [lr_model]
else:
raise ValueError(f"Unknown mode: {mode}")

0 comments on commit 739e970

Please sign in to comment.