Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MSA contexts and .aligned.pqt format #109

Merged
merged 36 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
86dc788
Add pandera dependency
wukevin Oct 10, 2024
b557103
Make the fasta object a named tuple
wukevin Oct 10, 2024
d7b68ef
Code to tokenize a3m style alignments
wukevin Oct 11, 2024
3b71795
Add tests
wukevin Oct 11, 2024
06749a8
Aligned parquet support
wukevin Oct 11, 2024
96d98a3
Example .a3m files
wukevin Oct 11, 2024
8fade94
Readability of MSAContext
wukevin Oct 11, 2024
ab8ccff
Support for more indexing on msa context
wukevin Oct 11, 2024
103cd06
Initial implementation of loading MSA contexts
wukevin Oct 11, 2024
324815b
Log warnings on missing MSAs
wukevin Oct 11, 2024
0e55ad5
Hooks to load MSA
wukevin Oct 11, 2024
36f817e
Shape fixes
wukevin Oct 11, 2024
c7cf70d
Cleanup
wukevin Oct 11, 2024
bf9d9af
Cleanup
wukevin Oct 11, 2024
8c930aa
logging
wukevin Oct 11, 2024
71a14d4
Minor
wukevin Oct 11, 2024
c920a43
Additional example code + docs
wukevin Oct 11, 2024
8279ea5
Fix fasta parsing giving incomplete headers
wukevin Oct 11, 2024
c371b66
Misc. fixes
wukevin Oct 11, 2024
a9f1787
Bugfix
wukevin Oct 11, 2024
964d4a8
Always include query row when we split into rows to pair or not pair
wukevin Oct 11, 2024
8165ff1
Check
wukevin Oct 12, 2024
eb5303d
Typo fix
wukevin Oct 12, 2024
a6da557
Add example aligned parquets
wukevin Oct 12, 2024
7e7d107
Remove unused files
wukevin Oct 12, 2024
64ff69b
Use Serialized MSA
wukevin Oct 14, 2024
0ef0d83
Remove dev code
wukevin Oct 14, 2024
7d708b1
Update examples/msas/README.md
wukevin Oct 16, 2024
efe12dc
get rid of species logic in MSAs, del SerializedMSAForSingleSequence,
arogozhnikov Oct 19, 2024
e6ce80e
rm login used for debugging
arogozhnikov Oct 19, 2024
3e027ca
update readme, misc. small changes
wukevin Oct 20, 2024
f1392e8
Do not truncate MSAs used for MSA profile
wukevin Oct 20, 2024
9514b65
Rename main msa context -> profile msa context; do not check its depth
wukevin Oct 20, 2024
7cc07c2
Rename function
wukevin Oct 20, 2024
844da7e
Make ukey pairing sort by sequence identity first
wukevin Oct 21, 2024
f2c080f
Logging for pairing
wukevin Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context
from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
Expand Down Expand Up @@ -249,6 +250,7 @@ def run_inference(
*,
output_dir: Path,
use_esm_embeddings: bool = True,
msa_directory: Path | None = None,
# expose some params for easy tweaking
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
Expand Down Expand Up @@ -276,14 +278,20 @@ def run_inference(
raise_if_too_many_tokens(n_actual_tokens)

# Load MSAs
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens,
depth=MAX_MSA_DEPTH,
)
main_msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens,
depth=MAX_MSA_DEPTH,
)
if msa_directory is not None:
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_directory
)
else:
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
msa_profile_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
assert (
msa_context.num_tokens == merged_context.num_tokens
), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}"

