Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MSA contexts and .aligned.pqt format #109

Merged
merged 36 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
86dc788
Add pandera dependency
wukevin Oct 10, 2024
b557103
Make the fasta object a named tuple
wukevin Oct 10, 2024
d7b68ef
Code to tokenize a3m style alignments
wukevin Oct 11, 2024
3b71795
Add tests
wukevin Oct 11, 2024
06749a8
Aligned parquet support
wukevin Oct 11, 2024
96d98a3
Example .a3m files
wukevin Oct 11, 2024
8fade94
Readability of MSAContext
wukevin Oct 11, 2024
ab8ccff
Support for more indexing on msa context
wukevin Oct 11, 2024
103cd06
Initial implementation of loading MSA contexts
wukevin Oct 11, 2024
324815b
Log warnings on missing MSAs
wukevin Oct 11, 2024
0e55ad5
Hooks to load MSA
wukevin Oct 11, 2024
36f817e
Shape fixes
wukevin Oct 11, 2024
c7cf70d
Cleanup
wukevin Oct 11, 2024
bf9d9af
Cleanup
wukevin Oct 11, 2024
8c930aa
logging
wukevin Oct 11, 2024
71a14d4
Minor
wukevin Oct 11, 2024
c920a43
Additional example code + docs
wukevin Oct 11, 2024
8279ea5
Fix fasta parsing giving incomplete headers
wukevin Oct 11, 2024
c371b66
Misc. fixes
wukevin Oct 11, 2024
a9f1787
Bugfix
wukevin Oct 11, 2024
964d4a8
Always include query row when we split into rows to pair or not pair
wukevin Oct 11, 2024
8165ff1
Check
wukevin Oct 12, 2024
eb5303d
Typo fix
wukevin Oct 12, 2024
a6da557
Add example aligned parquets
wukevin Oct 12, 2024
7e7d107
Remove unused files
wukevin Oct 12, 2024
64ff69b
Use Serialized MSA
wukevin Oct 14, 2024
0ef0d83
Remove dev code
wukevin Oct 14, 2024
7d708b1
Update examples/msas/README.md
wukevin Oct 16, 2024
efe12dc
get rid of species logic in MSAs, del SerializedMSAForSingleSequence,
arogozhnikov Oct 19, 2024
e6ce80e
rm login used for debugging
arogozhnikov Oct 19, 2024
3e027ca
update readme, misc. small changes
wukevin Oct 20, 2024
f1392e8
Do not truncate MSAs used for MSA profile
wukevin Oct 20, 2024
9514b65
Rename main msa context -> profile msa context; do not check its depth
wukevin Oct 20, 2024
7cc07c2
Rename function
wukevin Oct 20, 2024
844da7e
Make ukey pairing sort by sequence identity first
wukevin Oct 21, 2024
f2c080f
Logging for pairing
wukevin Oct 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,15 @@ def run_inference(

# Load MSAs
if msa_directory is not None:
msa_context, main_msa_context = get_msa_contexts(
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_directory
)
else:
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens,
depth=MAX_MSA_DEPTH,
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
main_msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens,
depth=MAX_MSA_DEPTH,
msa_profile_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)

