Skip to content

Commit

Permalink
continuation
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Schwab PI Sjors Scheres added 22022021 committed Oct 31, 2023
1 parent 97c9eb2 commit 50e63d8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 28 deletions.
Binary file not shown.
75 changes: 47 additions & 28 deletions dynamight/deformations/optimize_deformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def optimize_deformations(
particle_dataset, val_indices)
lambda_regularization_half1 = cp['regularization_parameter_h1']
lambda_regularization_half1 = cp['regularization_parameter_h2']
n_warmup_epochs = 0
n_warmup_epochs = n_warmup_epochs
half1_indices = inds_half1
half2_indices = inds_half2
print('continuing training from a given checkpoint file')
Expand Down Expand Up @@ -283,6 +283,25 @@ def optimize_deformations(
encoder_half1, encoder_half2, decoder_half1, decoder_half2 = load_models(
checkpoint_file, device, box_size, n_classes
)
decoder_half1.model_positions = torch.nn.Parameter(
initial_points.to(device), requires_grad=True)
decoder_half2.model_positions = torch.nn.Parameter(
initial_points.to(device), requires_grad=True)
decoder_half1.amp = torch.nn.Parameter(
50 * torch.ones(n_classes, n_points).to(device), requires_grad=False
)
decoder_half1.ampvar = torch.nn.Parameter(
torch.randn(n_classes, n_points).to(device), requires_grad=True
)
decoder_half2.amp = torch.nn.Parameter(
50 * torch.ones(n_classes, n_points).to(device), requires_grad=False
)
decoder_half2.ampvar = torch.nn.Parameter(
torch.randn(n_classes, n_points).to(device), requires_grad=True
)
decoder_half1.n_points = n_points
decoder_half2.n_points = n_points

else:
encoder_half1 = HetEncoder(box_size, latent_dim, 1).to(device)
encoder_half2 = HetEncoder(box_size, latent_dim, 1).to(device)
Expand All @@ -305,33 +324,33 @@ def optimize_deformations(
decoder_half1 = DisplacementDecoder(**decoder_kwargs).to(device)
decoder_half2 = DisplacementDecoder(**decoder_kwargs).to(device)

if initialization_mode == ConsensusInitializationMode.MAP:
with mrcfile.open(initial_model) as mrc:
Ivol = torch.tensor(mrc.data)
fits = False
while fits == False:
try:
for decoder in (decoder_half1, decoder_half2):
decoder.initialize_physical_parameters(
reference_volume=Ivol)
summ.add_figure("Data/cons_points_z_half1",
tensor_scatter(decoder_half1.model_positions[:, 0],
decoder_half1.model_positions[:, 1], c=torch.ones(decoder_half1.model_positions.shape[0]), s=3), -1)
summ.add_figure("Data/cons_points_z_half2",
tensor_scatter(decoder_half2.model_positions[:, 0],
decoder_half2.model_positions[:, 1], c=torch.ones(decoder_half2.model_positions.shape[0]), s=3), -1)
fits = True
print('consensus gaussian models initialized')
torch.cuda.empty_cache()
except Exception as error:
torch.cuda.empty_cache()
print(
'volume too large: change size of output volumes. (If you want the original box size for the output volumes use a bigger gpu.', error)
Ivol = torch.nn.functional.avg_pool3d(
Ivol[None, None], (2, 2, 2))
Ivol = Ivol[0, 0]
decoder_half1.vol_box = decoder_half1.vol_box//2
decoder_half2.vol_box = decoder_half2.vol_box//2
if initialization_mode == ConsensusInitializationMode.MAP:
with mrcfile.open(initial_model) as mrc:
Ivol = torch.tensor(mrc.data)
fits = False
while fits == False:
try:
for decoder in (decoder_half1, decoder_half2):
decoder.initialize_physical_parameters(
reference_volume=Ivol)
summ.add_figure("Data/cons_points_z_half1",
tensor_scatter(decoder_half1.model_positions[:, 0],
decoder_half1.model_positions[:, 1], c=torch.ones(decoder_half1.model_positions.shape[0]), s=3), -1)
summ.add_figure("Data/cons_points_z_half2",
tensor_scatter(decoder_half2.model_positions[:, 0],
decoder_half2.model_positions[:, 1], c=torch.ones(decoder_half2.model_positions.shape[0]), s=3), -1)
fits = True
print('consensus gaussian models initialized')
torch.cuda.empty_cache()
except Exception as error:
torch.cuda.empty_cache()
print(
'volume too large: change size of output volumes. (If you want the original box size for the output volumes use a bigger gpu.', error)
Ivol = torch.nn.functional.avg_pool3d(
Ivol[None, None], (2, 2, 2))
Ivol = Ivol[0, 0]
decoder_half1.vol_box = decoder_half1.vol_box//2
decoder_half2.vol_box = decoder_half2.vol_box//2

if mask_file:
with mrcfile.open(mask_file) as mrc:
Expand Down

0 comments on commit 50e63d8

Please sign in to comment.