Skip to content

Commit

Permalink
new version
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Schwab PI Sjors Scheres added 22022021 committed Feb 29, 2024
2 parents 3a30408 + 42e2c58 commit 577e9e5
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions dynamight/deformations/optimize_deformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def optimize_deformations(
mask_file: Optional[Path] = None,
checkpoint_file: Optional[Path] = None,
n_gaussians: int = 30000,
n_gaussian_widths: int = 1,
n_gaussian_widths: int = 2,
n_latent_dimensions: int = 5,
n_positional_encoding_dimensions: int = 10,
n_linear_layers: int = 8,
Expand All @@ -86,8 +86,6 @@ def optimize_deformations(
pipeline_control=None,
use_data_normalization: bool = True,
kld_factor: float = 0.01,
lr_angles: float = 0.0,
lr_shifts: float = 0.0
):

try:
Expand Down Expand Up @@ -157,14 +155,18 @@ def optimize_deformations(

angles = torch.nn.Parameter(torch.tensor(
angles, requires_grad=True).to(device))
angles_op = torch.optim.Adam([angles], lr=lr_angles)
angles_op = torch.optim.Adam([angles], lr=1e-3)
shifts = torch.nn.Parameter(torch.tensor(
shifts, requires_grad=True).to(device))
<<<<<<< HEAD
<<<<<<< HEAD
shifts_op = torch.optim.Adam([shifts], lr=1e-3) # 1e-3)
=======
shifts_op = torch.optim.Adam([shifts], lr=lr_shifts) # 1e-3)
>>>>>>> 616360b790febf56edf08aef5d4c414058194376
=======
shifts_op = torch.optim.Adam([shifts], lr=1e-3) # 1e-3)
>>>>>>> 42e2c58d5c0cd22b012f0b5e2dc8d0fb1376beda

# initialise training dataloaders
if checkpoint_file is not None: # get subsets from checkpoint if present
Expand Down Expand Up @@ -512,7 +514,11 @@ def optimize_deformations(
decoder_half2.compute_neighbour_graph()
decoder_half1.compute_radius_graph()
decoder_half2.compute_radius_graph()
<<<<<<< HEAD
print('computing noise shit')
=======

>>>>>>> 42e2c58d5c0cd22b012f0b5e2dc8d0fb1376beda
if mask_file != None and epoch > n_warmup_epochs:
noise_h1, noise_h2, signal_h1, signal_h2, snr1, snr2, w1, w2, snr_dis1, snr_dis2, snr_e1, snr_e2 = get_edge_weights_mask(
encoder_half1,
Expand All @@ -538,7 +544,11 @@ def optimize_deformations(
shifts,
data_preprocessor,
)
<<<<<<< HEAD
print('noise shit finished')
=======

>>>>>>> 42e2c58d5c0cd22b012f0b5e2dc8d0fb1376beda
w1 = 1/torch.maximum(snr_e1, torch.tensor(0.05))
w2 = 1/torch.maximum(snr_e2, torch.tensor(0.05))

Expand Down

0 comments on commit 577e9e5

Please sign in to comment.