Skip to content

Commit

Permalink
Update decoder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
schwabjohannes authored Dec 19, 2024
1 parent 0cc9031 commit d7d75e8
Showing 1 changed file with 103 additions and 33 deletions.
136 changes: 103 additions & 33 deletions dynamight/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def __init__(
# 10 * torch.ones(n_classes, n_points), requires_grad=True
# )
self.amp = torch.nn.Parameter(
50 * torch.ones(n_classes, n_points), requires_grad=False
50 * torch.ones(n_classes, n_points), requires_grad=True
)
self.ampvar = torch.nn.Parameter(
torch.randn(n_classes, n_points), requires_grad=True
Expand Down Expand Up @@ -592,6 +592,7 @@ def _update_positions_unmasked(

n_expanded_positions[:, self.masked_indices[i]] = torch.matmul(
euler_angles, expanded_positions[:, self.masked_indices[i]].movedim(-1, -2)).movedim(-1, -2)+translations[:, None, :]

#print(a.shape, expanded_positions[:, self.masked_indices[i]].shape)
# expanded_positions[:, masked_indices[i]] = torch.matmul(
# euler_angles, expanded_positions[:, self.masked_indices[i]].movedim(-1, -2)).movedim(-1, -2)
Expand All @@ -600,6 +601,7 @@ def _update_positions_unmasked(
# final_positions = expanded_positions + \
# displacements # (b, n_gaussians, 3)
final_positions = n_expanded_positions

displacements = expanded_positions - n_expanded_positions
return final_positions, displacements

Expand All @@ -612,12 +614,14 @@ def forward(
):
"""Decode latent variable into coordinate model and make a projection image."""
if positions is None:

posin = False
positions = self.model_positions
amp = self.amp
ampvar = self.ampvar
else:
posin = True

# amp = torch.tensor([1.0]).to(self.device)
amp = torch.ones(
self.n_classes, positions.shape[0]).to(self.device)
Expand Down Expand Up @@ -648,10 +652,13 @@ def forward(
[displacements_in_mask, torch.zeros_like(self.masked_positions).expand(self.batch_size, self.masked_positions.shape[0], 3)], 1)

# turn points into images

projected_positions = self.projector(updated_positions, orientation)
weighted_amplitudes = torch.stack(
self.batch_size * [amp*F.softmax(ampvar, dim=0)], dim=0
) # (b, n_points, n_gaussian_widths)
)

# (b, n_points, n_gaussian_widths)
# weighted_amplitudes = torch.stack(
# self.batch_size * [amp], dim=0
# )
Expand All @@ -666,51 +673,110 @@ def forward(
else:
return projection_images, updated_positions_in_mask, displacements_in_mask

def get_pose_parameters(
self, z, positions
) -> Tuple[torch.Tensor, torch.Tensor]:

shifts = []
rotations = []

if self.pos_enc_dim == 0:
encoded_positions = positions
else:
encoded_positions = positional_encoding(
positions, self.pos_enc_dim, self.box_size
)

# expand original positions to match batch side (used as residual)
expanded_positions = positions.expand(
self.batch_size, positions.shape[0], 3
)

n_expanded_positions = torch.clone(expanded_positions)

# do forward pass to calculate change in position
rigid_params = self.input(z)
rigid_params = self.layers(rigid_params)
rigid_params = self.output(rigid_params)
rigid_params = rigid_params.reshape(z.shape[0], self.n_bodies, 6)
for i in range(self.n_bodies):
euler_angles = rigid_params[:, i, :3]
translations = rigid_params[:, i, 3:]
shifts.append(translations)
rotations.append(euler_angles)

return torch.stack(shifts, 0), torch.stack(rotations, 0)

def initialize_physical_parameters(
self,
reference_volume: torch.Tensor,
# lr: float = 0.001,
lr: float = 0.001,
n_epochs: int = 50,
n_epochs: int = 200,
scale=False
):
reference_norm = torch.sum(reference_volume**2)
ref_amps = reference_norm/self.n_points
V_model = self.generate_consensus_volume()[0]
print(V_model.device, reference_volume.device)
print(self.image_smoother.A.requires_grad)
if scale == True:
with torch.no_grad():
scale_factor = reference_norm / \
torch.sum(V_model*reference_volume)
reference_volume = reference_volume * scale_factor
# self.image_smoother.A = torch.nn.Parameter(torch.linspace(
# 0.5*ref_amps, ref_amps, self.n_classes).to(self.device), requires_grad=True)
print('Optimizing scale only')
optimizer = torch.optim.Adam(
[self.image_smoother.A], lr=100*lr)
if reference_volume.shape[-1] > 360:
reference_volume = torch.nn.functional.avg_pool3d(
reference_volume.unsqueeze(0).unsqueeze(0), 2)
reference_volume = reference_volume.squeeze()

for i in range(n_epochs):
optimizer.zero_grad()
V = self.generate_consensus_volume()
loss = torch.nn.functional.mse_loss(
V[0].float(), reference_volume.to(V.device))
loss.backward()
optimizer.step()

optimizer = torch.optim.Adam(self.physical_parameters, lr=0.1*lr)
if scale == False:
print('Optimizing scale only')
optimizer = torch.optim.Adam(
[self.image_smoother.A], lr=100*lr)
if reference_volume.shape[-1] > 360:
reference_volume = torch.nn.functional.avg_pool3d(
reference_volume.unsqueeze(0).unsqueeze(0), 2)
reference_volume = reference_volume.squeeze()

for i in range(n_epochs):
optimizer.zero_grad()
V = self.generate_consensus_volume()
loss = torch.nn.functional.mse_loss(
V[0].float(), reference_volume.to(V.device).detach())
loss.backward()
optimizer.step()

if scale == False:
optimizer = torch.optim.Adam(self.physical_parameters, lr=lr)
elif scale == True:
#self.ampvar = torch.nn.Parameter(
# torch.randn(self.n_classes, self.n_points).to(self.device), requires_grad=True
#)
optimizer = torch.optim.Adam(
[self.image_smoother.B], lr=1*lr)
pos_optimizer = torch.optim.Adam([self.model_positions], lr=0.05*lr)
self.model_positions.requires_grad = True
# self.image_smoother.B.requires_grad = False
print(
'Initializing gaussian positions from reference')
for i in tqdm(range(n_epochs), file=sys.stdout):
optimizer.zero_grad()
V = self.generate_consensus_volume()
loss = torch.mean((V[0].float()-reference_volume.to(V.device))**2)
#loss = torch.mean((V[0].float()-reference_volume.to(V.device))**2)
loss = -torch.sum(V[0].float()*reference_volume.to(V.device))/(torch.sqrt(
torch.sum(V[0].float()**2))*torch.sqrt(torch.sum(reference_volume.to(V.device)**2)))
print(loss.item())
loss.backward()
optimizer.step()
#optimizer.step()
if scale == True:
pos_optimizer.step()

print('Final error:', loss.item())
self.image_smoother.A = torch.nn.Parameter(
self.image_smoother.A*(np.pi/np.sqrt(self.box_size)), requires_grad=True)
# self.amp.requires_grad = False
self.amp = torch.nn.Parameter(
self.amp*(np.pi/np.sqrt(self.box_size)), requires_grad=True)
self.image_smoother.B.requires_grad = True
if scale == False:
self.image_smoother.A = torch.nn.Parameter(
self.image_smoother.A*(np.pi/np.sqrt(self.box_size)), requires_grad=True)
# self.amp.requires_grad = False
self.amp = torch.nn.Parameter(
self.amp*(np.pi/np.sqrt(self.box_size)), requires_grad=True)
self.image_smoother.B.requires_grad = True
if self.mask is None:
self.n_active_points = self.model_positions.shape[0]
else:
Expand Down Expand Up @@ -809,7 +875,8 @@ def physical_parameters(self) -> torch.nn.ParameterList:
self.model_positions,
self.ampvar,
self.image_smoother.B,
self.image_smoother.A
self.image_smoother.A,
self.amp
]
return params

Expand Down Expand Up @@ -868,9 +935,10 @@ def generate_consensus_volume(self):

F = torch.exp(-(scaling_fac/(self.image_smoother.B[:, None, None,
None])**2) * R**2) # * (torch.nn.functional.softmax(self.image_smoother.A[
# FF = torch.real(torch.fft.fftn(torch.fft.fftshift(
# F, dim=[-3, -2, -1]), dim=[-3, -2, -1], norm='ortho'))*(0.01+self.image_smoother.A[:, None, None, None]**2)*scaling_fac / (self.image_smoother.B[:, None, None, None])
FF = torch.real(torch.fft.fftn(torch.fft.fftshift(
F, dim=[-3, -2, -1]), dim=[-3, -2, -1], norm='ortho'))*(0.01+self.image_smoother.A[:, None, None, None]**2)*scaling_fac / (self.image_smoother.B[:, None, None, None])

F, dim=[-3, -2, -1]), dim=[-3, -2, -1], norm='ortho'))*(self.image_smoother.A[:, None, None, None]**2)*scaling_fac
bs = 2
Filts = torch.stack(bs * [FF], 0)
Filtim = torch.sum(Filts * volume, 1)
Expand All @@ -889,7 +957,7 @@ def generate_volume(self, z, r, shift):
amplitudes = torch.stack(
2 * [self.amp*torch.nn.functional.softmax(self.ampvar, dim=0)], dim=0
)
_, pos, _ = self.forward(z, r, shift, positions=self.model_positions)
_, pos, _ = self.forward(z, r, shift)

if self.mask is None:
V = p2v(pos,
Expand Down Expand Up @@ -928,8 +996,10 @@ def generate_volume(self, z, r, shift):
# :, None,
# None,
# None]**2)
# FF = torch.real(torch.fft.fftn(torch.fft.fftshift(
# F, dim=[-3, -2, -1]), dim=[-3, -2, -1], norm='ortho'))*(0.01+self.image_smoother.A[:, None, None, None]**2)*scaling_fac/(self.image_smoother.B[:, None, None, None])
FF = torch.real(torch.fft.fftn(torch.fft.fftshift(
F, dim=[-3, -2, -1]), dim=[-3, -2, -1], norm='ortho'))*(0.01+self.image_smoother.A[:, None, None, None]**2)*scaling_fac/(self.image_smoother.B[:, None, None, None])
F, dim=[-3, -2, -1]), dim=[-3, -2, -1], norm='ortho'))*(self.image_smoother.A[:, None, None, None]**2)*scaling_fac
bs = V.shape[0]
Filts = torch.stack(bs * [FF], 0)
Filtim = torch.sum(Filts * V, 1)
Expand Down

0 comments on commit d7d75e8

Please sign in to comment.