-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for MSA contexts and .aligned.pqt format (#109)
Co-authored-by: Jack Dent <[email protected]> Co-authored-by: Alex Rogozhnikov <[email protected]>
- Loading branch information
1 parent
394f93f
commit 756437e
Showing
18 changed files
with
785 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) 2024 Chai Discovery, Inc. | ||
# This source code is licensed under the Chai Discovery Community License | ||
# Agreement (LICENSE.md) found in the root directory of this source tree. | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from chai_lab.data.dataset.msas.msa_context import MSAContext | ||
from chai_lab.data.dataset.msas.preprocess import ( | ||
drop_duplicates, | ||
merge_main_msas_by_chain, | ||
pair_and_merge_msas, | ||
) | ||
from chai_lab.data.dataset.structure.chain import Chain | ||
from chai_lab.data.parsing.msas.a3m import tokenize_sequences_to_arrays | ||
from chai_lab.data.parsing.msas.aligned_pqt import ( | ||
expected_basename, | ||
parse_aligned_pqt_to_msa_context, | ||
) | ||
from chai_lab.data.parsing.msas.data_source import MSADataSource | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def get_msa_contexts( | ||
chains: list[Chain], | ||
msa_directory: Path, | ||
) -> tuple[MSAContext, MSAContext]: | ||
""" | ||
Looks inside msa_directory to find .aligned.pqt files to load alignments from. | ||
Returns two contexts | ||
- First context to tokenize and give to model | ||
- Second context for computing summary statistics | ||
""" | ||
|
||
pdb_ids = set(chain.entity_data.pdb_id for chain in chains) | ||
assert len(pdb_ids) == 1, f"Found >1 pdb ids in chains: {pdb_ids=}" | ||
|
||
# MSAs are constructed based on sequence, so use the unique sequences present | ||
# in input chains to determine the MSAs that need to be loaded | ||
|
||
def get_msa_contexts_for_seq(seq) -> MSAContext: | ||
path = msa_directory / expected_basename(seq) | ||
if not path.is_file(): | ||
logger.warning(f"No MSA found for sequence: {seq}") | ||
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0] | ||
return MSAContext.create_single_seq( | ||
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq) | ||
) | ||
msa = parse_aligned_pqt_to_msa_context(path) | ||
logger.info(f"MSA found for sequence: {seq}, {msa.depth=}") | ||
return msa | ||
|
||
# For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing | ||
# + reindex to handle residues that are tokenized per-atom (this also crops if necessary) | ||
msa_contexts = [ | ||
get_msa_contexts_for_seq(chain.entity_data.sequence)[ | ||
:, chain.structure_context.token_residue_index | ||
] | ||
for chain in chains | ||
] | ||
|
||
# used later only for profile statistics | ||
profile_msa = merge_main_msas_by_chain( | ||
[drop_duplicates(msa) for msa in msa_contexts] | ||
) | ||
|
||
joined_msa = pair_and_merge_msas(msa_contexts) | ||
joined_msa = drop_duplicates(joined_msa) # rare dups after pairings | ||
|
||
logger.info(f"Prepared MSA context with {joined_msa.depth=}") | ||
return joined_msa, profile_msa |
Oops, something went wrong.