Skip to content

Commit

Permalink
Merge pull request #15 from Adibvafa/dev
Browse files Browse the repository at this point in the history
Add support for protein sequence matching
  • Loading branch information
Adibvafa authored Oct 29, 2024
2 parents 1356ebb + 89958df commit a315109
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 54 deletions.
22 changes: 13 additions & 9 deletions CodonTransformer/CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,30 @@ def preprocess_protein_sequence(protein: str) -> str:

# Handle ambiguous amino acids based on the specified behavior
config = ProteinConfig()
ambiguous_aminoacid_map_override = config.get('ambiguous_aminoacid_map_override')
ambiguous_aminoacid_behavior = config.get('ambiguous_aminoacid_behavior')
ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override")
ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior")
ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy()

for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items():
ambiguous_aminoacid_map[aminoacid] = standard_aminoacids

if ambiguous_aminoacid_behavior == 'raise_error':
if ambiguous_aminoacid_behavior == "raise_error":
if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein):
raise ValueError("Ambiguous amino acids found in protein sequence.")
elif ambiguous_aminoacid_behavior == 'standardize_deterministic':
elif ambiguous_aminoacid_behavior == "standardize_deterministic":
protein = "".join(
ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0] for aminoacid in protein
ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0]
for aminoacid in protein
)
elif ambiguous_aminoacid_behavior == 'standardize_random':
elif ambiguous_aminoacid_behavior == "standardize_random":
protein = "".join(
random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid])) for aminoacid in protein
random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid]))
for aminoacid in protein
)
else:
raise ValueError(f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}.")
raise ValueError(
f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}."
)

