diff --git a/dynamight/deformations/optimize_deformations_rigid.py b/dynamight/deformations/optimize_deformations_rigid.py index a736869..bc86e63 100644 --- a/dynamight/deformations/optimize_deformations_rigid.py +++ b/dynamight/deformations/optimize_deformations_rigid.py @@ -745,7 +745,6 @@ def optimize_deformations_rigid( data_loader_half1, angles, shifts, - add_corr, data_preprocessor, epoch, 0, @@ -756,7 +755,8 @@ def optimize_deformations_rigid( consensus_update_pooled_particles=consensus_update_pooled_particles, regularization_mode=regularization_mode_half1, edge_weights=edge_weights_h1, - edge_weights_dis=edge_weights_dis_h1 + edge_weights_dis=edge_weights_dis_h1, + add_corr = add_corr, ) fits = True torch.cuda.empty_cache() @@ -849,7 +849,6 @@ def optimize_deformations_rigid( data_loader_half1, angles, shifts, - add_corr, data_preprocessor, epoch, current_warmup_epochs, @@ -862,7 +861,8 @@ def optimize_deformations_rigid( edge_weights=edge_weights_h1, edge_weights_dis=edge_weights_dis_h1, pos_epoch=pos_epoch, - ref_mask=ref_mask + ref_mask=ref_mask, + add_corr = add_corr, ) ref_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1, @@ -941,7 +941,6 @@ def optimize_deformations_rigid( data_loader_half2, angles, shifts, - add_corr, data_preprocessor, epoch, current_warmup_epochs, @@ -954,7 +953,8 @@ def optimize_deformations_rigid( edge_weights=edge_weights_h2, edge_weights_dis=edge_weights_dis_h2, ref_mask=ref_mask, - pos_epoch=pos_epoch + pos_epoch=pos_epoch, + add_corr = add_corr, ) else: @@ -968,7 +968,6 @@ def optimize_deformations_rigid( data_loader_half2, angles, shifts, - add_corr, data_preprocessor, epoch, current_warmup_epochs, @@ -980,6 +979,7 @@ def optimize_deformations_rigid( regularization_mode=regularization_mode_half2, edge_weights=edge_weights_h2, edge_weights_dis=edge_weights_dis_h2, + add_corr = add_corr, ) angles_op.step() @@ -994,7 +994,6 @@ def optimize_deformations_rigid( data_loader_val, angles, shifts, - add_corr, data_preprocessor, epoch, current_warmup_epochs, @@ -1002,6 +1001,7 @@ def optimize_deformations_rigid( latent_space, latent_weight=beta, consensus_update_pooled_particles=consensus_update_pooled_particles, + add_corr = add_corr, ) latent_space, idix_half2_useless, sig2, Err2, err_im2 = val_epoch( @@ -1011,7 +1011,6 @@ def optimize_deformations_rigid( data_loader_val, angles, shifts, - add_corr, data_preprocessor, epoch, current_warmup_epochs, @@ -1019,6 +1018,7 @@ def optimize_deformations_rigid( latent_space, latent_weight=beta, consensus_update_pooled_particles=consensus_update_pooled_particles, + add_corr = add_corr, ) current_angles = angles.detach().cpu().numpy() @@ -1374,7 +1374,7 @@ def optimize_deformations_rigid( pass epoch_t = time.time() - start_time - if epoch % 5 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1): + if epoch % 1 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1): with torch.no_grad(): V_h1 = decoder_half1.generate_consensus_volume().cpu() @@ -1410,7 +1410,7 @@ def optimize_deformations_rigid( checkpoint_file = checkpoints_directory / \ f'{epoch:03}.pth' torch.save(checkpoint, checkpoint_file) - if epoch % 10 == 0 and epoch > 0: + if epoch % 1 == 0 and epoch > 0: write_reconstruction_script( output_directory, refinement_star_file, checkpoint_file, gpu_id, decoder_half1.n_bodies) subprocess.call(output_directory /