diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index a4d5434..8c5a4f6 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -39,19 +39,20 @@ def predict_dna_sequence( model: Union[str, torch.nn.Module] = None, attention_type: str = "original_full", deterministic: bool = True, + temperature: float = 0.2, ) -> DNASequencePrediction: """ - Predict the DNA sequence for a given protein using CodonTransformer model. + Predict the DNA sequence for a given protein using the CodonTransformer model. This function takes a protein sequence and an organism (as ID or name) as input - and returns the predicted DNA sequence by CodonTransformer. It can use either - provided tokenizer and model objects or load them from specified paths. + and returns the predicted DNA sequence using the CodonTransformer model. It can use + either provided tokenizer and model objects or load them from specified paths. Args: - protein (str): The input protein sequence to predict the DNA sequence for. + protein (str): The input protein sequence for which to predict the DNA sequence. organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). If a string is provided, it will be converted - to the corresponding ID using ORGANISM2ID. + to the corresponding ID using `ORGANISM2ID`. device (torch.device): The device (CPU or GPU) to run the model on. tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, a pre-loaded tokenizer object, or None. If @@ -63,8 +64,18 @@ def predict_dna_sequence( model. Can be either 'block_sparse' or 'original_full'. Defaults to "original_full". deterministic (bool, optional): Whether to use deterministic decoding (most - likely tokens). If False, samples tokens according to their probabilities. - Defaults to True. + likely tokens). If False, samples tokens according to their probabilities + adjusted by the temperature. Defaults to True. + temperature (float, optional): A value controlling the randomness of predictions + during non-deterministic decoding. Lower values (e.g., 0.2) make the model + more conservative, while higher values (e.g., 0.8) increase randomness. + Using high temperatures may result in prediction of DNA sequences that + do not translate to the input protein. + Recommended values are: + - Low randomness: 0.2 + - Medium randomness: 0.5 + - High randomness: 0.8 + The temperature must be a positive float. Defaults to 0.2. Returns: DNASequencePrediction: An object containing the prediction results: @@ -74,12 +85,13 @@ def predict_dna_sequence( - predicted_dna (str): Predicted DNA sequence. Raises: - ValueError: If the protein sequence is empty or if the organism is invalid. + ValueError: If the protein sequence is empty, if the organism is invalid, + or if the temperature is not a positive float. Note: - This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from - CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their - corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to + This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from + `CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their + corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to respective codons. Example: @@ -111,27 +123,32 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with probabilistic sampling - >>> output_sampler = predict_dna_sequence( + >>> # Predict DNA sequence with low randomness + >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, ... device=device, ... tokenizer=tokenizer, ... model=model, ... attention_type="original_full", - ... deterministic=False + ... deterministic=False, + ... temperature=0.2 ... ) >>> >>> print(format_model_output(output)) - >>> print(format_model_output(output_sampler)) + >>> print(format_model_output(output_random)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") - # Test that the input protein sequence contains only valid amino acids + # Ensure the protein sequence contains only valid amino acids if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") + # Validate temperature + if not isinstance(temperature, (float, int)) or temperature <= 0: + raise ValueError("Temperature must be a positive float.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -168,8 +185,10 @@ def predict_dna_sequence( predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: # Sample tokens according to their probability distribution - # Convert logits to probabilities using softmax + # Apply temperature scaling and convert logits to probabilities + logits = logits / temperature probabilities = torch.softmax(logits, dim=-1) + # Sample from the probability distribution at each position probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] predicted_indices = ( diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 3617c2b..28cb49e 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -44,7 +44,8 @@ def test_predict_dna_sequence_valid_input(self): def test_predict_dna_sequence_non_deterministic(self): protein_sequence = "MFWY" organism = "Escherichia coli general" - num_iterations = 64 + num_iterations = 100 + temperatures = [0.2, 0.5, 0.8] possible_outputs = set() possible_encodings_wo_stop = { "ATGTTTTGGTAT", @@ -54,15 +55,17 @@ def test_predict_dna_sequence_non_deterministic(self): } for _ in range(num_iterations): - result = predict_dna_sequence( - protein=protein_sequence, - organism=organism, - device=self.device, - tokenizer=self.tokenizer, - model=self.model, - deterministic=False, - ) - possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon + for temperature in temperatures: + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + ) + possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon self.assertEqual(possible_outputs, possible_encodings_wo_stop)