diff --git a/src/scvi/external/decipher/_model.py b/src/scvi/external/decipher/_model.py index ed7c8a5193..56a34bcfa5 100644 --- a/src/scvi/external/decipher/_model.py +++ b/src/scvi/external/decipher/_model.py @@ -199,7 +199,7 @@ def compute_imputed_gene_expression( mu = self.module.decoder_z_to_x(z_loc) mu = F.softmax(mu, dim=-1) library_size = x.sum(axis=-1, keepdim=True) - imputed_gene_expr = (library_size * mu).detach().cpu().numpy() + imputed_gene_expr = (library_size * mu.detach().cpu()).numpy() imputed_gene_expression_batches.append(imputed_gene_expr) imputed_gene_expression = np.concatenate(imputed_gene_expression_batches, axis=0) @@ -300,7 +300,7 @@ def compute_gene_patterns( t_points = trajectory.trajectory_latent t_times = trajectory.trajectory_time - t_points = torch.FloatTensor(t_points) + t_points = torch.FloatTensor(t_points).to(self.module.device) z_mean, z_scale = self.module.decoder_v_to_z(t_points) z_scale = F.softplus(z_scale) @@ -308,11 +308,12 @@ def compute_gene_patterns( gene_patterns = {} gene_patterns["mean"] = ( - F.softmax(self.module.decoder_z_to_x(z_mean), dim=-1).detach().numpy() * l_scale + F.softmax(self.module.decoder_z_to_x(z_mean), dim=-1).detach().cpu().numpy() * l_scale ) gene_expression_samples = ( - F.softmax(self.module.decoder_z_to_x(z_samples), dim=-1).detach().numpy() * l_scale + F.softmax(self.module.decoder_z_to_x(z_samples), dim=-1).detach().cpu().numpy() + * l_scale ) gene_patterns["q25"] = np.quantile(gene_expression_samples, 0.25, axis=0) gene_patterns["q75"] = np.quantile(gene_expression_samples, 0.75, axis=0)