Skip to content

Commit

Permalink
Support for MSA contexts and .aligned.pqt format (#109)
Browse files Browse the repository at this point in the history
Co-authored-by: Jack Dent <[email protected]>
Co-authored-by: Alex Rogozhnikov <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 394f93f commit 756437e
Show file tree
Hide file tree
Showing 18 changed files with 785 additions and 139 deletions.
28 changes: 18 additions & 10 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context
from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
Expand Down Expand Up @@ -249,6 +250,7 @@ def run_inference(
*,
output_dir: Path,
use_esm_embeddings: bool = True,
msa_directory: Path | None = None,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
Expand Down Expand Up @@ -276,14 +278,20 @@ def run_inference(
raise_if_too_many_tokens(n_actual_tokens)

# Load MSAs
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens,
depth=MAX_MSA_DEPTH,
)
main_msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens,
depth=MAX_MSA_DEPTH,
)
if msa_directory is not None:
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_directory
)
else:
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
msa_profile_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
assert (
msa_context.num_tokens == merged_context.num_tokens
), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}"

# Load templates
template_context = TemplateContext.empty(
Expand All @@ -305,7 +313,7 @@ def run_inference(
chains=chains,
structure_context=merged_context,
msa_context=msa_context,
main_msa_context=main_msa_context,
profile_msa_context=msa_profile_context,
template_context=template_context,
embedding_context=embedding_context,
constraint_context=constraint_context,
Expand Down Expand Up @@ -358,7 +366,7 @@ def run_folding_on_context(
raise_if_too_many_tokens(n_actual_tokens)
raise_if_too_many_templates(feature_context.template_context.num_templates)
raise_if_msa_too_deep(feature_context.msa_context.depth)
raise_if_msa_too_deep(feature_context.main_msa_context.depth)
# NOTE profile MSA used only for statistics; no depth check

##
## Prepare batch
Expand Down
15 changes: 7 additions & 8 deletions chai_lab/data/dataset/all_atom_feature_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AllAtomFeatureContext:
# Contexts: these are what get padded and batched
structure_context: AllAtomStructureContext
msa_context: MSAContext
main_msa_context: MSAContext
profile_msa_context: MSAContext
template_context: TemplateContext
embedding_context: EmbeddingContext | None
constraint_context: ConstraintContext
Expand All @@ -59,9 +59,9 @@ def pad(
max_num_tokens=n_tokens,
max_msa_depth=MAX_MSA_DEPTH,
),
main_msa_context=self.main_msa_context.pad(
profile_msa_context=self.profile_msa_context.pad(
max_num_tokens=n_tokens,
max_msa_depth=MAX_MSA_DEPTH,
# max_msa_depth=MAX_MSA_DEPTH,
),
template_context=self.template_context.pad(
max_tokens=n_tokens,
Expand All @@ -80,12 +80,11 @@ def to_dict(self) -> dict[str, Any]:
msa_tokens=self.msa_context.tokens,
msa_mask=self.msa_context.mask,
msa_deletion_matrix=self.msa_context.deletion_matrix,
msa_species=self.msa_context.species,
msa_pairkey=self.msa_context.pairing_key_hash,
msa_sequence_source=self.msa_context.sequence_source,
main_msa_tokens=self.main_msa_context.tokens,
main_msa_mask=self.main_msa_context.mask,
main_msa_deletion_matrix=self.main_msa_context.deletion_matrix,
paired_msa_depth=self.msa_context.paired_msa_depth,
main_msa_tokens=self.profile_msa_context.tokens,
main_msa_mask=self.profile_msa_context.mask,
main_msa_deletion_matrix=self.profile_msa_context.deletion_matrix,
)
return {
**self.structure_context.to_dict(),
Expand Down
77 changes: 77 additions & 0 deletions chai_lab/data/dataset/msas/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# This source code is licensed under the Chai Discovery Community License
# Agreement (LICENSE.md) found in the root directory of this source tree.

import logging
from pathlib import Path

import torch

from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.preprocess import (
drop_duplicates,
merge_main_msas_by_chain,
pair_and_merge_msas,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.msas.a3m import tokenize_sequences_to_arrays
from chai_lab.data.parsing.msas.aligned_pqt import (
expected_basename,
parse_aligned_pqt_to_msa_context,
)
from chai_lab.data.parsing.msas.data_source import MSADataSource

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def get_msa_contexts(
chains: list[Chain],
msa_directory: Path,
) -> tuple[MSAContext, MSAContext]:
"""
Looks inside msa_directory to find .aligned.pqt files to load alignments from.
Returns two contexts
- First context to tokenize and give to model
- Second context for computing summary statistics
"""

pdb_ids = set(chain.entity_data.pdb_id for chain in chains)
assert len(pdb_ids) == 1, f"Found >1 pdb ids in chains: {pdb_ids=}"

# MSAs are constructed based on sequence, so use the unique sequences present
# in input chains to determine the MSAs that need to be loaded

def get_msa_contexts_for_seq(seq) -> MSAContext:
path = msa_directory / expected_basename(seq)
if not path.is_file():
logger.warning(f"No MSA found for sequence: {seq}")
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
return MSAContext.create_single_seq(
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq)
)
msa = parse_aligned_pqt_to_msa_context(path)
logger.info(f"MSA found for sequence: {seq}, {msa.depth=}")
return msa

# For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing
# + reindex to handle residues that are tokenized per-atom (this also crops if necessary)
msa_contexts = [
get_msa_contexts_for_seq(chain.entity_data.sequence)[
:, chain.structure_context.token_residue_index
]
for chain in chains
]

# used later only for profile statistics
profile_msa = merge_main_msas_by_chain(
[drop_duplicates(msa) for msa in msa_contexts]
)

joined_msa = pair_and_merge_msas(msa_contexts)
joined_msa = drop_duplicates(joined_msa) # rare dups after pairings

logger.info(f"Prepared MSA context with {joined_msa.depth=}")
return joined_msa, profile_msa
Loading

0 comments on commit 756437e

Please sign in to comment.