diff --git a/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc b/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc index 5d60935..7817e22 100644 Binary files a/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc and b/dynamight/deformations/__pycache__/optimize_deformations.cpython-310.pyc differ diff --git a/dynamight/deformations/optimize_deformations.py b/dynamight/deformations/optimize_deformations.py index b4d7fc5..ab93f49 100644 --- a/dynamight/deformations/optimize_deformations.py +++ b/dynamight/deformations/optimize_deformations.py @@ -176,7 +176,7 @@ def optimize_deformations( particle_dataset, val_indices) lambda_regularization_half1 = cp['regularization_parameter_h1'] lambda_regularization_half1 = cp['regularization_parameter_h2'] - n_warmup_epochs = 0 + n_warmup_epochs = n_warmup_epochs half1_indices = inds_half1 half2_indices = inds_half2 print('continuing training from a given checkpoint file') @@ -283,6 +283,25 @@ def optimize_deformations( encoder_half1, encoder_half2, decoder_half1, decoder_half2 = load_models( checkpoint_file, device, box_size, n_classes ) + decoder_half1.model_positions = torch.nn.Parameter( + initial_points.to(device), requires_grad=True) + decoder_half2.model_positions = torch.nn.Parameter( + initial_points.to(device), requires_grad=True) + decoder_half1.amp = torch.nn.Parameter( + 50 * torch.ones(n_classes, n_points).to(device), requires_grad=False + ) + decoder_half1.ampvar = torch.nn.Parameter( + torch.randn(n_classes, n_points).to(device), requires_grad=True + ) + decoder_half2.amp = torch.nn.Parameter( + 50 * torch.ones(n_classes, n_points).to(device), requires_grad=False + ) + decoder_half2.ampvar = torch.nn.Parameter( + torch.randn(n_classes, n_points).to(device), requires_grad=True + ) + decoder_half1.n_points = n_points + decoder_half2.n_points = n_points + else: encoder_half1 = HetEncoder(box_size, latent_dim, 1).to(device) encoder_half2 = HetEncoder(box_size, latent_dim, 1).to(device) @@ -305,33 +324,33 @@ def optimize_deformations( decoder_half1 = DisplacementDecoder(**decoder_kwargs).to(device) decoder_half2 = DisplacementDecoder(**decoder_kwargs).to(device) - if initialization_mode == ConsensusInitializationMode.MAP: - with mrcfile.open(initial_model) as mrc: - Ivol = torch.tensor(mrc.data) - fits = False - while fits == False: - try: - for decoder in (decoder_half1, decoder_half2): - decoder.initialize_physical_parameters( - reference_volume=Ivol) - summ.add_figure("Data/cons_points_z_half1", - tensor_scatter(decoder_half1.model_positions[:, 0], - decoder_half1.model_positions[:, 1], c=torch.ones(decoder_half1.model_positions.shape[0]), s=3), -1) - summ.add_figure("Data/cons_points_z_half2", - tensor_scatter(decoder_half2.model_positions[:, 0], - decoder_half2.model_positions[:, 1], c=torch.ones(decoder_half2.model_positions.shape[0]), s=3), -1) - fits = True - print('consensus gaussian models initialized') - torch.cuda.empty_cache() - except Exception as error: - torch.cuda.empty_cache() - print( - 'volume too large: change size of output volumes. (If you want the original box size for the output volumes use a bigger gpu.', error) - Ivol = torch.nn.functional.avg_pool3d( - Ivol[None, None], (2, 2, 2)) - Ivol = Ivol[0, 0] - decoder_half1.vol_box = decoder_half1.vol_box//2 - decoder_half2.vol_box = decoder_half2.vol_box//2 + if initialization_mode == ConsensusInitializationMode.MAP: + with mrcfile.open(initial_model) as mrc: + Ivol = torch.tensor(mrc.data) + fits = False + while fits == False: + try: + for decoder in (decoder_half1, decoder_half2): + decoder.initialize_physical_parameters( + reference_volume=Ivol) + summ.add_figure("Data/cons_points_z_half1", + tensor_scatter(decoder_half1.model_positions[:, 0], + decoder_half1.model_positions[:, 1], c=torch.ones(decoder_half1.model_positions.shape[0]), s=3), -1) + summ.add_figure("Data/cons_points_z_half2", + tensor_scatter(decoder_half2.model_positions[:, 0], + decoder_half2.model_positions[:, 1], c=torch.ones(decoder_half2.model_positions.shape[0]), s=3), -1) + fits = True + print('consensus gaussian models initialized') + torch.cuda.empty_cache() + except Exception as error: + torch.cuda.empty_cache() + print( + 'volume too large: change size of output volumes. (If you want the original box size for the output volumes use a bigger gpu.', error) + Ivol = torch.nn.functional.avg_pool3d( + Ivol[None, None], (2, 2, 2)) + Ivol = Ivol[0, 0] + decoder_half1.vol_box = decoder_half1.vol_box//2 + decoder_half2.vol_box = decoder_half2.vol_box//2 if mask_file: with mrcfile.open(mask_file) as mrc: