Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Schwab PI Sjors Scheres added 22022021 committed Oct 21, 2023
1 parent 24db3df commit 5935e87
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions dynamight/deformations/optimize_deformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def optimize_deformations(
if consensus_update_rate_h1 != 0:
new_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1,
decoder_half1, shifts, angles, idix_half1, consensus_update_pooled_particles, batch_size)
else:
elif epoch % 20 != 0:
new_pos_h1 = decoder_half1.model_positions
new_pos_h2 = update_model_positions(particle_dataset, data_preprocessor, encoder_half2,
decoder_half2, shifts, angles, idix_half2, consensus_update_pooled_particles, batch_size)
Expand Down Expand Up @@ -958,7 +958,7 @@ def optimize_deformations(
if losses_half1['reconstruction_loss'] > (old_loss_half1+old2_loss_half1)/2 and consensus_update_rate_h1 != 0:
nosub_ind_h1 += 1

if nosub_ind_h1 == 1:
if nosub_ind_h1 > 1:
consensus_update_rate_h1 *= consensus_update_decay
if consensus_update_rate_h1 < 0.1:
consensus_update_rate_h1 = 0
Expand All @@ -969,7 +969,7 @@ def optimize_deformations(
# for g in dec_half2_optimizer.param_groups:
# g['lr'] *= 0.9
# print('new learning rate for half 2 is', g['lr'])
if nosub_ind_h2 == 1:
if nosub_ind_h2 > 1:
consensus_update_rate_h2 *= consensus_update_decay
if consensus_update_rate_h2 < 0.1:
consensus_update_rate_h2 = 0
Expand Down
Binary file not shown.
Binary file modified dynamight/evaluation/__pycache__/visualizer.cpython-310.pyc
Binary file not shown.
Binary file modified dynamight/models/__pycache__/decoder.cpython-310.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions dynamight/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ def baseline_parameters(self) -> torch.nn.ParameterList:
"""Parameters which make up a coordinate model."""
params = [
self.image_smoother.A,
# self.amp
# self.image_smoother.B
self.ampvar,
self.image_smoother.B
]
return params

Expand Down

0 comments on commit 5935e87

Please sign in to comment.