Skip to content

Commit

Permalink
cleaned up
Browse files Browse the repository at this point in the history
  • Loading branch information
alitinet committed Dec 19, 2024
1 parent 6767660 commit bd56d27
Showing 1 changed file with 0 additions and 11 deletions.
11 changes: 0 additions & 11 deletions src/multigrate/model/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,31 +261,20 @@ def get_model_output(self, adata=None, batch_size=256, save_unimodal_params=Fals
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
z = outputs["z_joint"]
# print('z_joint')
# print(outputs["z_joint"].shape)
if save_unimodal_latent is True:
# print('z_marginal')
# print(len(outputs["z_marginal"]))
# print(outputs["z_marginal"].shape)
z_marginal += [outputs["z_marginal"].cpu()]
if save_unimodal_params is True:
# print('params marginal')
# print(outputs["mu_marginal"].shape)
# print(outputs["logvar_marginal"].shape)
mu_marginal += [outputs["mu_marginal"].cpu()]
logvar_marginal += [outputs["logvar_marginal"].cpu()]
latent += [z.cpu()]

if save_unimodal_latent is True:
z_marginal = torch.cat(z_marginal)
print(z_marginal.shape)
for i in range(z_marginal.shape[1]):
adata.obsm[f"X_unimodal_{i}"] = z_marginal[:, i, :].squeeze(1).numpy()
if save_unimodal_params is True:
mu_marginal = torch.cat(mu_marginal)
logvar_marginal = torch.cat(logvar_marginal)
print(mu_marginal.shape)
print(logvar_marginal.shape)
for i in range(mu_marginal.shape[1]):
adata.obsm[f"mu_unimodal_{i}"] = mu_marginal[:, i, :].squeeze(1).numpy()
adata.obsm[f"logvar_unimodal_{i}"] = logvar_marginal[:, i, :].squeeze(1).numpy()
Expand Down

0 comments on commit bd56d27

Please sign in to comment.