diff --git a/dynamight/deformable_backprojection/__pycache__/deformable_backprojection.cpython-310.pyc b/dynamight/deformable_backprojection/__pycache__/deformable_backprojection.cpython-310.pyc index 1c3d613..494fcdc 100644 Binary files a/dynamight/deformable_backprojection/__pycache__/deformable_backprojection.cpython-310.pyc and b/dynamight/deformable_backprojection/__pycache__/deformable_backprojection.cpython-310.pyc differ diff --git a/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc b/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc index 981524f..264d694 100644 Binary files a/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc and b/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc differ diff --git a/dynamight/deformations/optimize_deformations.py b/dynamight/deformations/optimize_deformations.py index 0948824..d6e80bf 100644 --- a/dynamight/deformations/optimize_deformations.py +++ b/dynamight/deformations/optimize_deformations.py @@ -156,7 +156,7 @@ def optimize_deformations( angles_op = torch.optim.Adam([angles], lr=1e-3) shifts = torch.nn.Parameter(torch.tensor( shifts, requires_grad=True).to(device)) - shifts_op = torch.optim.Adam([shifts], lr=1e-3) + shifts_op = torch.optim.Adam([shifts], lr=0) # 1e-3) # initialise training dataloaders if checkpoint_file is not None: # get subsets from checkpoint if present @@ -1152,32 +1152,6 @@ def optimize_deformations( "Loss/variance_h2_1", decoder_half2.image_smoother.B[0].detach().cpu(), epoch) - summ.add_scalar( - "Loss/amplitude_h1", - decoder_half1.image_smoother.A[0].detach().cpu(), epoch) - summ.add_scalar( - "Loss/amplitude_h2", - decoder_half2.image_smoother.A[0].detach().cpu(), epoch) - if decoder_half1.n_classes > 1: - summ.add_scalar( - "Loss/amplitude_h1_2", - decoder_half1.image_smoother.A[1].detach().cpu(), epoch) - summ.add_scalar( - "Loss/amplitude_h2_2", - decoder_half2.image_smoother.A[1].detach().cpu(), epoch) - summ.add_scalar( - "Loss/variance_h1_2", - decoder_half1.image_smoother.B[1].detach().cpu(), epoch) - summ.add_scalar( - "Loss/variance_h2_2", - decoder_half2.image_smoother.B[1].detach().cpu(), epoch) - - summ.add_figure("Data/FSC_half_maps", - tensor_plot(fourier_shell_correlation, fix=[0, 1]), epoch) - summ.add_figure("Loss/amplitudes_h1", - tensor_plot(decoder_half1.amp[0].cpu()), epoch) - summ.add_figure("Loss/amplitudes_h2", - tensor_plot(decoder_half2.amp[0].cpu()), epoch) summ.add_scalar("Loss/N_graph_h1", N_graph_h1, epoch) summ.add_scalar("Loss/N_graph_h2", N_graph_h2, epoch) summ.add_scalar("Loss/reg_param_h1", @@ -1188,29 +1162,14 @@ def optimize_deformations( consensus_update_rate_h1, epoch) summ.add_scalar("Loss/substitute_h2", consensus_update_rate_h2, epoch) - summ.add_scalar("Loss/pose_error", angular_error, epoch) - summ.add_scalar("Loss/trans_error", + summ.add_scalar("Loss/pose_change", angular_error, epoch) + summ.add_scalar("Loss/trans_change", translational_error, epoch) summ.add_figure("Data/output", tensor_imshow(torch.fft.fftshift( apply_ctf(visualization_data_half1['projection_image'][0], visualization_data_half1['ctf'][0].float()).squeeze().cpu(), dim=[-1, -2])), epoch) - if initialization_mode in ( - ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP): - summ.add_figure("Data/snr1", - tensor_plot(snr1), epoch) - summ.add_figure("Data/snr2", - tensor_plot(snr2), epoch) - summ.add_figure("Data/signal1", - tensor_plot(signal_h1), epoch) - summ.add_figure("Data/signal2", - tensor_plot(signal_h2), epoch) - - summ.add_figure( - "Data/sig", tensor_imshow(data_normalization_mask), epoch) - # summ.add_figure( - # "Data/errsig", tensor_imshow(data_err), epoch) summ.add_figure("Data/target", tensor_imshow(torch.fft.fftshift( apply_ctf(visualization_data_half1['target_image'][0], data_normalization_mask.float() @@ -1219,52 +1178,27 @@ def optimize_deformations( epoch) if mask_file == None or epoch <= n_warmup_epochs: - summ.add_figure("Data/cons_points_z_half1_noise", + summ.add_figure("Data/cons_points_z_half1_consensus", tensor_scatter(decoder_half1.model_positions[:, 0], decoder_half1.model_positions[:, 1], c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch) else: - summ.add_figure("Data/cons_points_z_half1_noise", + summ.add_figure("Data/cons_points_z_half1_consensus", tensor_scatter(old_pos_h1[:, 0], old_pos_h1[:, 1], c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch) if mask_file == None or epoch <= n_warmup_epochs: - summ.add_figure("Data/cons_points_z_half2_noise", + summ.add_figure("Data/cons_points_z_half2_consensus", tensor_scatter(decoder_half2.model_positions[:, 0], decoder_half2.model_positions[:, 1], c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch) else: - summ.add_figure("Data/cons_points_z_half2_noise", + summ.add_figure("Data/cons_points_z_half2_consensus", tensor_scatter(old_pos_h2[:, 0], old_pos_h2[:, 1], c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch) - if initialization_mode in ( - ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP): - if mask_file == None or epoch <= n_warmup_epochs: - summ.add_figure("Data/cons_points_z_half2_nsr", - tensor_scatter(decoder_half2.model_positions[:, 0], - decoder_half2.model_positions[:, 1], - c=(snr2/torch.max(snr2)).cpu(), s=3), epoch) - else: - summ.add_figure("Data/cons_points_z_half2_nsr", - tensor_scatter(old_pos_h2[:, 0], - old_pos_h2[:, 1], - c=(snr2/torch.max(snr2)).cpu(), s=3), epoch) - if initialization_mode in ( - ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP): - if mask_file == None or epoch <= n_warmup_epochs: - summ.add_figure("Data/cons_points_z_half1_nsr", - tensor_scatter(decoder_half1.model_positions[:, 0], - decoder_half1.model_positions[:, 1], - c=(snr1/torch.max(snr1)).cpu(), s=3), epoch) - else: - summ.add_figure("Data/cons_points_z_half1_nsr", - tensor_scatter(old_pos_h1[:, 0], - old_pos_h1[:, 1], - c=(snr1/torch.max(snr1)).cpu(), s=3), epoch) - summ.add_figure( "Data/deformed_points", tensor_scatter(visualization_data_half1['deformed_points'][0, :, 0],