From 59e033f23e81dab53de16d81390d69c0ffd84dca Mon Sep 17 00:00:00 2001 From: schwabjohannes Date: Tue, 10 Dec 2024 14:15:34 +0100 Subject: [PATCH] Add files via upload update of rigid deformation estimation --- .../optimize_deformations_rigid.py | 277 ++++++++++++++---- 1 file changed, 224 insertions(+), 53 deletions(-) diff --git a/dynamight/deformations/optimize_deformations_rigid.py b/dynamight/deformations/optimize_deformations_rigid.py index 9e806fa..a736869 100644 --- a/dynamight/deformations/optimize_deformations_rigid.py +++ b/dynamight/deformations/optimize_deformations_rigid.py @@ -35,10 +35,11 @@ from ..models.utils import initialize_points_from_volume from ..utils.utils_new import compute_threshold, load_models, add_weight_decay_to_named_parameters, graph2bild, generate_form_factor, \ visualize_latent, tensor_plot, tensor_imshow, tensor_scatter, \ - apply_ctf, write_xyz, calculate_grid_oversampling_factor, generate_data_normalization_mask, FSC, radial_index_mask, radial_index_mask3 + apply_ctf, write_xyz, calculate_grid_oversampling_factor, generate_data_normalization_mask, FSC, radial_index_mask, radial_index_mask3, write_reconstruction_script, combine_maps from ._train_single_epoch_half import train_epoch, val_epoch, get_edge_weights, get_edge_weights_mask from ._update_model import update_model_positions from ..utils.coarse_grain import optimize_coarsegraining +import subprocess # TODO: add coarse graining to GitHub @@ -72,12 +73,12 @@ def optimize_deformations_rigid( weight_decay: float = 0, consensus_update_rate: float = 1, consensus_update_decay: float = 0.95, - consensus_update_pooled_particles: int = 500, + consensus_update_pooled_particles: int = 1, regularization_factor: float = 0.9, apply_bfactor: float = 0, particle_diameter: Optional[float] = None, deformation_masks: Optional[Path] = None, - soft_edge_width: float = 20, + soft_edge_width: float = 5, batch_size: int = 128, gpu_id: Optional[int] = 0, n_epochs: int = Option(150), @@ -164,6 +165,7 @@ def optimize_deformations_rigid( diameter_ang = relion_dataset.particle_diameter box_size = relion_dataset.box_size ang_pix = relion_dataset.pixel_spacing_angstroms + update_set = 1 print('Number of particles:', len(particle_dataset)) @@ -181,6 +183,8 @@ def optimize_deformations_rigid( shifts, requires_grad=True).to(device)) shifts_op = torch.optim.Adam([shifts], lr=0) # 1e-3) + add_corr = torch.zeros(len(particle_dataset)).float().to(device) + # initialise training dataloaders if checkpoint_file is not None: # get subsets from checkpoint if present cp = torch.load(checkpoint_file, map_location=device) @@ -199,6 +203,7 @@ def optimize_deformations_rigid( lambda_regularization_half1 = cp['regularization_parameter_h1'] lambda_regularization_half1 = cp['regularization_parameter_h2'] n_warmup_epochs = n_warmup_epochs + current_warmup_epochs = n_warmup_epochs half1_indices = inds_half1 half2_indices = inds_half2 print('continuing training from a given checkpoint file') @@ -211,6 +216,7 @@ def optimize_deformations_rigid( train_dataset, [len(train_dataset) // 2, len(train_dataset) - len( train_dataset) // 2]) + current_warmup_epochs = n_warmup_epochs data_loader_half1 = DataLoader( dataset=dataset_half1, @@ -502,10 +508,22 @@ def optimize_deformations_rigid( # the actual training loop for epoch in range(2*n_epochs): - if epoch % 20 < 0 and epoch > n_warmup_epochs: + if epoch % 20 < 3 and epoch > n_warmup_epochs: pos_epoch = True + # decoder_half1.amp.requires_grad = False + # decoder_half2.amp.requires_grad = False + # decoder_half1.model_positions.requires_grad = True + # decoder_half2.model_positions.requires_grad = True + # if regularization_mode_half1 != RegularizationMode.MODEL: + # physical_parameter_optimizer_half1 = torch.optim.Adam( + # decoder_half1.physical_parameters, lr=0.1*posLR) + # if regularization_mode_half2 != RegularizationMode.MODEL: + # physical_parameter_optimizer_half2 = torch.optim.Adam( + # decoder_half2.physical_parameters, lr=0.1*posLR) else: pos_epoch = False + # decoder_half1.model_positions.requires_grad = False + # decoder_half2.model_positions.requires_grad = False print(pos_epoch) abort_if_relion_abort(output_directory) # first, recompute the graphs @@ -622,9 +640,9 @@ def optimize_deformations_rigid( dec_half1_optimizer = torch.optim.Adam(dec_half1_params, lr=LR) dec_half2_optimizer = torch.optim.Adam(dec_half2_params, lr=LR) physical_parameter_optimizer_half1 = torch.optim.Adam( - decoder_half1.physical_parameters, lr=0.1*posLR) + decoder_half1.physical_parameters, lr=0.01*posLR) physical_parameter_optimizer_half2 = torch.optim.Adam( - decoder_half2.physical_parameters, lr=0.1*posLR) + decoder_half2.physical_parameters, lr=0.01*posLR) if initialization_mode in (ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP): @@ -717,7 +735,7 @@ def optimize_deformations_rigid( decoder_half1.amp.requires_grad = True decoder_half1.image_smoother.B.requires_grad = True - latent_space, losses_half1, displacement_statistics_half1, idix_half1, visualization_data_half1 = train_epoch( + latent_space, losses_half1, displacement_statistics_half1, idix_half1, visualization_data_half1, a_opt1 = train_epoch( encoder_half1, enc_half1_optimizer, decoder_half1, @@ -727,6 +745,7 @@ def optimize_deformations_rigid( data_loader_half1, angles, shifts, + add_corr, data_preprocessor, epoch, 0, @@ -778,7 +797,49 @@ def optimize_deformations_rigid( pin_memory=True ) - latent_space, losses_half1, displacement_statistics_half1, idix_half1, visualization_data_half1 = train_epoch( + if (epoch % 10 == 0 and update_set == 1 and epoch % 50 != 0) or final > 1: + + with torch.no_grad(): + print('compute mask for alignment from half2') + ref_indices = torch.round( + (decoder_half2.model_positions+0.5)*(decoder_half2.box_size-1)).long() + ref_indices = torch.clamp( + ref_indices, 0, decoder_half2.box_size-1) + ref_mask = torch.zeros( + decoder_half2.box_size, decoder_half2.box_size, decoder_half2.box_size).to(decoder_half2.device) + ref_mask[ref_indices[:, 0], + ref_indices[:, 1], ref_indices[:, 2]] = 1 + ref_mask = ref_mask.movedim(0, 2).movedim(0, 1) + + filter_width = int(np.round(combine_resolution/ang_pix)) + if filter_width % 2 == 0: + filter_width += 1 + + mean_filter = torch.ones( + 1, 1, filter_width, filter_width, filter_width) + x = torch.arange(-filter_width//2+1, filter_width//2+1) + X, Y, Z = torch.meshgrid(x, x, x, indexing='ij') + R = torch.sqrt(X**2+Y**2+Z**2) <= filter_width//2 + mean_filter = mean_filter*R[None, None, :, :, :] + mean_filter /= torch.sum(mean_filter) + # print(mean_filter.shape) + smooth_ref_mask = torch.nn.functional.conv3d( + ref_mask[None, None, :, :, :].float(), mean_filter.to(device), padding=filter_width//2) + if box_size > 400: + smooth_ref_mask = torch.nn.functional.upsample( + smooth_ref_mask, scale_factor=2) + + ref_mask = smooth_ref_mask > 5/(filter_width**3) + + ref_mask = ref_mask[0, 0] + + with mrcfile.new(volumes_directory / ('mask_half2_' + f'{epoch:03}.mrc'), overwrite=True) as mrc: + mrc.set_data(ref_mask.cpu().float().numpy()) + mrc.voxel_size = ang_pix + else: + ref_mask = None + + latent_space, losses_half1, displacement_statistics_half1, idix_half1, visualization_data_half1, a_opt2 = train_epoch( encoder_half1, enc_half1_optimizer, decoder_half1, @@ -788,9 +849,10 @@ def optimize_deformations_rigid( data_loader_half1, angles, shifts, + add_corr, data_preprocessor, epoch, - n_warmup_epochs, + current_warmup_epochs, data_normalization_mask, latent_space, latent_weight=beta, @@ -799,31 +861,44 @@ def optimize_deformations_rigid( regularization_mode=regularization_mode_half1, edge_weights=edge_weights_h1, edge_weights_dis=edge_weights_dis_h1, - pos_epoch=pos_epoch + pos_epoch=pos_epoch, + ref_mask=ref_mask ) ref_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1, decoder_half1, shifts, angles, idix_half1, consensus_update_pooled_particles, batch_size) if epoch > (n_warmup_epochs - 1) and consensus_update_rate != 0: - if losses_half1['reconstruction_loss'] < (old_loss_half1+old2_loss_half1)/2 and consensus_update_rate_h1 != 0 and (initialization_mode in (ConsensusInitializationMode.EMPTY, - ConsensusInitializationMode.MAP)) and epoch % update_frequency == 0: - # decoder_half1.model_positions = torch.nn.Parameter( - # ref_pos_h1, requires_grad=True) - #print('updated consensus model of half 1') + # and consensus_update_rate_h1 != 0 and (initialization_mode in (ConsensusInitializationMode.EMPTY, + if (epoch % 10 == 0 and update_set == 1 and epoch % 50 != 0) or final > 1: + # (losses_half1['reconstruction_loss'] < (old_loss_half1+old2_loss_half1)/2) + # ConsensusInitializationMode.MAP)) and epoch % update_frequency == 0: + decoder_half1.model_positions.data = torch.nn.Parameter( + ref_pos_h1, requires_grad=True) + print('updated consensus model of half 1') old2_loss_half1 = old_loss_half1 old_loss_half1 = losses_half1['reconstruction_loss'] nosub_ind_h1 = 0 + update_set = 2 - if epoch % 20 == 0 or (consensus_update_rate_h1 == 0 and consensus_update_rate_h2 == 0) and (initialization_mode in (ConsensusInitializationMode.EMPTY, - ConsensusInitializationMode.MAP)): + if (epoch % 10 == 0 and update_set == 2 and epoch % 50 != 0) or final > 1: with torch.no_grad(): - ref_vol = decoder_half1.generate_consensus_volume().detach() - ref_threshold = compute_threshold( - ref_vol[0], percentage=99) - ref_mask = ref_vol[0] > ref_threshold + print('compute mask for alignment from half 1') + ref_indices = torch.round( + (decoder_half1.model_positions+0.5)*(decoder_half1.box_size-1)).long() + ref_indices = torch.clamp( + ref_indices, 0, decoder_half1.box_size-1) + ref_mask = torch.zeros( + decoder_half1.box_size, decoder_half1.box_size, decoder_half1.box_size).to(decoder_half1.device) + ref_mask[ref_indices[:, 0], + ref_indices[:, 1], ref_indices[:, 2]] = 1 + ref_mask = ref_mask.movedim(0, 2).movedim(0, 1) + # ref_vol = decoder_half1.generate_consensus_volume().detach() + # ref_threshold = compute_threshold( + # ref_vol[0], percentage=99) + # ref_mask = ref_vol[0] > ref_threshold filter_width = int(np.round(combine_resolution/ang_pix)) if filter_width % 2 == 0: @@ -839,11 +914,11 @@ def optimize_deformations_rigid( # print(mean_filter.shape) smooth_ref_mask = torch.nn.functional.conv3d( ref_mask[None, None, :, :, :].float(), mean_filter.to(device), padding=filter_width//2) - if box_size > 360: + if box_size > 400: smooth_ref_mask = torch.nn.functional.upsample( smooth_ref_mask, scale_factor=2) - ref_mask = smooth_ref_mask > 2/filter_width**3 + ref_mask = smooth_ref_mask > 5/(filter_width**3) ref_mask = ref_mask[0, 0] @@ -856,7 +931,7 @@ def optimize_deformations_rigid( abort_if_relion_abort(output_directory) if initialization_mode in (ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP): - latent_space, losses_half2, displacement_statistics_half2, idix_half2, visualization_data_half2 = train_epoch( + latent_space, losses_half2, displacement_statistics_half2, idix_half2, visualization_data_half2, a_opt2 = train_epoch( encoder_half2, enc_half2_optimizer, decoder_half2, @@ -866,9 +941,10 @@ def optimize_deformations_rigid( data_loader_half2, angles, shifts, + add_corr, data_preprocessor, epoch, - n_warmup_epochs, + current_warmup_epochs, data_normalization_mask, latent_space, latent_weight=beta, @@ -882,7 +958,7 @@ def optimize_deformations_rigid( ) else: - latent_space, losses_half2, displacement_statistics_half2, idix_half2, visualization_data_half2 = train_epoch( + latent_space, losses_half2, displacement_statistics_half2, idix_half2, visualization_data_half2, a_opt2 = train_epoch( encoder_half2, enc_half2_optimizer, decoder_half2, @@ -892,9 +968,10 @@ def optimize_deformations_rigid( data_loader_half2, angles, shifts, + add_corr, data_preprocessor, epoch, - n_warmup_epochs, + current_warmup_epochs, data_normalization_mask, latent_space, latent_weight=beta, @@ -910,32 +987,34 @@ def optimize_deformations_rigid( abort_if_relion_abort(output_directory) - latent_space, idix_half1_useless, sig1, Err1 = val_epoch( + latent_space, idix_half1_useless, sig1, Err1, err_im1 = val_epoch( encoder_half1, enc_half1_optimizer, decoder_half1, data_loader_val, angles, shifts, + add_corr, data_preprocessor, epoch, - n_warmup_epochs, + current_warmup_epochs, data_normalization_mask, latent_space, latent_weight=beta, consensus_update_pooled_particles=consensus_update_pooled_particles, ) - latent_space, idix_half2_useless, sig2, Err2 = val_epoch( + latent_space, idix_half2_useless, sig2, Err2, err_im2 = val_epoch( encoder_half2, enc_half2_optimizer, decoder_half2, data_loader_val, angles, shifts, + add_corr, data_preprocessor, epoch, - n_warmup_epochs, + current_warmup_epochs, data_normalization_mask, latent_space, latent_weight=beta, @@ -956,28 +1035,32 @@ def optimize_deformations_rigid( # update consensus model if epoch > (n_warmup_epochs - 1) and (initialization_mode in (ConsensusInitializationMode.EMPTY, ConsensusInitializationMode.MAP)): - - if consensus_update_rate_h1 != 0 and epoch % update_frequency == 0: - new_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1, - decoder_half1, shifts, angles, idix_half1, consensus_update_pooled_particles, batch_size) - elif epoch % 20 != 0: - new_pos_h1 = decoder_half1.model_positions + # new_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1, + # decoder_half1, shifts, angles, idix_half1, consensus_update_pooled_particles, batch_size) + # if consensus_update_rate_h1 != 0 and epoch % update_frequency == 0: + # new_pos_h1 = update_model_positions(particle_dataset, data_preprocessor, encoder_half1, + # decoder_half1, shifts, angles, idix_half1, consensus_update_pooled_particles, batch_size) + # elif epoch % 20 != 0: + # new_pos_h1 = decoder_half1.model_positions new_pos_h2 = update_model_positions(particle_dataset, data_preprocessor, encoder_half2, decoder_half2, shifts, angles, idix_half2, consensus_update_pooled_particles, batch_size) - if (losses_half2['reconstruction_loss'] < (old_loss_half2+old2_loss_half2)/2 and consensus_update_rate_h2 != 0 and epoch % update_frequency == 0) or epoch % 20 == 0: + # and consensus_update_rate_h2 != 0 and epoch % update_frequency == 0) or epoch % 20 == 0: + if (epoch % 10 == 0 and update_set == 2 and epoch % 50 != 0) or final > 1: + # (losses_half2['reconstruction_loss'] < (old_loss_half2+old2_loss_half2)/2) # decoder_half2.model_positions = torch.nn.Parameter(( # 1 - consensus_update_rate_h2) * decoder_half2.model_positions + consensus_update_rate_h2 * new_pos_h2, requires_grad=True) # decoder_half2.model_positions = torch.nn.Parameter( # new_pos_h2, requires_grad=True) - # decoder_half2.model_positions = torch.nn.Parameter( - # new_pos_h2, requires_grad=True) - #print('updated consensus model of half 2') + decoder_half2.model_positions.data = torch.nn.Parameter( + new_pos_h2, requires_grad=True) + print('updated consensus model of half 2') old2_loss_half2 = old_loss_half2 old_loss_half2 = losses_half2['reconstruction_loss'] nosub_ind_h2 = 0 + update_set = 1 if consensus_update_rate_h1 == 0: @@ -1054,21 +1137,49 @@ def optimize_deformations_rigid( None] ** 2 # and epoch < (n_warmup_epochs + 60): - if epoch > (n_warmup_epochs+50): - if use_data_normalization == True: - Sig = 0.9*Sig + 0.1*(sig1+sig2)/2 + if epoch <= n_warmup_epochs + 50: + R2, r_mask = radial_index_mask(decoder_half1.box_size) + start_resolution = 0.5*decoder_half1.box_size + data_normalization_frequency_mask = R2 < ( + start_resolution + epoch/(n_warmup_epochs+50)*start_resolution) + data_normalization_frequency_mask = data_normalization_frequency_mask.to( + decoder_half1.device) + else: + r2, data_normalization_frequency_mask = radial_index_mask( + decoder_half1.box_size) + + if use_data_normalization == True: + Sig = 0.9*Sig + 0.1*(Err1+Err2)/2 - data_normalization_mask = 1 / \ - Sig - data_normalization_mask /= torch.max( - data_normalization_mask) - R2, r_mask = radial_index_mask(decoder_half1.box_size) + data_normalization_mask = 1 / \ + Sig + data_normalization_mask /= torch.max( + data_normalization_mask) + R2, r_mask = radial_index_mask(decoder_half1.box_size) + # r_mask2 = R2 > 5 + # r_mask *= r_mask2 - data_normalization_mask *= torch.fft.fftshift( - r_mask.to(decoder_half1.device), dim=[-1, -2]) + data_normalization_mask *= torch.fft.fftshift( + data_normalization_frequency_mask.to(decoder_half1.device), dim=[-1, -2]) - data_normalization_mask = (data_normalization_mask / - torch.sum(data_normalization_mask**2))*(box_size**2) + data_normalization_mask = ( + data_normalization_mask / torch.sum(data_normalization_mask**2))*(box_size**2) + + # if epoch > (n_warmup_epochs+50): + # if use_data_normalization == True: + # Sig = 0.9*Sig + 0.1*(sig1+sig2)/2 + + # data_normalization_mask = 1 / \ + # Sig + # data_normalization_mask /= torch.max( + # data_normalization_mask) + # R2, r_mask = radial_index_mask(decoder_half1.box_size) + + # data_normalization_mask *= torch.fft.fftshift( + # r_mask.to(decoder_half1.device), dim=[-1, -2]) + + # data_normalization_mask = (data_normalization_mask / + # torch.sum(data_normalization_mask**2))*(box_size**2) displacement_variance_half1 = displacement_statistics_half1[ 'displacement_variances'] @@ -1257,6 +1368,8 @@ def optimize_deformations_rigid( "Data/frc_h1", tensor_plot(frc_half1), epoch) summ.add_figure( "Data/frc_h2", tensor_plot(frc_half2), epoch) + summ.add_figure( + "Data/addcorr", tensor_plot(add_corr), epoch) except: pass epoch_t = time.time() - start_time @@ -1297,6 +1410,44 @@ def optimize_deformations_rigid( checkpoint_file = checkpoints_directory / \ f'{epoch:03}.pth' torch.save(checkpoint, checkpoint_file) + if epoch % 10 == 0 and epoch > 0: + write_reconstruction_script( + output_directory, refinement_star_file, checkpoint_file, gpu_id, decoder_half1.n_bodies) + subprocess.call(output_directory / + 'relion_reconstructions/reconstruct.sh') + Ivol_1, Ivol_2 = combine_maps(output_directory/'relion_reconstructions', new_mask_list, + decoder_half1.n_bodies, decoder_half1.box_size, decoder_half1.ang_pix) + + + + initial_threshold_1 = compute_threshold( + Ivol_1, percentage=99) + initial_threshold_2 = compute_threshold( + Ivol_2, percentage=99) + + print( + "Initialize new consensus model from the reconstructions") + initial_points_1 = initialize_points_from_volume( + Ivol_1.movedim(0, 2).movedim(0, 1), + threshold=initial_threshold_1, + n_points=n_points, + ) + initial_points_2 = initialize_points_from_volume( + Ivol_2.movedim(0, 2).movedim(0, 1), + threshold=initial_threshold_2, + n_points=n_points, + ) + + Ivol1_, Ivol_2 = Ivol_1.to(device), Ivol_2.to(device) + + decoder_half1.model_positions.data = torch.nn.Parameter( + initial_points_2.to(device), requires_grad=True) + decoder_half2.model_positions.data = torch.nn.Parameter( + initial_points_2.to(device), requires_grad=True) + + decoder_half1.mask_positions() + decoder_half2.mask_positions() + xyz_file = graphs_directory / ('points'+f'{epoch:03}.xyz') write_xyz( decoder_half1.model_positions, @@ -1314,6 +1465,26 @@ def optimize_deformations_rigid( mrc.set_data( (V_h2[0] / torch.mean(V_h2[0])).float().numpy()) mrc.voxel_size = ang_pix + + if epoch % 10 == 0 and epoch > 0: + print("re-optimize physical parameters") + + decoder_half1.initialize_physical_parameters( + reference_volume=Ivol_1.to(device), scale=True, n_epochs = 1000) + decoder_half2.initialize_physical_parameters( + reference_volume=Ivol_2.to(device), scale=True, n_epochs = 1000) + physical_parameter_optimizer_half1 = torch.optim.Adam( + decoder_half1.physical_parameters, lr=0.01*posLR) + physical_parameter_optimizer_half2 = torch.optim.Adam( + decoder_half2.physical_parameters, lr=0.01*posLR) + #current_warmup_epochs = epoch + n_warmup_epochs + current_warmup_epochs = n_warmup_epochs + with torch.no_grad(): + V_h1 = decoder_half1.generate_consensus_volume().cpu() + with mrcfile.new(volumes_directory / ('initialization_' + f'{epoch:03}.mrc'), overwrite=True) as mrc: + mrc.set_data((V_h1[0]/torch.mean(V_h1[0])).float().numpy()) + mrc.voxel_size = ang_pix + abort_if_relion_abort(output_directory) if epoch > n_epochs: consensus_update_rate_h1 = 0