Skip to content

Commit

Permalink
Empty cache before we run folding (#47)
Browse files Browse the repository at this point in the history
* Empty cache before we run folding

* More aggressively move data off GPU

* Delete modules after we are finished with them

* Remove redundant call to move inputs to cpu
  • Loading branch information
wukevin authored Sep 13, 2024
1 parent ecd62ff commit 1d3e499
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def run_folding_on_context(
if device is None:
device = torch.device("cuda:0")

# Clear memory
torch.cuda.empty_cache()

##
## Validate inputs
##
Expand Down Expand Up @@ -443,6 +446,9 @@ def run_folding_on_context(
token_single_mask=token_single_mask,
token_pair_mask=token_pair_mask,
)
# We won't be using the trunk anymore; remove it from memory
del trunk
torch.cuda.empty_cache()

##
## Denoise the trunk representation by passing it through the diffusion module
Expand Down Expand Up @@ -534,6 +540,10 @@ def _denoise(atom_pos: Tensor, sigma: Tensor, s: int) -> Tensor:
d_i_prime = (atom_pos - denoised_pos) / sigma_next
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)

# We won't be running diffusion anymore
del diffusion_module
torch.cuda.empty_cache()

##
## Run the confidence model
##
Expand Down Expand Up @@ -610,6 +620,11 @@ def avg_per_token_1d(x):
##
## Write the outputs
##
# Move data to the CPU so we don't hit GPU memory limits
inputs = move_data_to_device(inputs, torch.device("cpu"))
atom_pos = atom_pos.cpu()
plddt_logits = plddt_logits.cpu()
pae_logits = pae_logits.cpu()

# Plot coverage of tokens by MSA, save plot
output_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -671,7 +686,7 @@ def avg_per_token_1d(x):
outputs_to_cif(
coords=atom_pos[idx : idx + 1],
bfactors=scaled_plddt_scores_per_atom,
output_batch=move_data_to_device(inputs, torch.device("cpu")),
output_batch=inputs,
write_path=cif_out_path,
entity_names={
c.entity_data.entity_id: c.entity_data.entity_name
Expand Down

0 comments on commit 1d3e499

Please sign in to comment.