From 7b3d85a46b0357d1dd575bc8dc51385224dbeada Mon Sep 17 00:00:00 2001 From: schwabjohannes Date: Thu, 19 Dec 2024 11:06:32 +0100 Subject: [PATCH] Add files via upload add tools for rigid body version --- dynamight/inverse_deformations/rigid.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/dynamight/inverse_deformations/rigid.py b/dynamight/inverse_deformations/rigid.py index 0852681..af27242 100644 --- a/dynamight/inverse_deformations/rigid.py +++ b/dynamight/inverse_deformations/rigid.py @@ -51,6 +51,7 @@ def get_rotation_translation( data_preprocessor, masked_points, star_file_data, + body, half, ): @@ -73,11 +74,22 @@ def get_rotation_translation( for points in masked_points: - proj, pos, dis = decoder.forward( - mu, r.to(device), t.to(device), points.to(device)) + # proj, pos, dis = decoder.forward( + # mu, r.to(device), t.to(device), points.to(device)) + + # proj_all, pos_all, dis_all = decoder.forward( + # mu, r.to(device), t.to(device)) + if body != None: + proj, pos, dis = decoder.forward( + mu, r.to(device), t.to(device)) + pos = pos[:, decoder.masked_indices[body]] + dis = dis[:, decoder.masked_indices[body]] + points = decoder.model_positions[decoder.masked_indices[body]] + + else: + proj, pos, dis = decoder.forward( + mu, r.to(device), t.to(device), points.to(device)) - proj_all, pos_all, dis_all = decoder.forward( - mu, r.to(device), t.to(device)) B = pos.movedim(1, 2) A = points