diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 16a6136..e82b535 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,8 +1,16 @@ import unittest import warnings +import random import torch +from CodonTransformer.CodonData import get_amino_acid_sequence +from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, + INDEX2TOKEN, + STOP_SYMBOLS, + ORGANISM2ID, +) from CodonTransformer.CodonPrediction import ( load_model, load_tokenizer, @@ -148,8 +156,6 @@ def test_predict_dna_sequence_invalid_temperature_and_top_p(self): def test_predict_dna_sequence_translation_consistency(self): """Test that the predicted DNA translates back to the original protein.""" - from CodonTransformer.CodonData import get_amino_acid_sequence - protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE" organism = "Escherichia coli general" result = predict_dna_sequence( @@ -170,6 +176,262 @@ def test_predict_dna_sequence_translation_consistency(self): "Translated protein does not match the original protein sequence", ) + def test_predict_dna_sequence_long_protein_sequence(self): + """Test the function with a very long protein sequence to check performance and correctness.""" + protein_sequence = ( + "M" + + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" + * 20 + + STOP_SYMBOLS[0] + ) + organism = "Escherichia coli general" + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Check that the predicted DNA translates back to the original protein + dna_sequence = result.predicted_dna[:-3] + translated_protein = get_amino_acid_sequence(dna_sequence) + self.assertEqual( + translated_protein, + protein_sequence[:-1], + "Translated protein does not match the original long protein sequence", + ) + + def test_predict_dna_sequence_edge_case_organisms(self): + """Test the function with organism IDs at the boundaries of the mapping.""" + protein_sequence = "MWWMW" + # Assuming ORGANISM2ID has IDs starting from 0 to N + min_organism_id = min(ORGANISM2ID.values()) + max_organism_id = max(ORGANISM2ID.values()) + organisms = [min_organism_id, max_organism_id] + + for organism_id in organisms: + with self.subTest(organism_id=organism_id): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism_id, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) + + def test_predict_dna_sequence_concurrent_calls(self): + """Test the function's behavior under concurrent execution.""" + import threading + + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + results = [] + + def call_predict(): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + results.append(result.predicted_dna) + + threads = [threading.Thread(target=call_predict) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(len(results), 10) + self.assertTrue(all(dna == results[0] for dna in results)) + + def test_predict_dna_sequence_random_seed_consistency(self): + """Test that setting a random seed results in consistent outputs in non-deterministic mode.""" + protein_sequence = "MFWY" + organism = "Escherichia coli general" + temperature = 0.5 + top_p = 0.95 + torch.manual_seed(42) + + result1 = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + + torch.manual_seed(42) + + result2 = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + + self.assertEqual( + result1.predicted_dna, + result2.predicted_dna, + "Outputs should be consistent when random seed is set", + ) + + def test_predict_dna_sequence_invalid_tokenizer_and_model(self): + """Test that providing invalid tokenizer or model raises appropriate exceptions.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + + with self.subTest("Invalid tokenizer"): + with self.assertRaises(Exception): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer="invalid_tokenizer_path", + model=self.model, + ) + + with self.subTest("Invalid model"): + with self.assertRaises(Exception): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model="invalid_model_path", + ) + + def test_predict_dna_sequence_stop_codon_handling(self): + """Test the function's handling of protein sequences ending with a non '_' or '*' stop symbol.""" + protein_sequence = "MWW/" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) + + def test_predict_dna_sequence_ambiguous_amino_acids(self): + """Test the function's response to ambiguous or non-standard amino acids.""" + protein_sequence = "MWWBXZ" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) + + def test_predict_dna_sequence_device_compatibility(self): + """Test that the function works correctly on both CPU and GPU devices.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + + for device in devices: + with self.subTest(device=device): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) + + def test_predict_dna_sequence_random_proteins(self): + """Test random proteins to ensure translated DNA matches the original protein.""" + organism = "Escherichia coli general" + num_tests = 200 + + for _ in range(num_tests): + # Generate a random protein sequence of random length between 10 and 50 + protein_length = random.randint(10, 500) + protein_sequence = "M" + "".join( + random.choices(AMINO_ACIDS, k=protein_length - 1) + ) + protein_sequence += random.choice(STOP_SYMBOLS) + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Remove stop codon from predicted DNA + dna_sequence = result.predicted_dna[:-3] + + # Translate predicted DNA back to protein + translated_protein = get_amino_acid_sequence(dna_sequence) + self.assertEqual( + translated_protein, + protein_sequence[:-1], # Remove stop symbol + f"Translated protein does not match the original protein sequence for protein: {protein_sequence}", + ) + + def test_predict_dna_sequence_long_protein_over_max_length(self): + """Test that the model handles protein sequences longer than 2048 amino acids.""" + # Create a protein sequence longer than 2048 amino acids + base_sequence = ( + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" + ) + protein_sequence = base_sequence * 100 # Length > 2048 amino acids + organism = "Escherichia coli general" + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Remove stop codon from predicted DNA + dna_sequence = result.predicted_dna[:-3] + translated_protein = get_amino_acid_sequence(dna_sequence) + + # Due to potential model limitations, compare up to the model's max supported length + max_length = len(translated_protein) + self.assertEqual( + translated_protein[:max_length], + protein_sequence[:max_length], + "Translated protein does not match the original protein sequence up to the maximum length supported.", + ) + if __name__ == "__main__": unittest.main()