Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for variable randomness in DNA prediction. #10

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ def predict_dna_sequence(
model: Union[str, torch.nn.Module] = None,
attention_type: str = "original_full",
deterministic: bool = True,
temperature: float = 0.2,
) -> DNASequencePrediction:
"""
Predict the DNA sequence for a given protein using CodonTransformer model.
Predict the DNA sequence for a given protein using the CodonTransformer model.

This function takes a protein sequence and an organism (as ID or name) as input
and returns the predicted DNA sequence by CodonTransformer. It can use either
provided tokenizer and model objects or load them from specified paths.
and returns the predicted DNA sequence using the CodonTransformer model. It can use
either provided tokenizer and model objects or load them from specified paths.

Args:
protein (str): The input protein sequence to predict the DNA sequence for.
protein (str): The input protein sequence for which to predict the DNA sequence.
organism (Union[int, str]): Either the ID of the organism or its name (e.g.,
"Escherichia coli general"). If a string is provided, it will be converted
to the corresponding ID using ORGANISM2ID.
to the corresponding ID using `ORGANISM2ID`.
device (torch.device): The device (CPU or GPU) to run the model on.
tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file
path to load the tokenizer from, a pre-loaded tokenizer object, or None. If
Expand All @@ -63,8 +64,18 @@ def predict_dna_sequence(
model. Can be either 'block_sparse' or 'original_full'. Defaults to
"original_full".
deterministic (bool, optional): Whether to use deterministic decoding (most
likely tokens). If False, samples tokens according to their probabilities.
Defaults to True.
likely tokens). If False, samples tokens according to their probabilities
adjusted by the temperature. Defaults to True.
temperature (float, optional): A value controlling the randomness of predictions
during non-deterministic decoding. Lower values (e.g., 0.2) make the model
more conservative, while higher values (e.g., 0.8) increase randomness.
Using high temperatures may result in prediction of DNA sequences that
do not translate to the input protein.
Recommended values are:
- Low randomness: 0.2
- Medium randomness: 0.5
- High randomness: 0.8
The temperature must be a positive float. Defaults to 0.2.

Returns:
DNASequencePrediction: An object containing the prediction results:
Expand All @@ -74,12 +85,13 @@ def predict_dna_sequence(
- predicted_dna (str): Predicted DNA sequence.

Raises:
ValueError: If the protein sequence is empty or if the organism is invalid.
ValueError: If the protein sequence is empty, if the organism is invalid,
or if the temperature is not a positive float.

Note:
This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from
CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their
corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to
This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from
`CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their
corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to
respective codons.

Example:
Expand Down Expand Up @@ -111,27 +123,32 @@ def predict_dna_sequence(
... deterministic=True
... )
>>>
>>> # Predict DNA sequence with probabilistic sampling
>>> output_sampler = predict_dna_sequence(
>>> # Predict DNA sequence with low randomness
>>> output_random = predict_dna_sequence(
... protein=protein,
... organism=organism,
... device=device,
... tokenizer=tokenizer,
... model=model,
... attention_type="original_full",
... deterministic=False
... deterministic=False,
... temperature=0.2
... )
>>>
>>> print(format_model_output(output))
>>> print(format_model_output(output_sampler))
>>> print(format_model_output(output_random))
"""
if not protein:
raise ValueError("Protein sequence cannot be empty.")

# Test that the input protein sequence contains only valid amino acids
# Ensure the protein sequence contains only valid amino acids
if not all(aminoacid in AMINO_ACIDS for aminoacid in protein):
raise ValueError("Invalid amino acid found in protein sequence.")

# Validate temperature
if not isinstance(temperature, (float, int)) or temperature <= 0:
raise ValueError("Temperature must be a positive float.")

# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand Down Expand Up @@ -168,8 +185,10 @@ def predict_dna_sequence(
predicted_indices = logits.argmax(dim=-1).squeeze().tolist()
else:
# Sample tokens according to their probability distribution
# Convert logits to probabilities using softmax
# Apply temperature scaling and convert logits to probabilities
logits = logits / temperature
probabilities = torch.softmax(logits, dim=-1)

# Sample from the probability distribution at each position
probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size]
predicted_indices = (
Expand Down
Loading