Skip to content

Commit

Permalink
Merge pull request #147 from cnellington/dev
Browse files Browse the repository at this point in the history
Bugfixes for easy modules predict_params, added early stopping kwargs to easy modules.
  • Loading branch information
cnellington authored Nov 20, 2022
2 parents d821522 + 6b96852 commit 00dc129
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Many people have helped. Check out [ACKNOWLEDGEMENTS.md](https://github.com/cnel


## Videos
- [Cold Spring Harbor Laboratory: Contextualized Graphical Models Reveal Sample-Specific Transcriptional Networks for 7000 Tumors](https://www.youtube.com/watch?v=MTcjFK-YwCw)
- [Sample-Specific Models for Interpretable Analysis with Applications to Disease Subtyping](http://www.birs.ca/events/2022/5-day-workshops/22w5055/videos/watch/202205051559-Lengerich.html)

## Contact Us
Expand Down
26 changes: 16 additions & 10 deletions contextualized/easy/wrappers/SKLearnWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __init__(
self.default_train_batch_size = 1
self.default_test_batch_size = 16
self.default_val_split = 0.2
self.default_encoder_width = 25
self.default_encoder_layers = 3
self.default_encoder_link_fn = LINK_FUNCTIONS["identity"]
self.n_bootstraps = 1
self.models = None
self.trainers = None
Expand Down Expand Up @@ -69,6 +72,11 @@ def __init__(
"fit": [],
"wrapper": [
"n_bootstraps",
"es_patience",
"es_monitor",
"es_mode",
"es_min_delta",
"es_verbose",
],
}
self._update_acceptable_kwargs("model", extra_model_kwargs)
Expand Down Expand Up @@ -96,7 +104,7 @@ def __init__(
"layers", self.constructor_kwargs["encoder_kwargs"]["layers"]
)
self.constructor_kwargs["encoder_kwargs"]["link_fn"] = kwargs.get(
"encoder_link_fn", self.constructor_kwargs["encoder_kwargs"]["link_fn"]
"encoder_link_fn", self.constructor_kwargs["encoder_kwargs"].get("link_fn", self.default_encoder_link_fn)
)
self.not_constructor_kwargs = {
k: v
Expand Down Expand Up @@ -165,6 +173,7 @@ def maybe_add_kwarg(category, kwarg, default_val):
)
],
)
print(organized_kwargs["trainer"]["callback_constructors"])
organized_kwargs["trainer"]["callback_constructors"].append(
lambda i: ModelCheckpoint(
monitor=kwargs.get("es_monitor", "val_loss"),
Expand Down Expand Up @@ -260,13 +269,13 @@ def _build_dataloaders(self, model, train_data, val_data, **kwargs):
:param **kwargs:
"""
train_dataloader = self._build_dataloader(
model, kwargs.get("train_batch_size", 1), *train_data
model, kwargs.get("train_batch_size", self.default_train_batch_size), *train_data
)
if val_data is None:
val_dataloader = None
else:
val_dataloader = self._build_dataloader(
model, kwargs.get("val_batch_size", 16), *val_data
model, kwargs.get("val_batch_size", self.default_val_batch_size), *val_data
)

return train_dataloader, val_dataloader
Expand All @@ -287,7 +296,9 @@ def maybe_add_constructor_kwarg(kwarg, default_val):
maybe_add_constructor_kwarg("loss_fn", LOSSES["mse"])
maybe_add_constructor_kwarg(
"encoder_kwargs",
{"width": 25, "layers": 2, "link_fn": LINK_FUNCTIONS["identity"]},
{"width": kwargs.get("encoder_width", self.default_encoder_width),
"layers": kwargs.get("encoder_layers", self.default_encoder_layers),
"link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn)},
)
if kwargs.get("subtype_probabilities", False):
constructor_kwargs["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"]
Expand All @@ -308,11 +319,9 @@ def maybe_add_constructor_kwarg(kwarg, default_val):

def predict(self, C, X, individual_preds=False, **kwargs):
"""
:param C:
:param X:
:param individual_preds: (Default value = False)
"""
if not hasattr(self, "models") or self.models is None:
raise ValueError(
Expand All @@ -336,10 +345,8 @@ def predict_params(
self, C, individual_preds=False, model_includes_mus=True, **kwargs
):
"""
:param C:
:param individual_preds: (Default value = False)
"""
# Returns betas, mus
if kwargs.get("uses_y", True):
Expand All @@ -363,15 +370,14 @@ def predict_params(
else:
return np.mean(betas, axis=0), np.mean(mus, axis=0)
betas = np.array(preds)
if individual_preds:
if not individual_preds:
return np.mean(betas, axis=0)
return betas

def fit(self, *args, **kwargs):
"""
:param *args: C, X, Y (optional)
:param **kwargs:
"""
self.models = []
self.trainers = []
Expand Down

0 comments on commit 00dc129

Please sign in to comment.