From a6ff817fd9fc868ec7081045b861f134844c3c3e Mon Sep 17 00:00:00 2001 From: schwabjohannes Date: Thu, 19 Dec 2024 11:07:38 +0100 Subject: [PATCH] Add files via upload add tools for rigid version --- .../compute_rigid_transforms.py | 67 +++++++++++++------ 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/dynamight/inverse_deformations/compute_rigid_transforms.py b/dynamight/inverse_deformations/compute_rigid_transforms.py index c295ec7..41b45d6 100644 --- a/dynamight/inverse_deformations/compute_rigid_transforms.py +++ b/dynamight/inverse_deformations/compute_rigid_transforms.py @@ -42,6 +42,7 @@ def compute_rigid_transforms( data_loader_threads: int = Option(4), pipeline_control=None, mask=None, + rigid: Optional[bool] = Option(False), ): forward_deformations_directory = output_directory / \ 'forward_deformations' / 'checkpoints' @@ -72,6 +73,8 @@ def compute_rigid_transforms( decoder_half1.mask = None decoder_half2.mask = None + body = None + n_points = decoder_half1.n_points points = decoder_half1.model_positions.detach().cpu() @@ -94,25 +97,42 @@ def compute_rigid_transforms( encoder_half2.to(device) masked_points = [] - for file in sorted(os.listdir(masks_directory)): - + if rigid == False: try: - if file.startswith('mask'): - print('reading in:', file) - with mrcfile.open(masks_directory / file) as mrc: - mask = torch.tensor(mrc.data).to(device) - mask = mask.movedim(0, 2).movedim(0, 1) - m_points, inds = maskpoints( - decoder_half1.model_positions, decoder_half1.ampvar, mask, decoder_half1.box_size) - print(m_points.shape) - masked_points.append(m_points) + for file in sorted(os.listdir(masks_directory)): + + try: + if file.startswith('mask'): + print('reading in:', file) + with mrcfile.open(masks_directory / file) as mrc: + mask = torch.tensor(mrc.data).to(device) + mask = mask.movedim(0, 2).movedim(0, 1) + m_points, inds = maskpoints( + decoder_half1.model_positions, decoder_half1.ampvar, mask, decoder_half1.box_size) + print(m_points.shape) + masked_points.append(m_points) + except: + print('No masks in this directory') except: - print('No masks in this directory') - print(len(masked_points)) + for file in sorted(os.listdir(output_directory)): + print(file) + try: + if 'mask' in file: + print('reading in:', file) + with mrcfile.open(output_directory / file) as mrc: + mask = torch.tensor(mrc.data).to(device) + mask = mask.movedim(0, 2).movedim(0, 1) + m_points, inds = maskpoints( + decoder_half1.model_positions, decoder_half1.ampvar, mask, decoder_half1.box_size) + print(m_points.shape) + masked_points.append(m_points) + except: + print('No masks in this directory') + latent_dim = encoder_half1.latent_dim star_file = starfile.read(refinement_star_file) - star_directory = output_directory / 'subsets' + star_directory = output_directory / 'body_starfiles' star_directory.mkdir(exist_ok=True, parents=True) relion_dataset = RelionDataset( @@ -161,13 +181,20 @@ def compute_rigid_transforms( circular_mask_radius=diameter_ang / (2 * ang_pix), circular_mask_thickness=mask_soft_edge_width / ang_pix ) - - for i in range(len(masked_points)): + if rigid == True: + n_bodies = decoder_half1.n_bodies + masked_points = list(np.arange(n_bodies)) + else: + n_bodies = len(masked_points) + + for i in range(n_bodies): + if rigid == True: + body = i + print('Computing new star files for body ', body+1) current_star_file = star_file.copy() new_star = get_rotation_translation( - encoder_half1, decoder_half1, data_loader_half1, poses, data_preprocessor, [masked_points[i]], [current_star_file], half=1) - print(len(new_star)) + encoder_half1, decoder_half1, data_loader_half1, poses, data_preprocessor, [masked_points[i]], [current_star_file], body, half=1) new_star = get_rotation_translation( - encoder_half2, decoder_half2, data_loader_half2, poses, data_preprocessor, [masked_points[i]], new_star, half=2) + encoder_half2, decoder_half2, data_loader_half2, poses, data_preprocessor, [masked_points[i]], new_star, body, half=2) starfile.write(new_star[0], star_directory / - ('body_' + str(i+1) + '.star')) + ('body_' + str(i+1) + '.star'), overwrite=True)