diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index a4d5434..74922e7 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -39,19 +39,20 @@ def predict_dna_sequence( model: Union[str, torch.nn.Module] = None, attention_type: str = "original_full", deterministic: bool = True, + temperature: float = 0.2, ) -> DNASequencePrediction: """ - Predict the DNA sequence for a given protein using CodonTransformer model. + Predict the DNA sequence for a given protein using the CodonTransformer model. This function takes a protein sequence and an organism (as ID or name) as input - and returns the predicted DNA sequence by CodonTransformer. It can use either - provided tokenizer and model objects or load them from specified paths. + and returns the predicted DNA sequence using the CodonTransformer model. It can use + either provided tokenizer and model objects or load them from specified paths. Args: - protein (str): The input protein sequence to predict the DNA sequence for. + protein (str): The input protein sequence for which to predict the DNA sequence. organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). If a string is provided, it will be converted - to the corresponding ID using ORGANISM2ID. + to the corresponding ID using `ORGANISM2ID`. device (torch.device): The device (CPU or GPU) to run the model on. tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, a pre-loaded tokenizer object, or None. If @@ -63,8 +64,18 @@ def predict_dna_sequence( model. Can be either 'block_sparse' or 'original_full'. Defaults to "original_full". deterministic (bool, optional): Whether to use deterministic decoding (most - likely tokens). If False, samples tokens according to their probabilities. - Defaults to True. + likely tokens). If False, samples tokens according to their probabilities + adjusted by the temperature. Defaults to True. + temperature (float, optional): A value controlling the randomness of predictions + during non-deterministic decoding. Lower values (e.g., 0.2) make the model + more conservative, while higher values (e.g., 0.8) increase randomness. + Using high temperatures may result in prediction of DNA sequences that + do not translate to the input protein. + Recommended values are: + - Low randomness: 0.2 + - Medium randomness: 0.5 + - High randomness: 0.8 + The temperature must be a positive float. Defaults to 0.2. Returns: DNASequencePrediction: An object containing the prediction results: @@ -74,12 +85,13 @@ def predict_dna_sequence( - predicted_dna (str): Predicted DNA sequence. Raises: - ValueError: If the protein sequence is empty or if the organism is invalid. + ValueError: If the protein sequence is empty, if the organism is invalid, + or if the temperature is not a positive float. Note: - This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from - CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their - corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to + This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from + `CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their + corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to respective codons. Example: @@ -111,27 +123,32 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with probabilistic sampling - >>> output_sampler = predict_dna_sequence( + >>> # Predict DNA sequence with medium randomness + >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, ... device=device, ... tokenizer=tokenizer, ... model=model, ... attention_type="original_full", - ... deterministic=False + ... deterministic=False, + ... temperature=1.0 ... ) >>> >>> print(format_model_output(output)) - >>> print(format_model_output(output_sampler)) + >>> print(format_model_output(output_random)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") - # Test that the input protein sequence contains only valid amino acids + # Ensure the protein sequence contains only valid amino acids if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") + # Validate temperature + if not isinstance(temperature, (float, int)) or temperature <= 0: + raise ValueError("Temperature must be a positive float.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -168,8 +185,10 @@ def predict_dna_sequence( predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: # Sample tokens according to their probability distribution - # Convert logits to probabilities using softmax + # Apply temperature scaling and convert logits to probabilities + logits = logits / temperature probabilities = torch.softmax(logits, dim=-1) + # Sample from the probability distribution at each position probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] predicted_indices = (