# Check for sequence validity
if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
Expand Down
20 changes: 19 additions & 1 deletion CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACID_TO_INDEX,
INDEX2TOKEN,
NUM_ORGANISMS,
ORGANISM2ID,
Expand All @@ -41,6 +42,7 @@ def predict_dna_sequence(
temperature: float = 0.2,
top_p: float = 0.95,
num_sequences: int = 1,
match_protein: bool = False,
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
"""
Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
Expand Down Expand Up @@ -83,6 +85,9 @@ def predict_dna_sequence(
The value must be a float between 0 and 1. Defaults to 0.95.
num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
when deterministic is False. Defaults to 1.
match_protein (bool, optional): Ensures the predicted DNA sequence is translated
to the input protein sequence by sampling from only the respective codons of
given amino acids. Defaults to False.
Returns:
Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
Expand Down Expand Up @@ -198,6 +203,19 @@ def predict_dna_sequence(
# Get the model predictions
output_dict = model(**tokenized_input, return_dict=True)
logits = output_dict.logits.detach().cpu()
logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens

# Mask the logits of codons that do not correspond to the input protein sequence
if match_protein:
possible_tokens_per_position = [
AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ")
]
mask = torch.full_like(logits, float("-inf"))

for pos, possible_tokens in enumerate(possible_tokens_per_position):
mask[:, pos, possible_tokens] = 0

logits = mask + logits

predictions = []
for _ in range(num_sequences):
Expand All @@ -211,7 +229,7 @@ def predict_dna_sequence(

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
predicted_dna = (
"".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper()
"".join([token[-3:] for token in predicted_dna]).strip().upper()
)

predictions.append(
Expand Down
52 changes: 36 additions & 16 deletions CodonTransformer/CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@
# Index-to-token mapping, reverse of TOKEN2INDEX
INDEX2TOKEN: Dict[int, str] = {i: c for c, i in TOKEN2INDEX.items()}

# Dictionary mapping each amino acid and stop symbol to indices of codon tokens that translate to it
AMINO_ACID_TO_INDEX = {
aa: sorted(
[i for t, i in TOKEN2INDEX.items() if t[0].upper() == aa and t[-3:] != "unk"]
)
for aa in (AMINO_ACIDS + STOP_SYMBOLS)
}


# Mask token mapping
TOKEN2MASK: Dict[int, int] = {
0: 0,
Expand Down Expand Up @@ -550,14 +559,15 @@ class ConfigManager(ABC):
"""
Abstract base class for managing configuration settings.
"""

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}")
self.reset_config()

@abstractmethod
def reset_config(self) -> None:
"""Reset the configuration to default values."""
Expand Down Expand Up @@ -601,7 +611,8 @@ def update(self, config_dict: dict) -> None:
def validate_inputs(self, key: str, value: Any) -> None:
"""Validate the inputs for the configuration."""
pass



class ProteinConfig(ConfigManager):
"""
A class to manage configuration settings for protein sequences.
Expand All @@ -613,6 +624,7 @@ class ProteinConfig(ConfigManager):
_instance (Optional[ConfigManager]): The singleton instance of the ConfigManager.
_config (Dict[str, Any]): The configuration dictionary.
"""

_instance = None

def __new__(cls):
Expand All @@ -626,36 +638,44 @@ def __new__(cls):
cls._instance = super(ProteinConfig, cls).__new__(cls)
cls._instance.reset_config()
return cls._instance

def validate_inputs(self, key: str, value: Any) -> None:
"""
Validate the inputs for the configuration.
Args:
key (str): The key to validate.
value (Any): The value to validate.
Raises:
ValueError: If the value is invalid.
TypeError: If the value is of the wrong type.
"""
if key == 'ambiguous_aminoacid_behavior':
if key == "ambiguous_aminoacid_behavior":
if value not in [
'raise_error',
'standardize_deterministic',
'standardize_random'
"raise_error",
"standardize_deterministic",
"standardize_random",
]:
raise ValueError(f"Invalid value for ambiguous_aminoacid_behavior: {value}.")
elif key == 'ambiguous_aminoacid_map_override':
raise ValueError(
f"Invalid value for ambiguous_aminoacid_behavior: {value}."
)
elif key == "ambiguous_aminoacid_map_override":
if not isinstance(value, dict):
raise TypeError(f"Invalid type for ambiguous_aminoacid_map_override: {value}.")
raise TypeError(
f"Invalid type for ambiguous_aminoacid_map_override: {value}."
)
for ambiguous_aminoacid, aminoacids in value.items():
if not isinstance(aminoacids, list):
raise TypeError(f"Invalid type for aminoacids: {aminoacids}.")
if not aminoacids:
raise ValueError(f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list.")
raise ValueError(
f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list."
)
if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP:
raise ValueError(f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}")
raise ValueError(
f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}"
)
else:
raise ValueError(f"Invalid configuration key: {key}")

Expand All @@ -664,8 +684,8 @@ def reset_config(self) -> None:
Reset the configuration to the default values.
"""
self._config = {
'ambiguous_aminoacid_behavior': 'standardize_random',
'ambiguous_aminoacid_map_override': {}
"ambiguous_aminoacid_behavior": "standardize_random",
"ambiguous_aminoacid_map_override": {},
}


Expand Down
3 changes: 2 additions & 1 deletion tests/test_CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
build_amino2codon_skeleton,
get_amino_acid_sequence,
is_correct_seq,
read_fasta_file,
preprocess_protein_sequence,
read_fasta_file,
)
from CodonTransformer.CodonUtils import ProteinConfig


class TestCodonData(unittest.TestCase):
def test_preprocess_protein_sequence(self):
with ProteinConfig() as config:
Expand Down
97 changes: 97 additions & 0 deletions tests/test_CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,103 @@ def test_predict_dna_sequence_multi_diversity(self):
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)

def test_predict_dna_sequence_match_protein_repetitive(self):
"""Test that match_protein=True correctly handles highly repetitive and unconventional sequences."""
test_sequences = (
"QQQQQQQQQQQQQQQQ_",
"KRKRKRKRKRKRKRKR_",
"PGPGPGPGPGPGPGPG_",
"DEDEDEDEDEDEDEDEDE_",
"M_M_M_M_M_",
"MMMMMMMMMM_",
"WWWWWWWWWW_",
"CCCCCCCCCC_",
"MWCHMWCHMWCH_",
"Q_QQ_QQQ_QQQQ_",
"MWMWMWMWMWMW_",
"CCCHHHMMMWWW_",
"_",
"M_",
"MGWC_",
)

organism = "Homo sapiens"

for protein_sequence in test_sequences:
# Generate sequence with match_protein=True
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=20, # High temperature to test protein matching
match_protein=True,
)

