Skip to content

Commit

Permalink
tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Schwab PI Sjors Scheres added 22022021 committed Nov 7, 2023
1 parent 50e63d8 commit 86870f3
Showing 1 changed file with 120 additions and 116 deletions.
236 changes: 120 additions & 116 deletions dynamight/deformations/optimize_deformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,127 +1140,131 @@ def optimize_deformations(
x = torch.fft.fft2(x, dim=[-2, -1], norm='ortho')

if epoch % 1 == 0:
if tot_latent_dim > 2:
if epoch % 5 == 0 and epoch > n_warmup_epochs:
try:
if tot_latent_dim > 2:
if epoch % 5 == 0 and epoch > n_warmup_epochs:
summ.add_figure("Data/latent",
visualize_latent(latent_space, c=cols, s=3,
alpha=0.2, method='pca'),
epoch)
summ.add_figure("Data/latent_val",
visualize_latent(latent_space[val_indices], c=cols[val_indices], s=3,
alpha=0.2, method='pca'),
epoch)

else:
summ.add_figure("Data/latent",
visualize_latent(latent_space, c=cols, s=3,
alpha=0.2, method='pca'),
epoch)
summ.add_figure("Data/latent_val",
visualize_latent(latent_space[val_indices], c=cols[val_indices], s=3,
alpha=0.2, method='pca'),
visualize_latent(
latent_space,
c=cols,
s=3,
alpha=0.2),
epoch)

else:
summ.add_figure("Data/latent",
visualize_latent(
latent_space,
c=cols,
s=3,
alpha=0.2),
summ.add_scalar("Loss/kld_loss",
(losses_half1['latent_loss'] + losses_half2[
'latent_loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)
summ.add_scalar("Loss/mse_loss",
(losses_half1['reconstruction_loss'] + losses_half2[
'reconstruction_loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)
summ.add_scalars("Loss/mse_loss_halfs",
{'half1': (losses_half1['reconstruction_loss']) / (len(
data_loader_half1)),
'half2': (losses_half2['reconstruction_loss']) / (
len(data_loader_half2))}, epoch)
summ.add_scalar("Loss/total_loss", (
losses_half1['loss'] + losses_half2['loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)
summ.add_scalar("Loss/geometric_loss",
(losses_half1['geometric_loss'] + losses_half2[
'geometric_loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)

summ.add_scalar(
"Loss/variance_h1_1",
decoder_half1.image_smoother.B[0].detach().cpu(), epoch)
summ.add_scalar(
"Loss/variance_h2_1",
decoder_half2.image_smoother.B[0].detach().cpu(), epoch)

summ.add_scalar("Loss/N_graph_h1", N_graph_h1, epoch)
summ.add_scalar("Loss/N_graph_h2", N_graph_h2, epoch)
summ.add_scalar("Loss/reg_param_h1",
lambda_regularization_half1, epoch)
summ.add_scalar("Loss/reg_param_h2",
lambda_regularization_half2, epoch)
summ.add_scalar("Loss/substitute_h1",
consensus_update_rate_h1, epoch)
summ.add_scalar("Loss/substitute_h2",
consensus_update_rate_h2, epoch)
summ.add_scalar("Loss/pose_change",
angular_error, epoch)
summ.add_scalar("Loss/trans_change",
translational_error, epoch)
summ.add_figure("Data/output", tensor_imshow(torch.fft.fftshift(
apply_ctf(visualization_data_half1['projection_image'][0],
visualization_data_half1['ctf'][0].float()).squeeze().cpu(),
dim=[-1, -2])), epoch)

summ.add_figure("Data/target", tensor_imshow(torch.fft.fftshift(
apply_ctf(visualization_data_half1['target_image'][0],
data_normalization_mask.float()
).squeeze().cpu(),
dim=[-1, -2])),
epoch)

if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half1_consensus",
tensor_scatter(decoder_half1.model_positions[:, 0],
decoder_half1.model_positions[:, 1],
c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half1_consensus",
tensor_scatter(old_pos_h1[:, 0],
old_pos_h1[:, 1],
c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch)

if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half2_consensus",
tensor_scatter(decoder_half2.model_positions[:, 0],
decoder_half2.model_positions[:, 1],
c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half2_consensus",
tensor_scatter(old_pos_h2[:, 0],
old_pos_h2[:, 1],
c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch)

summ.add_figure(
"Data/deformed_points",
tensor_scatter(visualization_data_half1['deformed_points'][0, :, 0],
visualization_data_half1['deformed_points'][0, :, 1],
c='b',
s=0.1), epoch)

summ.add_figure("Data/projection_image",
tensor_imshow(torch.fft.fftshift(torch.real(
torch.fft.ifftn(
visualization_data_half1['projection_image'][0],
dim=[-1,
-2])).squeeze().detach().cpu(),
dim=[-1, -2])),
epoch)

summ.add_scalar("Loss/kld_loss",
(losses_half1['latent_loss'] + losses_half2[
'latent_loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)
summ.add_scalar("Loss/mse_loss",
(losses_half1['reconstruction_loss'] + losses_half2[
'reconstruction_loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)
summ.add_scalars("Loss/mse_loss_halfs",
{'half1': (losses_half1['reconstruction_loss']) / (len(
data_loader_half1)),
'half2': (losses_half2['reconstruction_loss']) / (
len(data_loader_half2))}, epoch)
summ.add_scalar("Loss/total_loss", (
losses_half1['loss'] + losses_half2['loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)
summ.add_scalar("Loss/geometric_loss",
(losses_half1['geometric_loss'] + losses_half2[
'geometric_loss']) / (
len(data_loader_half1) + len(
data_loader_half2)), epoch)

summ.add_scalar(
"Loss/variance_h1_1",
decoder_half1.image_smoother.B[0].detach().cpu(), epoch)
summ.add_scalar(
"Loss/variance_h2_1",
decoder_half2.image_smoother.B[0].detach().cpu(), epoch)

summ.add_scalar("Loss/N_graph_h1", N_graph_h1, epoch)
summ.add_scalar("Loss/N_graph_h2", N_graph_h2, epoch)
summ.add_scalar("Loss/reg_param_h1",
lambda_regularization_half1, epoch)
summ.add_scalar("Loss/reg_param_h2",
lambda_regularization_half2, epoch)
summ.add_scalar("Loss/substitute_h1",
consensus_update_rate_h1, epoch)
summ.add_scalar("Loss/substitute_h2",
consensus_update_rate_h2, epoch)
summ.add_scalar("Loss/pose_change", angular_error, epoch)
summ.add_scalar("Loss/trans_change",
translational_error, epoch)
summ.add_figure("Data/output", tensor_imshow(torch.fft.fftshift(
apply_ctf(visualization_data_half1['projection_image'][0],
visualization_data_half1['ctf'][0].float()).squeeze().cpu(),
dim=[-1, -2])), epoch)

summ.add_figure("Data/target", tensor_imshow(torch.fft.fftshift(
apply_ctf(visualization_data_half1['target_image'][0],
data_normalization_mask.float()
).squeeze().cpu(),
dim=[-1, -2])),
epoch)

if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half1_consensus",
tensor_scatter(decoder_half1.model_positions[:, 0],
decoder_half1.model_positions[:, 1],
c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half1_consensus",
tensor_scatter(old_pos_h1[:, 0],
old_pos_h1[:, 1],
c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch)

if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half2_consensus",
tensor_scatter(decoder_half2.model_positions[:, 0],
decoder_half2.model_positions[:, 1],
c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half2_consensus",
tensor_scatter(old_pos_h2[:, 0],
old_pos_h2[:, 1],
c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch)

summ.add_figure(
"Data/deformed_points",
tensor_scatter(visualization_data_half1['deformed_points'][0, :, 0],
visualization_data_half1['deformed_points'][0, :, 1],
c='b',
s=0.1), epoch)

summ.add_figure("Data/projection_image",
tensor_imshow(torch.fft.fftshift(torch.real(
torch.fft.ifftn(
visualization_data_half1['projection_image'][0],
dim=[-1,
-2])).squeeze().detach().cpu(),
dim=[-1, -2])),
epoch)

# summ.add_figure("Data/dis_var", tensor_plot(D_var), epoch)

summ.add_figure(
"Data/frc_h1", tensor_plot(frc_half1), epoch)
summ.add_figure(
"Data/frc_h2", tensor_plot(frc_half2), epoch)
# summ.add_figure("Data/dis_var", tensor_plot(D_var), epoch)

summ.add_figure(
"Data/frc_h1", tensor_plot(frc_half1), epoch)
summ.add_figure(
"Data/frc_h2", tensor_plot(frc_half2), epoch)
except:
pass
epoch_t = time.time() - start_time

if epoch % 5 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1):
Expand Down

0 comments on commit 86870f3

Please sign in to comment.