From 86870f30f79a22ebbb05a40220226911ec64fcda Mon Sep 17 00:00:00 2001 From: Johannes Schwab PI Sjors Scheres added 22022021 Date: Tue, 7 Nov 2023 08:58:24 +0000 Subject: [PATCH] tensorboard --- .../deformations/optimize_deformations.py | 236 +++++++++--------- 1 file changed, 120 insertions(+), 116 deletions(-) diff --git a/dynamight/deformations/optimize_deformations.py b/dynamight/deformations/optimize_deformations.py index ab93f49..6cb4f1f 100644 --- a/dynamight/deformations/optimize_deformations.py +++ b/dynamight/deformations/optimize_deformations.py @@ -1140,127 +1140,131 @@ def optimize_deformations( x = torch.fft.fft2(x, dim=[-2, -1], norm='ortho') if epoch % 1 == 0: - if tot_latent_dim > 2: - if epoch % 5 == 0 and epoch > n_warmup_epochs: + try: + if tot_latent_dim > 2: + if epoch % 5 == 0 and epoch > n_warmup_epochs: + summ.add_figure("Data/latent", + visualize_latent(latent_space, c=cols, s=3, + alpha=0.2, method='pca'), + epoch) + summ.add_figure("Data/latent_val", + visualize_latent(latent_space[val_indices], c=cols[val_indices], s=3, + alpha=0.2, method='pca'), + epoch) + + else: summ.add_figure("Data/latent", - visualize_latent(latent_space, c=cols, s=3, - alpha=0.2, method='pca'), - epoch) - summ.add_figure("Data/latent_val", - visualize_latent(latent_space[val_indices], c=cols[val_indices], s=3, - alpha=0.2, method='pca'), + visualize_latent( + latent_space, + c=cols, + s=3, + alpha=0.2), epoch) - else: - summ.add_figure("Data/latent", - visualize_latent( - latent_space, - c=cols, - s=3, - alpha=0.2), + summ.add_scalar("Loss/kld_loss", + (losses_half1['latent_loss'] + losses_half2[ + 'latent_loss']) / ( + len(data_loader_half1) + len( + data_loader_half2)), epoch) + summ.add_scalar("Loss/mse_loss", + (losses_half1['reconstruction_loss'] + losses_half2[ + 'reconstruction_loss']) / ( + len(data_loader_half1) + len( + data_loader_half2)), epoch) + summ.add_scalars("Loss/mse_loss_halfs", + {'half1': (losses_half1['reconstruction_loss']) / (len( + data_loader_half1)), + 'half2': (losses_half2['reconstruction_loss']) / ( + len(data_loader_half2))}, epoch) + summ.add_scalar("Loss/total_loss", ( + losses_half1['loss'] + losses_half2['loss']) / ( + len(data_loader_half1) + len( + data_loader_half2)), epoch) + summ.add_scalar("Loss/geometric_loss", + (losses_half1['geometric_loss'] + losses_half2[ + 'geometric_loss']) / ( + len(data_loader_half1) + len( + data_loader_half2)), epoch) + + summ.add_scalar( + "Loss/variance_h1_1", + decoder_half1.image_smoother.B[0].detach().cpu(), epoch) + summ.add_scalar( + "Loss/variance_h2_1", + decoder_half2.image_smoother.B[0].detach().cpu(), epoch) + + summ.add_scalar("Loss/N_graph_h1", N_graph_h1, epoch) + summ.add_scalar("Loss/N_graph_h2", N_graph_h2, epoch) + summ.add_scalar("Loss/reg_param_h1", + lambda_regularization_half1, epoch) + summ.add_scalar("Loss/reg_param_h2", + lambda_regularization_half2, epoch) + summ.add_scalar("Loss/substitute_h1", + consensus_update_rate_h1, epoch) + summ.add_scalar("Loss/substitute_h2", + consensus_update_rate_h2, epoch) + summ.add_scalar("Loss/pose_change", + angular_error, epoch) + summ.add_scalar("Loss/trans_change", + translational_error, epoch) + summ.add_figure("Data/output", tensor_imshow(torch.fft.fftshift( + apply_ctf(visualization_data_half1['projection_image'][0], + visualization_data_half1['ctf'][0].float()).squeeze().cpu(), + dim=[-1, -2])), epoch) + + summ.add_figure("Data/target", tensor_imshow(torch.fft.fftshift( + apply_ctf(visualization_data_half1['target_image'][0], + data_normalization_mask.float() + ).squeeze().cpu(), + dim=[-1, -2])), + epoch) + + if mask_file == None or epoch <= n_warmup_epochs: + summ.add_figure("Data/cons_points_z_half1_consensus", + tensor_scatter(decoder_half1.model_positions[:, 0], + decoder_half1.model_positions[:, 1], + c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch) + else: + summ.add_figure("Data/cons_points_z_half1_consensus", + tensor_scatter(old_pos_h1[:, 0], + old_pos_h1[:, 1], + c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch) + + if mask_file == None or epoch <= n_warmup_epochs: + summ.add_figure("Data/cons_points_z_half2_consensus", + tensor_scatter(decoder_half2.model_positions[:, 0], + decoder_half2.model_positions[:, 1], + c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch) + else: + summ.add_figure("Data/cons_points_z_half2_consensus", + tensor_scatter(old_pos_h2[:, 0], + old_pos_h2[:, 1], + c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch) + + summ.add_figure( + "Data/deformed_points", + tensor_scatter(visualization_data_half1['deformed_points'][0, :, 0], + visualization_data_half1['deformed_points'][0, :, 1], + c='b', + s=0.1), epoch) + + summ.add_figure("Data/projection_image", + tensor_imshow(torch.fft.fftshift(torch.real( + torch.fft.ifftn( + visualization_data_half1['projection_image'][0], + dim=[-1, + -2])).squeeze().detach().cpu(), + dim=[-1, -2])), epoch) - summ.add_scalar("Loss/kld_loss", - (losses_half1['latent_loss'] + losses_half2[ - 'latent_loss']) / ( - len(data_loader_half1) + len( - data_loader_half2)), epoch) - summ.add_scalar("Loss/mse_loss", - (losses_half1['reconstruction_loss'] + losses_half2[ - 'reconstruction_loss']) / ( - len(data_loader_half1) + len( - data_loader_half2)), epoch) - summ.add_scalars("Loss/mse_loss_halfs", - {'half1': (losses_half1['reconstruction_loss']) / (len( - data_loader_half1)), - 'half2': (losses_half2['reconstruction_loss']) / ( - len(data_loader_half2))}, epoch) - summ.add_scalar("Loss/total_loss", ( - losses_half1['loss'] + losses_half2['loss']) / ( - len(data_loader_half1) + len( - data_loader_half2)), epoch) - summ.add_scalar("Loss/geometric_loss", - (losses_half1['geometric_loss'] + losses_half2[ - 'geometric_loss']) / ( - len(data_loader_half1) + len( - data_loader_half2)), epoch) - - summ.add_scalar( - "Loss/variance_h1_1", - decoder_half1.image_smoother.B[0].detach().cpu(), epoch) - summ.add_scalar( - "Loss/variance_h2_1", - decoder_half2.image_smoother.B[0].detach().cpu(), epoch) - - summ.add_scalar("Loss/N_graph_h1", N_graph_h1, epoch) - summ.add_scalar("Loss/N_graph_h2", N_graph_h2, epoch) - summ.add_scalar("Loss/reg_param_h1", - lambda_regularization_half1, epoch) - summ.add_scalar("Loss/reg_param_h2", - lambda_regularization_half2, epoch) - summ.add_scalar("Loss/substitute_h1", - consensus_update_rate_h1, epoch) - summ.add_scalar("Loss/substitute_h2", - consensus_update_rate_h2, epoch) - summ.add_scalar("Loss/pose_change", angular_error, epoch) - summ.add_scalar("Loss/trans_change", - translational_error, epoch) - summ.add_figure("Data/output", tensor_imshow(torch.fft.fftshift( - apply_ctf(visualization_data_half1['projection_image'][0], - visualization_data_half1['ctf'][0].float()).squeeze().cpu(), - dim=[-1, -2])), epoch) - - summ.add_figure("Data/target", tensor_imshow(torch.fft.fftshift( - apply_ctf(visualization_data_half1['target_image'][0], - data_normalization_mask.float() - ).squeeze().cpu(), - dim=[-1, -2])), - epoch) - - if mask_file == None or epoch <= n_warmup_epochs: - summ.add_figure("Data/cons_points_z_half1_consensus", - tensor_scatter(decoder_half1.model_positions[:, 0], - decoder_half1.model_positions[:, 1], - c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch) - else: - summ.add_figure("Data/cons_points_z_half1_consensus", - tensor_scatter(old_pos_h1[:, 0], - old_pos_h1[:, 1], - c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch) - - if mask_file == None or epoch <= n_warmup_epochs: - summ.add_figure("Data/cons_points_z_half2_consensus", - tensor_scatter(decoder_half2.model_positions[:, 0], - decoder_half2.model_positions[:, 1], - c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch) - else: - summ.add_figure("Data/cons_points_z_half2_consensus", - tensor_scatter(old_pos_h2[:, 0], - old_pos_h2[:, 1], - c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch) - - summ.add_figure( - "Data/deformed_points", - tensor_scatter(visualization_data_half1['deformed_points'][0, :, 0], - visualization_data_half1['deformed_points'][0, :, 1], - c='b', - s=0.1), epoch) - - summ.add_figure("Data/projection_image", - tensor_imshow(torch.fft.fftshift(torch.real( - torch.fft.ifftn( - visualization_data_half1['projection_image'][0], - dim=[-1, - -2])).squeeze().detach().cpu(), - dim=[-1, -2])), - epoch) - - # summ.add_figure("Data/dis_var", tensor_plot(D_var), epoch) - - summ.add_figure( - "Data/frc_h1", tensor_plot(frc_half1), epoch) - summ.add_figure( - "Data/frc_h2", tensor_plot(frc_half2), epoch) + # summ.add_figure("Data/dis_var", tensor_plot(D_var), epoch) + + summ.add_figure( + "Data/frc_h1", tensor_plot(frc_half1), epoch) + summ.add_figure( + "Data/frc_h2", tensor_plot(frc_half2), epoch) + except: + pass epoch_t = time.time() - start_time if epoch % 5 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1):