dna_sequence = result.predicted_dna
translated_protein = get_amino_acid_sequence(dna_sequence)

self.assertEqual(
translated_protein,
protein_sequence,
f"Translated protein must match original when match_protein=True. Failed for sequence: {protein_sequence}",
)

def test_predict_dna_sequence_match_protein_rare_amino_acids(self):
"""Test match_protein with rare amino acids that have limited codon options."""
# Methionine (M) and Tryptophan (W) have only one codon each
# While Leucine (L) has 6 codons - testing contrast
protein_sequence = "MWLLLMWLLL"
organism = "Escherichia coli general"

# Run multiple predictions
results = []
num_iterations = 10

for _ in range(num_iterations):
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=20, # High temperature to test protein matching
match_protein=True,
)
results.append(result.predicted_dna)

# Check all sequences
for dna_sequence in results:
# Verify M always uses ATG
m_positions = [0, 5] # Known positions of M in sequence
for pos in m_positions:
self.assertEqual(
dna_sequence[pos * 3 : (pos + 1) * 3],
"ATG",
"Methionine must use ATG codon.",
)

# Verify W always uses TGG
w_positions = [1, 6] # Known positions of W in sequence
for pos in w_positions:
self.assertEqual(
dna_sequence[pos * 3 : (pos + 1) * 3],
"TGG",
"Tryptophan must use TGG codon.",
)

# Verify all L codons are valid
l_positions = [2, 3, 4, 7, 8, 9] # Known positions of L in sequence
l_codons = [dna_sequence[pos * 3 : (pos + 1) * 3] for pos in l_positions]
valid_l_codons = {"TTA", "TTG", "CTT", "CTC", "CTA", "CTG"}
self.assertTrue(
all(codon in valid_l_codons for codon in l_codons),
"All Leucine codons must be valid",
)


if __name__ == "__main__":
unittest.main()
40 changes: 13 additions & 27 deletions tests/test_CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,23 @@
class TestCodonUtils(unittest.TestCase):
def test_config_manager(self):
with ProteinConfig() as config:
config.set(
"ambiguous_aminoacid_behavior",
"standardize_deterministic"
)
config.set("ambiguous_aminoacid_behavior", "standardize_deterministic")
self.assertEqual(
config.get("ambiguous_aminoacid_behavior"),
"standardize_deterministic"
)
config.set(
"ambiguous_aminoacid_map_override",
{"X": ["A", "G"]}
config.get("ambiguous_aminoacid_behavior"), "standardize_deterministic"
)
config.set("ambiguous_aminoacid_map_override", {"X": ["A", "G"]})
self.assertEqual(
config.get("ambiguous_aminoacid_map_override"),
{"X": ["A", "G"]}
config.get("ambiguous_aminoacid_map_override"), {"X": ["A", "G"]}
)
config.update({
"ambiguous_aminoacid_behavior": "raise_error",
"ambiguous_aminoacid_map_override": {"X": ["A", "G"]},
})
self.assertEqual(
config.get("ambiguous_aminoacid_behavior"),
"raise_error"
config.update(
{
"ambiguous_aminoacid_behavior": "raise_error",
"ambiguous_aminoacid_map_override": {"X": ["A", "G"]},
}
)
self.assertEqual(config.get("ambiguous_aminoacid_behavior"), "raise_error")
self.assertEqual(
config.get("ambiguous_aminoacid_map_override"),
{"X": ["A", "G"]}
config.get("ambiguous_aminoacid_map_override"), {"X": ["A", "G"]}
)
try:
config.set("invalid_key", "invalid_value")
Expand All @@ -53,13 +43,9 @@ def test_config_manager(self):
pass
with ProteinConfig() as config:
self.assertEqual(
config.get("ambiguous_aminoacid_behavior"),
"standardize_random"
)
self.assertEqual(
config.get("ambiguous_aminoacid_map_override"),
{}
config.get("ambiguous_aminoacid_behavior"), "standardize_random"
)
self.assertEqual(config.get("ambiguous_aminoacid_map_override"), {})

def test_load_python_object_from_disk(self):
test_obj = {"key1": "value1", "key2": 2}
Expand Down

0 comments on commit a315109

Please sign in to comment.