# Load templates
template_context = TemplateContext.empty(
Expand All @@ -305,7 +313,7 @@ def run_inference(
chains=chains,
structure_context=merged_context,
msa_context=msa_context,
main_msa_context=main_msa_context,
profile_msa_context=msa_profile_context,
template_context=template_context,
embedding_context=embedding_context,
constraint_context=constraint_context,
Expand Down Expand Up @@ -358,7 +366,7 @@ def run_folding_on_context(
raise_if_too_many_tokens(n_actual_tokens)
raise_if_too_many_templates(feature_context.template_context.num_templates)
raise_if_msa_too_deep(feature_context.msa_context.depth)
raise_if_msa_too_deep(feature_context.main_msa_context.depth)
# NOTE profile MSA used only for statistics; no depth check

##
## Prepare batch
Expand Down
15 changes: 7 additions & 8 deletions chai_lab/data/dataset/all_atom_feature_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AllAtomFeatureContext:
# Contexts: these are what get padded and batched
structure_context: AllAtomStructureContext
msa_context: MSAContext
main_msa_context: MSAContext
profile_msa_context: MSAContext
template_context: TemplateContext
embedding_context: EmbeddingContext | None
constraint_context: ConstraintContext
Expand All @@ -59,9 +59,9 @@ def pad(
max_num_tokens=n_tokens,
max_msa_depth=MAX_MSA_DEPTH,
),
main_msa_context=self.main_msa_context.pad(
profile_msa_context=self.profile_msa_context.pad(
max_num_tokens=n_tokens,
max_msa_depth=MAX_MSA_DEPTH,
# max_msa_depth=MAX_MSA_DEPTH,
),
template_context=self.template_context.pad(
max_tokens=n_tokens,
Expand All @@ -80,12 +80,11 @@ def to_dict(self) -> dict[str, Any]:
msa_tokens=self.msa_context.tokens,
msa_mask=self.msa_context.mask,
msa_deletion_matrix=self.msa_context.deletion_matrix,
msa_species=self.msa_context.species,
msa_pairkey=self.msa_context.pairing_key_hash,
msa_sequence_source=self.msa_context.sequence_source,
main_msa_tokens=self.main_msa_context.tokens,
main_msa_mask=self.main_msa_context.mask,
main_msa_deletion_matrix=self.main_msa_context.deletion_matrix,
paired_msa_depth=self.msa_context.paired_msa_depth,
main_msa_tokens=self.profile_msa_context.tokens,
main_msa_mask=self.profile_msa_context.mask,
main_msa_deletion_matrix=self.profile_msa_context.deletion_matrix,
)
return {
**self.structure_context.to_dict(),
Expand Down
77 changes: 77 additions & 0 deletions chai_lab/data/dataset/msas/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# This source code is licensed under the Chai Discovery Community License
# Agreement (LICENSE.md) found in the root directory of this source tree.

import logging
from pathlib import Path

import torch

from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.preprocess import (
drop_duplicates,
merge_main_msas_by_chain,
pair_and_merge_msas,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.msas.a3m import tokenize_sequences_to_arrays
from chai_lab.data.parsing.msas.aligned_pqt import (
expected_basename,
parse_aligned_pqt_to_msa_context,
)
from chai_lab.data.parsing.msas.data_source import MSADataSource

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def get_msa_contexts(
chains: list[Chain],
msa_directory: Path,
) -> tuple[MSAContext, MSAContext]:
"""
Looks inside msa_directory to find .aligned.pqt files to load alignments from.
Returns two contexts
- First context to tokenize and give to model
- Second context for computing summary statistics
"""

pdb_ids = set(chain.entity_data.pdb_id for chain in chains)
assert len(pdb_ids) == 1, f"Found >1 pdb ids in chains: {pdb_ids=}"

# MSAs are constructed based on sequence, so use the unique sequences present
# in input chains to determine the MSAs that need to be loaded

def get_msa_contexts_for_seq(seq) -> MSAContext:
path = msa_directory / expected_basename(seq)
if not path.is_file():
logger.warning(f"No MSA found for sequence: {seq}")
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
return MSAContext.create_single_seq(
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq)
)
msa = parse_aligned_pqt_to_msa_context(path)
logger.info(f"MSA found for sequence: {seq}, {msa.depth=}")
return msa

# For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing
# + reindex to handle residues that are tokenized per-atom (this also crops if necessary)
msa_contexts = [
get_msa_contexts_for_seq(chain.entity_data.sequence)[
:, chain.structure_context.token_residue_index
]
for chain in chains
]

# used later only for profile statistics
profile_msa = merge_main_msas_by_chain(
[drop_duplicates(msa) for msa in msa_contexts]
)

joined_msa = pair_and_merge_msas(msa_contexts)
joined_msa = drop_duplicates(joined_msa) # rare dups after pairings

logger.info(f"Prepared MSA context with {joined_msa.depth=}")
return joined_msa, profile_msa
Loading