Skip to content

Commit

Permalink
Add extensive testing for predict_dna_sequence.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Sep 20, 2024
1 parent e6c3ed5 commit bbfe03a
Showing 1 changed file with 264 additions and 2 deletions.
266 changes: 264 additions & 2 deletions tests/test_CodonPrediction.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import unittest
import warnings
import random

import torch

from CodonTransformer.CodonData import get_amino_acid_sequence
from CodonTransformer.CodonUtils import (
AMINO_ACIDS,
INDEX2TOKEN,
STOP_SYMBOLS,
ORGANISM2ID,
)
from CodonTransformer.CodonPrediction import (
load_model,
load_tokenizer,
Expand Down Expand Up @@ -148,8 +156,6 @@ def test_predict_dna_sequence_invalid_temperature_and_top_p(self):

def test_predict_dna_sequence_translation_consistency(self):
"""Test that the predicted DNA translates back to the original protein."""
from CodonTransformer.CodonData import get_amino_acid_sequence

protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE"
organism = "Escherichia coli general"
result = predict_dna_sequence(
Expand All @@ -170,6 +176,262 @@ def test_predict_dna_sequence_translation_consistency(self):
"Translated protein does not match the original protein sequence",
)

def test_predict_dna_sequence_long_protein_sequence(self):
"""Test the function with a very long protein sequence to check performance and correctness."""
protein_sequence = (
"M"
+ "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
* 20
+ STOP_SYMBOLS[0]
)
organism = "Escherichia coli general"
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
)

# Check that the predicted DNA translates back to the original protein
dna_sequence = result.predicted_dna[:-3]
translated_protein = get_amino_acid_sequence(dna_sequence)
self.assertEqual(
translated_protein,
protein_sequence[:-1],
"Translated protein does not match the original long protein sequence",
)

def test_predict_dna_sequence_edge_case_organisms(self):
"""Test the function with organism IDs at the boundaries of the mapping."""
protein_sequence = "MWWMW"
# Assuming ORGANISM2ID has IDs starting from 0 to N
min_organism_id = min(ORGANISM2ID.values())
max_organism_id = max(ORGANISM2ID.values())
organisms = [min_organism_id, max_organism_id]

for organism_id in organisms:
with self.subTest(organism_id=organism_id):
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism_id,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
)
self.assertIsInstance(result.predicted_dna, str)
self.assertTrue(
all(nucleotide in "ATCG" for nucleotide in result.predicted_dna)
)

def test_predict_dna_sequence_concurrent_calls(self):
"""Test the function's behavior under concurrent execution."""
import threading

protein_sequence = "MWWMW"
organism = "Escherichia coli general"
results = []

def call_predict():
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
)
results.append(result.predicted_dna)

threads = [threading.Thread(target=call_predict) for _ in range(10)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

self.assertEqual(len(results), 10)
self.assertTrue(all(dna == results[0] for dna in results))

def test_predict_dna_sequence_random_seed_consistency(self):
"""Test that setting a random seed results in consistent outputs in non-deterministic mode."""
protein_sequence = "MFWY"
organism = "Escherichia coli general"
temperature = 0.5
top_p = 0.95
torch.manual_seed(42)

result1 = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=temperature,
top_p=top_p,
)

torch.manual_seed(42)

result2 = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=temperature,
top_p=top_p,
)

self.assertEqual(
result1.predicted_dna,
result2.predicted_dna,
"Outputs should be consistent when random seed is set",
)

def test_predict_dna_sequence_invalid_tokenizer_and_model(self):
"""Test that providing invalid tokenizer or model raises appropriate exceptions."""
protein_sequence = "MWWMW"
organism = "Escherichia coli general"

with self.subTest("Invalid tokenizer"):
with self.assertRaises(Exception):
predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer="invalid_tokenizer_path",
model=self.model,
)

with self.subTest("Invalid model"):
with self.assertRaises(Exception):
predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model="invalid_model_path",
)

def test_predict_dna_sequence_stop_codon_handling(self):
"""Test the function's handling of protein sequences ending with a non '_' or '*' stop symbol."""
protein_sequence = "MWW/"
organism = "Escherichia coli general"

with self.assertRaises(ValueError):
predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
)

def test_predict_dna_sequence_ambiguous_amino_acids(self):
"""Test the function's response to ambiguous or non-standard amino acids."""
protein_sequence = "MWWBXZ"
organism = "Escherichia coli general"

with self.assertRaises(ValueError):
predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
)

def test_predict_dna_sequence_device_compatibility(self):
"""Test that the function works correctly on both CPU and GPU devices."""
protein_sequence = "MWWMW"
organism = "Escherichia coli general"

devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda"))

for device in devices:
with self.subTest(device=device):
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
)
self.assertIsInstance(result.predicted_dna, str)
self.assertTrue(
all(nucleotide in "ATCG" for nucleotide in result.predicted_dna)
)

def test_predict_dna_sequence_random_proteins(self):
"""Test random proteins to ensure translated DNA matches the original protein."""
organism = "Escherichia coli general"
num_tests = 200

for _ in range(num_tests):
# Generate a random protein sequence of random length between 10 and 50
protein_length = random.randint(10, 500)
protein_sequence = "M" + "".join(
random.choices(AMINO_ACIDS, k=protein_length - 1)
)
protein_sequence += random.choice(STOP_SYMBOLS)

result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
)

# Remove stop codon from predicted DNA
dna_sequence = result.predicted_dna[:-3]

# Translate predicted DNA back to protein
translated_protein = get_amino_acid_sequence(dna_sequence)
self.assertEqual(
translated_protein,
protein_sequence[:-1], # Remove stop symbol
f"Translated protein does not match the original protein sequence for protein: {protein_sequence}",
)

def test_predict_dna_sequence_long_protein_over_max_length(self):
"""Test that the model handles protein sequences longer than 2048 amino acids."""
# Create a protein sequence longer than 2048 amino acids
base_sequence = (
"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG"
)
protein_sequence = base_sequence * 100 # Length > 2048 amino acids
organism = "Escherichia coli general"

result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=True,
)

# Remove stop codon from predicted DNA
dna_sequence = result.predicted_dna[:-3]
translated_protein = get_amino_acid_sequence(dna_sequence)

# Due to potential model limitations, compare up to the model's max supported length
max_length = len(translated_protein)
self.assertEqual(
translated_protein[:max_length],
protein_sequence[:max_length],
"Translated protein does not match the original protein sequence up to the maximum length supported.",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit bbfe03a

Please sign in to comment.