Skip to content

Commit

Permalink
Rename main msa context -> profile msa context; do not check its depth
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Oct 20, 2024
1 parent f1392e8 commit 9514b65
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
4 changes: 2 additions & 2 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def run_inference(
chains=chains,
structure_context=merged_context,
msa_context=msa_context,
main_msa_context=msa_profile_context,
profile_msa_context=msa_profile_context,
template_context=template_context,
embedding_context=embedding_context,
constraint_context=constraint_context,
Expand Down Expand Up @@ -366,7 +366,7 @@ def run_folding_on_context(
raise_if_too_many_tokens(n_actual_tokens)
raise_if_too_many_templates(feature_context.template_context.num_templates)
raise_if_msa_too_deep(feature_context.msa_context.depth)
raise_if_msa_too_deep(feature_context.main_msa_context.depth)
# NOTE profile MSA used only for statistics; no depth check

##
## Prepare batch
Expand Down
12 changes: 6 additions & 6 deletions chai_lab/data/dataset/all_atom_feature_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AllAtomFeatureContext:
# Contexts: these are what get padded and batched
structure_context: AllAtomStructureContext
msa_context: MSAContext
main_msa_context: MSAContext
profile_msa_context: MSAContext
template_context: TemplateContext
embedding_context: EmbeddingContext | None
constraint_context: ConstraintContext
Expand All @@ -59,9 +59,9 @@ def pad(
max_num_tokens=n_tokens,
max_msa_depth=MAX_MSA_DEPTH,
),
main_msa_context=self.main_msa_context.pad(
profile_msa_context=self.profile_msa_context.pad(
max_num_tokens=n_tokens,
max_msa_depth=MAX_MSA_DEPTH,
# max_msa_depth=MAX_MSA_DEPTH,
),
template_context=self.template_context.pad(
max_tokens=n_tokens,
Expand All @@ -82,9 +82,9 @@ def to_dict(self) -> dict[str, Any]:
msa_deletion_matrix=self.msa_context.deletion_matrix,
msa_pairkey=self.msa_context.pairing_key_hash,
msa_sequence_source=self.msa_context.sequence_source,
main_msa_tokens=self.main_msa_context.tokens,
main_msa_mask=self.main_msa_context.mask,
main_msa_deletion_matrix=self.main_msa_context.deletion_matrix,
main_msa_tokens=self.profile_msa_context.tokens,
main_msa_mask=self.profile_msa_context.mask,
main_msa_deletion_matrix=self.profile_msa_context.deletion_matrix,
)
return {
**self.structure_context.to_dict(),
Expand Down
9 changes: 6 additions & 3 deletions chai_lab/data/dataset/msas/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# This source code is licensed under the Chai Discovery Community License
# Agreement (LICENSE.md) found in the root directory of this source tree.
import logging
from collections import Counter
from typing import Iterable

Expand All @@ -10,6 +11,8 @@
MAX_PAIRED_DEPTH = 8_192
FULL_DEPTH = 16_384

logger = logging.getLogger(__name__)


def merge_main_msas_by_chain(msas: list[MSAContext]) -> MSAContext:
"""Merges MSAs across chains, concatenating over token dimension."""
Expand Down Expand Up @@ -89,9 +92,9 @@ def prepair_ukey(pairing_keys: Iterable[int]) -> dict[tuple[int, int], int]:

selected_msa = msa.take_rows_with_padding(all_rowids)

print(
f"Loaded (paired in includes query sequence):"
f"\n{n_paired_msa=} {n_unpaired_msa=} out of {msa.depth=} "
logger.info(
f"Loaded (paired in includes query sequence): "
f"{n_paired_msa=} {n_unpaired_msa=} out of {msa.depth=} "
)

# reorder each msa to have paired elements first; that's same # of rows
Expand Down

0 comments on commit 9514b65

Please sign in to comment.