# Load templates
Expand All @@ -312,7 +310,7 @@ def run_inference(
chains=chains,
structure_context=merged_context,
msa_context=msa_context,
main_msa_context=main_msa_context,
main_msa_context=msa_profile_context,
template_context=template_context,
embedding_context=embedding_context,
constraint_context=constraint_context,
Expand Down
3 changes: 1 addition & 2 deletions chai_lab/data/dataset/all_atom_feature_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@ def to_dict(self) -> dict[str, Any]:
msa_tokens=self.msa_context.tokens,
msa_mask=self.msa_context.mask,
msa_deletion_matrix=self.msa_context.deletion_matrix,
msa_species=self.msa_context.species,
msa_pairkey=self.msa_context.pairing_key_hash,
msa_sequence_source=self.msa_context.sequence_source,
main_msa_tokens=self.main_msa_context.tokens,
main_msa_mask=self.main_msa_context.mask,
main_msa_deletion_matrix=self.main_msa_context.deletion_matrix,
paired_msa_depth=self.msa_context.paired_msa_depth,
)
return {
**self.structure_context.to_dict(),
Expand Down
103 changes: 36 additions & 67 deletions chai_lab/data/dataset/msas/load.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,80 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# This source code is licensed under the Chai Discovery Community License
# Agreement (LICENSE.md) found in the root directory of this source tree.
"""
Code for loading MSAs given
"""

import logging
import pickle
from pathlib import Path

import torch

from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.preprocess import (
concatenate_paired_and_main_msas,
FULL_DEPTH,
drop_duplicates,
merge_main_msas_by_chain,
merge_msas_by_datasource,
pair_msas_by_chain_with_species_matching,
partition_msa_by_pairing_key,
pair_and_merge_msas,
)
from chai_lab.data.dataset.structure.chain import Chain
from chai_lab.data.parsing.msas.a3m import tokenize_sequences_to_arrays
from chai_lab.data.parsing.msas.aligned_pqt import (
expected_basename,
parse_aligned_pqt_to_msa_set,
parse_aligned_pqt_to_serialized_msa,
)
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.msas.serialized_msa import SerializedMSAForSingleSequence
from chai_lab.data.parsing.structure.entity_type import EntityType

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def get_msa_contexts(
chains: list[Chain],
msa_directory: Path,
) -> tuple[MSAContext, MSAContext]:
"""Returns two contexts
"""
Looks inside msa_directory to find .aligned.pqt files to load alignments from.

- First context to tokenize and give to model
- Second context for computing summary statistics.
Returns two contexts

Looks inside msa_directory to find .aligned.pqt files to load alignments from.
- First context to tokenize and give to model
- Second context for computing summary statistics
"""
with open("/tmp/chains.pkl", "wb") as f:
pickle.dump(chains, f)

pdb_ids = set(chain.entity_data.pdb_id for chain in chains)
assert len(pdb_ids) == 1, f"Found >1 pdb ids in chains: {pdb_ids=}"

# MSAs are constructed based on sequence, so use the unique sequences present
# in input chains to determine the MSAs that need to be loaded
msas_to_load = {
(chain.entity_data.sequence, chain.entity_data.entity_id)
for chain in chains
if chain.entity_data.entity_type == EntityType.PROTEIN
}

# Load up the MSAs for each chain; do this by checking for MSA sequences
msa_contexts_for_entities: dict[tuple[str, int], SerializedMSAForSingleSequence] = (
dict()
)
for seq, entity_id in msas_to_load:
def get_msa_contexts_for_seq(seq) -> MSAContext:
path = msa_directory / expected_basename(seq)
if not path.is_file():
logging.warning(f"No MSA found for sequence: {seq}")
continue
msa_contexts_for_entities[(seq, entity_id)] = parse_aligned_pqt_to_msa_set(path)
logger.warning(f"No MSA found for sequence: {seq}")
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
return MSAContext.create_single_seq(
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq)
)
msa = parse_aligned_pqt_to_serialized_msa(path)
logger.info(f"MSA found for sequence: {seq}, {msa.depth=}")
return msa

# For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing
msa_sets: list[dict[MSADataSource, MSAContext]] = [
(
MSAContext.from_serialized(msa_contexts_for_entities[k])
if (k := (chain.entity_data.sequence, chain.entity_data.entity_id))
in msa_contexts_for_entities
else {
MSADataSource.NONE: MSAContext.create(
MSADataSource.NONE,
tokens=torch.from_numpy(
tokenize_sequences_to_arrays([chain.entity_data.sequence])[
0
].squeeze(0)
),
)
}
)
# + reindex to handle residues that are tokenized per-atom (this also crops if necessary)
msa_contexts = [
get_msa_contexts_for_seq(chain.entity_data.sequence)[
:, chain.structure_context.token_residue_index
]
for chain in chains
]
assert len(msa_sets) == len(chains)

# Stack the MSA for each chain together across MSA sources
msa_contexts = [merge_msas_by_datasource(msa_set) for msa_set in msa_sets]

# Re-index to handle residues that are tokenized per-atom
msa_sets_exploded = [
msa_ctx[..., chain.structure_context.token_residue_index]
for chain, msa_ctx in zip(chains, msa_contexts, strict=True)
]

# Pair up the MSAs that have a pairing key (typically species) provided
divided = [partition_msa_by_pairing_key(m) for m in msa_sets_exploded]
pairing_contexts = [d for d, _ in divided]
paired_msa = pair_msas_by_chain_with_species_matching(pairing_contexts)

# Process main MSA - deduplicate and merge across chains
main_contexts = [d for _, d in divided]
main_msa_deduped = [drop_duplicates(msa) for msa in main_contexts]
main_msa = merge_main_msas_by_chain(main_msa_deduped)
profile_msa = merge_main_msas_by_chain(
[drop_duplicates(msa)[:FULL_DEPTH, :] for msa in msa_contexts]
) # used later only for profile statistics
jackdent marked this conversation as resolved.
Show resolved Hide resolved

# Combine the paired and main MSAs
merged_msa = concatenate_paired_and_main_msas(paired_msa, main_msa)
merged_dedup_msa = drop_duplicates(merged_msa)
joined_msa = pair_and_merge_msas(msa_contexts)
joined_msa = drop_duplicates(joined_msa) # rare dups after pairings

logging.info(f"Loaded MSA context with {merged_dedup_msa.depth=}")
return merged_dedup_msa, main_msa
print(f"Prepared MSA context with {joined_msa.depth=}")
return joined_msa, profile_msa
Loading