diff --git a/src/scvi/external/decipher/_components.py b/src/scvi/external/decipher/_components.py index 2951528aca..5439190fb9 100644 --- a/src/scvi/external/decipher/_components.py +++ b/src/scvi/external/decipher/_components.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import numpy as np import torch @@ -52,7 +52,7 @@ def __init__( # The multiple outputs are computed as a single output layer, and then split indices = np.concatenate(([0], np.cumsum(self.output_dims))) - self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:])] + self.output_slices = [slice(s, e) for s, e in zip(indices[:-1], indices[1:], strict=False)] # Create masked layers deep_context_dim = self.context_dim if self.deep_context_injection else 0 @@ -63,21 +63,15 @@ def __init__( batch_norms.append(nn.BatchNorm1d(hidden_dims[0])) for i in range(1, len(hidden_dims)): layers.append( - torch.nn.Linear( - hidden_dims[i - 1] + deep_context_dim, hidden_dims[i] - ) + torch.nn.Linear(hidden_dims[i - 1] + deep_context_dim, hidden_dims[i]) ) batch_norms.append(nn.BatchNorm1d(hidden_dims[i])) layers.append( - torch.nn.Linear( - hidden_dims[-1] + deep_context_dim, self.output_total_dim - ) + torch.nn.Linear(hidden_dims[-1] + deep_context_dim, self.output_total_dim) ) else: - layers.append( - torch.nn.Linear(input_dim + context_dim, self.output_total_dim) - ) + layers.append(torch.nn.Linear(input_dim + context_dim, self.output_total_dim)) self.layers = torch.nn.ModuleList(layers) diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index 75a8db88de..48f730f2eb 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -62,9 +62,7 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -113,9 +111,7 @@ def get_latent_representation( self._check_if_trained(warn=False) adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) latent_locs = [] for tensors in scdl: x = tensors[REGISTRY_KEYS.X_KEY] diff --git a/src/scvi/external/decipher/_module.py b/src/scvi/external/decipher/_module.py index 433bc83706..c327654527 100644 --- a/src/scvi/external/decipher/_module.py +++ b/src/scvi/external/decipher/_module.py @@ -82,9 +82,7 @@ def device(self): return self._dummy_param.device @staticmethod - def _get_fn_args_from_batch( - tensor_dict: dict[str, torch.Tensor] - ) -> Iterable | dict: + def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict: x = tensor_dict[REGISTRY_KEYS.X_KEY] return (x,), {} @@ -125,9 +123,7 @@ def model(self, x: torch.Tensor): self.theta + self._epsilon ) # noinspection PyUnresolvedReferences - x_dist = dist.NegativeBinomial( - total_count=self.theta + self._epsilon, logits=logit - ) + x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit) pyro.sample("x", x_dist.to_event(1), obs=x) @auto_move_data @@ -188,9 +184,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5): model_trace = poutine.trace( poutine.replay(self.model, trace=guide_trace) ).get_trace(x) - log_weights.append( - model_trace.log_prob_sum() - guide_trace.log_prob_sum() - ) + log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum()) finally: self.beta = old_beta diff --git a/src/scvi/external/decipher/_trainingplan.py b/src/scvi/external/decipher/_trainingplan.py index 992ad164a8..73e49f3022 100644 --- a/src/scvi/external/decipher/_trainingplan.py +++ b/src/scvi/external/decipher/_trainingplan.py @@ -41,9 +41,7 @@ def __init__( optim_kwargs.update({"lr": 5e-3}) if "weight_decay" not in optim_kwargs.keys(): optim_kwargs.update({"weight_decay": 1e-4}) - self.optim = ( - pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim - ) + self.optim = pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim # We let SVI take care of all optimization self.automatic_optimization = False diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 5067d9cdbf..fd8542914c 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -182,9 +182,7 @@ def __init__( self.optimizer_creator = optimizer_creator if self.optimizer_name == "Custom" and self.optimizer_creator is None: - raise ValueError( - "If optimizer is 'Custom', `optimizer_creator` must be provided." - ) + raise ValueError("If optimizer is 'Custom', `optimizer_creator` must be provided.") self._n_obs_training = None self._n_obs_validation = None @@ -221,9 +219,7 @@ def initialize_train_metrics(self): self.kl_local_train, self.kl_global_train, self.train_metrics, - ) = self._create_elbo_metric_components( - mode="train", n_total=self.n_obs_training - ) + ) = self._create_elbo_metric_components(mode="train", n_total=self.n_obs_training) self.elbo_train.reset() def initialize_val_metrics(self): @@ -234,9 +230,7 @@ def initialize_val_metrics(self): self.kl_local_val, self.kl_global_val, self.val_metrics, - ) = self._create_elbo_metric_components( - mode="validation", n_total=self.n_obs_validation - ) + ) = self._create_elbo_metric_components(mode="validation", n_total=self.n_obs_validation) self.elbo_val.reset() @property @@ -372,9 +366,7 @@ def validation_step(self, batch, batch_idx): ) self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") - def _optimizer_creator_fn( - self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW - ): + def _optimizer_creator_fn(self, optimizer_cls: torch.optim.Adam | torch.optim.AdamW): """Create optimizer for the model. This type of function can be passed as the `optimizer_creator` @@ -552,9 +544,7 @@ def loss_adversarial_classifier(self, z, batch_index, predict_true_class=True): if predict_true_class: cls_target = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) else: - one_hot_batch = torch.nn.functional.one_hot( - batch_index.squeeze(-1), n_classes - ) + one_hot_batch = torch.nn.functional.one_hot(batch_index.squeeze(-1), n_classes) # place zeroes where true label is cls_target = (~one_hot_batch.bool()).float() cls_target = cls_target / (n_classes - 1) @@ -582,9 +572,7 @@ def training_step(self, batch, batch_idx): else: opt1, opt2 = opts - inference_outputs, _, scvi_loss = self.forward( - batch, loss_kwargs=self.loss_kwargs - ) + inference_outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) z = inference_outputs["z"] loss = scvi_loss.loss # fool classifier if doing adversarial training @@ -617,10 +605,7 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self) -> None: """Update the learning rate via scheduler steps.""" - if ( - not self.reduce_lr_on_plateau - or "validation" not in self.lr_scheduler_metric - ): + if not self.reduce_lr_on_plateau or "validation" not in self.lr_scheduler_metric: return else: sch = self.lr_schedulers() @@ -651,9 +636,7 @@ def configure_optimizers(self): ) if self.adversarial_classifier is not False: - params2 = filter( - lambda p: p.requires_grad, self.adversarial_classifier.parameters() - ) + params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters()) optimizer2 = torch.optim.Adam( params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay ) @@ -919,9 +902,7 @@ def __init__( self.n_epochs_kl_warmup = n_epochs_kl_warmup self.use_kl_weight = False if isinstance(self.module.model, PyroModule): - self.use_kl_weight = ( - "kl_weight" in signature(self.module.model.forward).parameters - ) + self.use_kl_weight = "kl_weight" in signature(self.module.model.forward).parameters elif callable(self.module.model): self.use_kl_weight = "kl_weight" in signature(self.module.model).parameters self.scale_elbo = scale_elbo @@ -1102,9 +1083,7 @@ def __init__( optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else {} if "lr" not in optim_kwargs.keys(): optim_kwargs.update({"lr": 1e-3}) - self.optim = ( - pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim - ) + self.optim = pyro.optim.Adam(optim_args=optim_kwargs) if optim is None else optim # We let SVI take care of all optimization self.automatic_optimization = False @@ -1200,9 +1179,7 @@ def __init__( self.loss_fn = loss() if self.module.logits is False and loss == torch.nn.CrossEntropyLoss: - raise UserWarning( - "classifier should return logits when using CrossEntropyLoss." - ) + raise UserWarning("classifier should return logits when using CrossEntropyLoss.") def forward(self, *args, **kwargs): """Passthrough to the module's forward function.""" @@ -1232,9 +1209,7 @@ def configure_optimizers(self): optim_cls = torch.optim.AdamW else: raise ValueError("Optimizer not understood.") - optimizer = optim_cls( - params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay - ) + optimizer = optim_cls(params, lr=self.lr, eps=self.eps, weight_decay=self.weight_decay) return optimizer @@ -1300,11 +1275,7 @@ def __init__( def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" - clip_by = ( - optax.clip_by_global_norm(self.max_norm) - if self.max_norm - else optax.identity() - ) + clip_by = optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() if self.optimizer_name == "Adam": # Replicates PyTorch Adam defaults optim = optax.chain( @@ -1358,9 +1329,9 @@ def loss_fn(params): loss = loss_output.loss return loss, (loss_output, new_model_state) - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( - loss_fn, has_aux=True - )(state.params) + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad(loss_fn, has_aux=True)( + state.params + ) new_state = state.apply_gradients(grads=grads, state=new_model_state) return new_state, loss, loss_output