Skip to content

Commit

Permalink
Add support for variable randomness in DNA prediction.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Sep 20, 2024
1 parent c5d396b commit 59e4c5c
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ def predict_dna_sequence(
model: Union[str, torch.nn.Module] = None,
attention_type: str = "original_full",
deterministic: bool = True,
temperature: float = 0.2,
) -> DNASequencePrediction:
"""
Predict the DNA sequence for a given protein using CodonTransformer model.
Predict the DNA sequence for a given protein using the CodonTransformer model.
This function takes a protein sequence and an organism (as ID or name) as input
and returns the predicted DNA sequence by CodonTransformer. It can use either
provided tokenizer and model objects or load them from specified paths.
and returns the predicted DNA sequence using the CodonTransformer model. It can use
either provided tokenizer and model objects or load them from specified paths.
Args:
protein (str): The input protein sequence to predict the DNA sequence for.
protein (str): The input protein sequence for which to predict the DNA sequence.
organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
"Escherichia coli general"). If a string is provided, it will be converted
to the corresponding ID using ORGANISM2ID.
to the corresponding ID using `ORGANISM2ID`.
device (torch.device): The device (CPU or GPU) to run the model on.
tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file
path to load the tokenizer from, a pre-loaded tokenizer object, or None. If
Expand All @@ -63,8 +64,18 @@ def predict_dna_sequence(
model. Can be either 'block_sparse' or 'original_full'. Defaults to
"original_full".
deterministic (bool, optional): Whether to use deterministic decoding (most
likely tokens). If False, samples tokens according to their probabilities.
Defaults to True.
likely tokens). If False, samples tokens according to their probabilities
adjusted by the temperature. Defaults to True.
temperature (float, optional): A value controlling the randomness of predictions
during non-deterministic decoding. Lower values (e.g., 0.2) make the model
more conservative, while higher values (e.g., 0.8) increase randomness.
Using high temperatures may result in prediction of DNA sequences that
do not translate to the input protein.
Recommended values are:
- Low randomness: 0.2
- Medium randomness: 0.5
- High randomness: 0.8
The temperature must be a positive float. Defaults to 0.2.
Returns:
DNASequencePrediction: An object containing the prediction results:
Expand All @@ -74,12 +85,13 @@ def predict_dna_sequence(
- predicted_dna (str): Predicted DNA sequence.
Raises:
ValueError: If the protein sequence is empty or if the organism is invalid.
ValueError: If the protein sequence is empty, if the organism is invalid,
or if the temperature is not a positive float.
Note:
This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from
CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to
This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from
`CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their
corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to
respective codons.
Example:
Expand Down Expand Up @@ -111,27 +123,32 @@ def predict_dna_sequence(
... deterministic=True
... )
>>>
>>> # Predict DNA sequence with probabilistic sampling
>>> output_sampler = predict_dna_sequence(
>>> # Predict DNA sequence with medium randomness
>>> output_random = predict_dna_sequence(
... protein=protein,
... organism=organism,
... device=device,
... tokenizer=tokenizer,
... model=model,
... attention_type="original_full",
... deterministic=False
... deterministic=False,
... temperature=1.0
... )
>>>
>>> print(format_model_output(output))
>>> print(format_model_output(output_sampler))
>>> print(format_model_output(output_random))
"""
if not protein:
raise ValueError("Protein sequence cannot be empty.")

# Test that the input protein sequence contains only valid amino acids
# Ensure the protein sequence contains only valid amino acids
if not all(aminoacid in AMINO_ACIDS for aminoacid in protein):
raise ValueError("Invalid amino acid found in protein sequence.")

# Validate temperature
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand Down Expand Up @@ -168,8 +185,10 @@ def predict_dna_sequence(
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
else:
# Sample tokens according to their probability distribution
# Convert logits to probabilities using softmax
# Apply temperature scaling and convert logits to probabilities
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)

# Sample from the probability distribution at each position
probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size]
predicted_indices = (
Expand Down

0 comments on commit 59e4c5c

Please sign in to comment.