Skip to content

Commit

Permalink
clean tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Schwab PI Sjors Scheres added 22022021 committed Oct 20, 2023
1 parent 71e0032 commit 9d047ea
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 73 deletions.
Binary file not shown.
Binary file not shown.
80 changes: 7 additions & 73 deletions dynamight/deformations/optimize_deformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def optimize_deformations(
angles_op = torch.optim.Adam([angles], lr=1e-3)
shifts = torch.nn.Parameter(torch.tensor(
shifts, requires_grad=True).to(device))
shifts_op = torch.optim.Adam([shifts], lr=1e-3)
shifts_op = torch.optim.Adam([shifts], lr=0) # 1e-3)

# initialise training dataloaders
if checkpoint_file is not None: # get subsets from checkpoint if present
Expand Down Expand Up @@ -1152,32 +1152,6 @@ def optimize_deformations(
"Loss/variance_h2_1",
decoder_half2.image_smoother.B[0].detach().cpu(), epoch)

summ.add_scalar(
"Loss/amplitude_h1",
decoder_half1.image_smoother.A[0].detach().cpu(), epoch)
summ.add_scalar(
"Loss/amplitude_h2",
decoder_half2.image_smoother.A[0].detach().cpu(), epoch)
if decoder_half1.n_classes > 1:
summ.add_scalar(
"Loss/amplitude_h1_2",
decoder_half1.image_smoother.A[1].detach().cpu(), epoch)
summ.add_scalar(
"Loss/amplitude_h2_2",
decoder_half2.image_smoother.A[1].detach().cpu(), epoch)
summ.add_scalar(
"Loss/variance_h1_2",
decoder_half1.image_smoother.B[1].detach().cpu(), epoch)
summ.add_scalar(
"Loss/variance_h2_2",
decoder_half2.image_smoother.B[1].detach().cpu(), epoch)

summ.add_figure("Data/FSC_half_maps",
tensor_plot(fourier_shell_correlation, fix=[0, 1]), epoch)
summ.add_figure("Loss/amplitudes_h1",
tensor_plot(decoder_half1.amp[0].cpu()), epoch)
summ.add_figure("Loss/amplitudes_h2",
tensor_plot(decoder_half2.amp[0].cpu()), epoch)
summ.add_scalar("Loss/N_graph_h1", N_graph_h1, epoch)
summ.add_scalar("Loss/N_graph_h2", N_graph_h2, epoch)
summ.add_scalar("Loss/reg_param_h1",
Expand All @@ -1188,29 +1162,14 @@ def optimize_deformations(
consensus_update_rate_h1, epoch)
summ.add_scalar("Loss/substitute_h2",
consensus_update_rate_h2, epoch)
summ.add_scalar("Loss/pose_error", angular_error, epoch)
summ.add_scalar("Loss/trans_error",
summ.add_scalar("Loss/pose_change", angular_error, epoch)
summ.add_scalar("Loss/trans_change",
translational_error, epoch)
summ.add_figure("Data/output", tensor_imshow(torch.fft.fftshift(
apply_ctf(visualization_data_half1['projection_image'][0],
visualization_data_half1['ctf'][0].float()).squeeze().cpu(),
dim=[-1, -2])), epoch)

if initialization_mode in (
ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP):
summ.add_figure("Data/snr1",
tensor_plot(snr1), epoch)
summ.add_figure("Data/snr2",
tensor_plot(snr2), epoch)
summ.add_figure("Data/signal1",
tensor_plot(signal_h1), epoch)
summ.add_figure("Data/signal2",
tensor_plot(signal_h2), epoch)

summ.add_figure(
"Data/sig", tensor_imshow(data_normalization_mask), epoch)
# summ.add_figure(
# "Data/errsig", tensor_imshow(data_err), epoch)
summ.add_figure("Data/target", tensor_imshow(torch.fft.fftshift(
apply_ctf(visualization_data_half1['target_image'][0],
data_normalization_mask.float()
Expand All @@ -1219,52 +1178,27 @@ def optimize_deformations(
epoch)

if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half1_noise",
summ.add_figure("Data/cons_points_z_half1_consensus",
tensor_scatter(decoder_half1.model_positions[:, 0],
decoder_half1.model_positions[:, 1],
c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half1_noise",
summ.add_figure("Data/cons_points_z_half1_consensus",
tensor_scatter(old_pos_h1[:, 0],
old_pos_h1[:, 1],
c=(noise_h1/torch.max(noise_h1)).cpu(), s=3), epoch)

if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half2_noise",
summ.add_figure("Data/cons_points_z_half2_consensus",
tensor_scatter(decoder_half2.model_positions[:, 0],
decoder_half2.model_positions[:, 1],
c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half2_noise",
summ.add_figure("Data/cons_points_z_half2_consensus",
tensor_scatter(old_pos_h2[:, 0],
old_pos_h2[:, 1],
c=(noise_h2/torch.max(noise_h2)).cpu(), s=3), epoch)

if initialization_mode in (
ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP):
if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half2_nsr",
tensor_scatter(decoder_half2.model_positions[:, 0],
decoder_half2.model_positions[:, 1],
c=(snr2/torch.max(snr2)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half2_nsr",
tensor_scatter(old_pos_h2[:, 0],
old_pos_h2[:, 1],
c=(snr2/torch.max(snr2)).cpu(), s=3), epoch)
if initialization_mode in (
ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP):
if mask_file == None or epoch <= n_warmup_epochs:
summ.add_figure("Data/cons_points_z_half1_nsr",
tensor_scatter(decoder_half1.model_positions[:, 0],
decoder_half1.model_positions[:, 1],
c=(snr1/torch.max(snr1)).cpu(), s=3), epoch)
else:
summ.add_figure("Data/cons_points_z_half1_nsr",
tensor_scatter(old_pos_h1[:, 0],
old_pos_h1[:, 1],
c=(snr1/torch.max(snr1)).cpu(), s=3), epoch)

summ.add_figure(
"Data/deformed_points",
tensor_scatter(visualization_data_half1['deformed_points'][0, :, 0],
Expand Down

0 comments on commit 9d047ea

Please sign in to comment.