Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
update rigid version
  • Loading branch information
schwabjohannes authored Dec 19, 2024
1 parent c713748 commit 1f57c0f
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions dynamight/deformations/optimize_deformations_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,6 @@ def optimize_deformations_rigid(
data_loader_half1,
angles,
shifts,
add_corr,
data_preprocessor,
epoch,
0,
Expand All @@ -756,7 +755,8 @@ def optimize_deformations_rigid(
consensus_update_pooled_particles=consensus_update_pooled_particles,
regularization_mode=regularization_mode_half1,
edge_weights=edge_weights_h1,
edge_weights_dis=edge_weights_dis_h1
edge_weights_dis=edge_weights_dis_h1,
add_corr = add_corr,
)
fits = True
torch.cuda.empty_cache()
Expand Down Expand Up @@ -849,7 +849,6 @@ def optimize_deformations_rigid(
data_loader_half1,
angles,
shifts,
add_corr,
data_preprocessor,
epoch,
current_warmup_epochs,
Expand All @@ -862,7 +861,8 @@ def optimize_deformations_rigid(
edge_weights=edge_weights_h1,
edge_weights_dis=edge_weights_dis_h1,
pos_epoch=pos_epoch,
ref_mask=ref_mask
ref_mask=ref_mask,
add_corr = add_corr,
)

ref_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1,
Expand Down Expand Up @@ -941,7 +941,6 @@ def optimize_deformations_rigid(
data_loader_half2,
angles,
shifts,
add_corr,
data_preprocessor,
epoch,
current_warmup_epochs,
Expand All @@ -954,7 +953,8 @@ def optimize_deformations_rigid(
edge_weights=edge_weights_h2,
edge_weights_dis=edge_weights_dis_h2,
ref_mask=ref_mask,
pos_epoch=pos_epoch
pos_epoch=pos_epoch,
add_corr = add_corr,
)
else:

Expand All @@ -968,7 +968,6 @@ def optimize_deformations_rigid(
data_loader_half2,
angles,
shifts,
add_corr,
data_preprocessor,
epoch,
current_warmup_epochs,
Expand All @@ -980,6 +979,7 @@ def optimize_deformations_rigid(
regularization_mode=regularization_mode_half2,
edge_weights=edge_weights_h2,
edge_weights_dis=edge_weights_dis_h2,
add_corr = add_corr,
)

angles_op.step()
Expand All @@ -994,14 +994,14 @@ def optimize_deformations_rigid(
data_loader_val,
angles,
shifts,
add_corr,
data_preprocessor,
epoch,
current_warmup_epochs,
data_normalization_mask,
latent_space,
latent_weight=beta,
consensus_update_pooled_particles=consensus_update_pooled_particles,
add_corr = add_corr,
)

latent_space, idix_half2_useless, sig2, Err2, err_im2 = val_epoch(
Expand All @@ -1011,14 +1011,14 @@ def optimize_deformations_rigid(
data_loader_val,
angles,
shifts,
add_corr,
data_preprocessor,
epoch,
current_warmup_epochs,
data_normalization_mask,
latent_space,
latent_weight=beta,
consensus_update_pooled_particles=consensus_update_pooled_particles,
add_corr = add_corr,
)

current_angles = angles.detach().cpu().numpy()
Expand Down Expand Up @@ -1374,7 +1374,7 @@ def optimize_deformations_rigid(
pass
epoch_t = time.time() - start_time

if epoch % 5 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1):
if epoch % 1 == 0 or (final > finalization_epochs) or (epoch == n_epochs-1):

with torch.no_grad():
V_h1 = decoder_half1.generate_consensus_volume().cpu()
Expand Down Expand Up @@ -1410,7 +1410,7 @@ def optimize_deformations_rigid(
checkpoint_file = checkpoints_directory / \
f'{epoch:03}.pth'
torch.save(checkpoint, checkpoint_file)
if epoch % 10 == 0 and epoch > 0:
if epoch % 1 == 0 and epoch > 0:
write_reconstruction_script(
output_directory, refinement_star_file, checkpoint_file, gpu_id, decoder_half1.n_bodies)
subprocess.call(output_directory /
Expand Down

0 comments on commit 1f57c0f

Please sign in to comment.