Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
add tools for rigid version
  • Loading branch information
schwabjohannes authored Dec 19, 2024
1 parent 7b3d85a commit a6ff817
Showing 1 changed file with 47 additions and 20 deletions.
67 changes: 47 additions & 20 deletions dynamight/inverse_deformations/compute_rigid_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def compute_rigid_transforms(
data_loader_threads: int = Option(4),
pipeline_control=None,
mask=None,
rigid: Optional[bool] = Option(False),
):
forward_deformations_directory = output_directory / \
'forward_deformations' / 'checkpoints'
Expand Down Expand Up @@ -72,6 +73,8 @@ def compute_rigid_transforms(
decoder_half1.mask = None
decoder_half2.mask = None

body = None

n_points = decoder_half1.n_points

points = decoder_half1.model_positions.detach().cpu()
Expand All @@ -94,25 +97,42 @@ def compute_rigid_transforms(
encoder_half2.to(device)

masked_points = []
for file in sorted(os.listdir(masks_directory)):

if rigid == False:
try:
if file.startswith('mask'):
print('reading in:', file)
with mrcfile.open(masks_directory / file) as mrc:
mask = torch.tensor(mrc.data).to(device)
mask = mask.movedim(0, 2).movedim(0, 1)
m_points, inds = maskpoints(
decoder_half1.model_positions, decoder_half1.ampvar, mask, decoder_half1.box_size)
print(m_points.shape)
masked_points.append(m_points)
for file in sorted(os.listdir(masks_directory)):

try:
if file.startswith('mask'):
print('reading in:', file)
with mrcfile.open(masks_directory / file) as mrc:
mask = torch.tensor(mrc.data).to(device)
mask = mask.movedim(0, 2).movedim(0, 1)
m_points, inds = maskpoints(
decoder_half1.model_positions, decoder_half1.ampvar, mask, decoder_half1.box_size)
print(m_points.shape)
masked_points.append(m_points)
except:
print('No masks in this directory')
except:
print('No masks in this directory')
print(len(masked_points))
for file in sorted(os.listdir(output_directory)):
print(file)
try:
if 'mask' in file:
print('reading in:', file)
with mrcfile.open(output_directory / file) as mrc:
mask = torch.tensor(mrc.data).to(device)
mask = mask.movedim(0, 2).movedim(0, 1)
m_points, inds = maskpoints(
decoder_half1.model_positions, decoder_half1.ampvar, mask, decoder_half1.box_size)
print(m_points.shape)
masked_points.append(m_points)
except:
print('No masks in this directory')

latent_dim = encoder_half1.latent_dim

star_file = starfile.read(refinement_star_file)
star_directory = output_directory / 'subsets'
star_directory = output_directory / 'body_starfiles'
star_directory.mkdir(exist_ok=True, parents=True)

relion_dataset = RelionDataset(
Expand Down Expand Up @@ -161,13 +181,20 @@ def compute_rigid_transforms(
circular_mask_radius=diameter_ang / (2 * ang_pix),
circular_mask_thickness=mask_soft_edge_width / ang_pix
)

for i in range(len(masked_points)):
if rigid == True:
n_bodies = decoder_half1.n_bodies
masked_points = list(np.arange(n_bodies))
else:
n_bodies = len(masked_points)

for i in range(n_bodies):
if rigid == True:
body = i
print('Computing new star files for body ', body+1)
current_star_file = star_file.copy()
new_star = get_rotation_translation(
encoder_half1, decoder_half1, data_loader_half1, poses, data_preprocessor, [masked_points[i]], [current_star_file], half=1)
print(len(new_star))
encoder_half1, decoder_half1, data_loader_half1, poses, data_preprocessor, [masked_points[i]], [current_star_file], body, half=1)
new_star = get_rotation_translation(
encoder_half2, decoder_half2, data_loader_half2, poses, data_preprocessor, [masked_points[i]], new_star, half=2)
encoder_half2, decoder_half2, data_loader_half2, poses, data_preprocessor, [masked_points[i]], new_star, body, half=2)
starfile.write(new_star[0], star_directory /
('body_' + str(i+1) + '.star'))
('body_' + str(i+1) + '.star'), overwrite=True)

0 comments on commit a6ff817

Please sign in to comment.