From c0252995ddc382bf321dd5b39d234ae48e7f869c Mon Sep 17 00:00:00 2001 From: derpbuffalo Date: Fri, 20 Sep 2024 15:37:17 +0200 Subject: [PATCH 01/36] Fix issue #5 and add ConfigManager --- CodonTransformer/CodonData.py | 30 ++++++-- CodonTransformer/CodonPrediction.py | 5 -- CodonTransformer/CodonUtils.py | 110 ++++++++++++++++++++++++++-- 3 files changed, 128 insertions(+), 17 deletions(-) diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index b6f6d86..ea576ce 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -7,6 +7,7 @@ import json import os +import random from typing import Dict, List, Optional, Tuple, Union import pandas as pd @@ -24,6 +25,7 @@ START_CODONS, STOP_CODONS, STOP_SYMBOL, + ConfigManager, find_pattern_in_fasta, get_taxonomy_id, sort_amino2codon_skeleton, @@ -161,7 +163,7 @@ def preprocess_protein_sequence(protein: str) -> str: str: The preprocessed protein sequence. Raises: - ValueError: If the protein sequence is invalid. + ValueError: If the protein sequence is invalid or if the configuration is invalid. """ if not protein: raise ValueError("Protein sequence is empty.") @@ -171,10 +173,28 @@ def preprocess_protein_sequence(protein: str) -> str: protein.upper().strip().replace("\n", "").replace(" ", "").replace("\t", "") ) - # Replace ambiguous amino acids with standard 20 amino acids - protein = "".join( - AMBIGUOUS_AMINOACID_MAP.get(aminoacid, aminoacid) for aminoacid in protein - ) + # Handle ambiguous amino acids based on the specified behavior + config = ConfigManager() + ambiguous_aminoacid_map_override = config.get('ambiguous_aminoacid_map_override') + ambiguous_aminoacid_behavior = config.get('ambiguous_aminoacid_behavior') + ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy() + + for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items(): + ambiguous_aminoacid_map[aminoacid] = standard_aminoacids + + if ambiguous_aminoacid_behavior == 'raise_error': + if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein): + raise ValueError("Ambiguous amino acids found in protein sequence.") + elif ambiguous_aminoacid_behavior == 'standardize_deterministic': + protein = "".join( + ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0] for aminoacid in protein + ) + elif ambiguous_aminoacid_behavior == 'standardize_random': + protein = "".join( + random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid])) for aminoacid in protein + ) + else: + raise ValueError(f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}.") # Check for sequence validity if any( diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index a4d5434..2ee618d 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -22,7 +22,6 @@ from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( - AMINO_ACIDS, INDEX2TOKEN, NUM_ORGANISMS, ORGANISM2ID, @@ -128,10 +127,6 @@ def predict_dna_sequence( if not protein: raise ValueError("Protein sequence cannot be empty.") - # Test that the input protein sequence contains only valid amino acids - if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): - raise ValueError("Invalid amino acid found in protein sequence.") - # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index dd91e3f..2dedd0a 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -40,13 +40,13 @@ ] # Dictionary ambiguous amino acids to standard amino acids -AMBIGUOUS_AMINOACID_MAP: Dict[str, str] = { - "B": "N", # Aspartic acid (D) or Asparagine (N) - "Z": "Q", # Glutamic acid (E) or Glutamine (Q) - "X": "A", # Any amino acid (typically replaced with Alanine) - "J": "L", # Leucine (L) or Isoleucine (I) - "U": "C", # Selenocysteine (typically replaced with Cysteine) - "O": "K", # Pyrrolysine (typically replaced with Lysine) +AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = { + "B": ["N", "D"], # Asparagine (N) or Aspartic acid (D) + "Z": ["Q", "E"], # Glutamine (Q) or Glutamic acid (E) + "X": ["A"], # Any amino acid (typically replaced with Alanine) + "J": ["L", "I"], # Leucine (L) or Isoleucine (I) + "U": ["C"], # Selenocysteine (typically replaced with Cysteine) + "O": ["K"], # Pyrrolysine (typically replaced with Lysine) } # List of all possible start and stop codons @@ -544,6 +544,102 @@ def __init__(self, data_path: str, train: bool = True, **kwargs): self.train = train +class ConfigManager: + """ + A class to manage configuration settings. + + This class ensures that the configuration is a singleton. + It provides methods to get, set, and update configuration values. + + Attributes: + _instance (Optional[ConfigManager]): The singleton instance of the ConfigManager. + _config (Dict[str, Any]): The configuration dictionary. + """ + _instance = None + + def __new__(cls): + """ + Create a new instance of the ConfigManager class. + + Returns: + ConfigManager: The singleton instance of the ConfigManager. + """ + if cls._instance is None: + cls._instance = super(ConfigManager, cls).__new__(cls) + cls._instance._config = { + 'ambiguous_aminoacid_behavior': 'raise_error', + 'ambiguous_aminoacid_map_override': {} + } + return cls._instance + + def get(self, key: str) -> Any: + """ + Get the value of a configuration key. + + Args: + key (str): The key to retrieve the value for. + + Returns: + Any: The value of the configuration key. + """ + return self._config.get(key) + + def set(self, key: str, value: Any) -> None: + """ + Set the value of a configuration key. + + Args: + key (str): The key to set the value for. + value (Any): The value to set for the key. + """ + self.validate_inputs(key, value) + self._config[key] = value + + def update(self, config_dict: dict) -> None: + """ + Update the configuration with a dictionary of key-value pairs after validating them. + + Args: + config_dict (dict): A dictionary of key-value pairs to update the configuration. + """ + for key, value in config_dict.items(): + self.validate_inputs(key, value) + for key, value in config_dict.items(): + self.set(key, value) + + def validate_inputs(self, key: str, value: Any) -> None: + """ + Validate the inputs for the configuration. + + Args: + key (str): The key to validate. + value (Any): The value to validate. + + Raises: + ValueError: If the value is invalid. + TypeError: If the value is of the wrong type. + """ + if key == 'ambiguous_aminoacid_behavior': + if value not in [ + 'raise_error', + 'standardize_deterministic', + 'standardize_random' + ]: + raise ValueError(f"Invalid value for ambiguous_aminoacid_behavior: {value}.") + elif key == 'ambiguous_aminoacid_map_override': + if not isinstance(value, dict): + raise TypeError(f"Invalid type for ambiguous_aminoacid_map_override: {value}.") + for ambiguous_aminoacid, aminoacids in value.items(): + if not isinstance(aminoacids, list): + raise TypeError(f"Invalid type for aminoacids: {aminoacids}.") + if not aminoacids: + raise ValueError(f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list.") + if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP: + raise ValueError(f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}") + else: + raise ValueError(f"Invalid configuration key: {key}") + + def load_python_object_from_disk(file_path: str) -> Any: """ Load a Pickle object from disk and return it as a Python object. From 16bd563aaa5e567b819c7177c0b035f45079c2fc Mon Sep 17 00:00:00 2001 From: derpbuffalo Date: Mon, 23 Sep 2024 21:47:46 +0200 Subject: [PATCH 02/36] add testcases for fix issue #5 --- CodonTransformer/CodonUtils.py | 21 +++++++++++++--- tests/test_CodonData.py | 24 +++++++++++++++++- tests/test_CodonUtils.py | 46 ++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 5 deletions(-) diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 2dedd0a..29711de 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -566,12 +566,17 @@ def __new__(cls): """ if cls._instance is None: cls._instance = super(ConfigManager, cls).__new__(cls) - cls._instance._config = { - 'ambiguous_aminoacid_behavior': 'raise_error', - 'ambiguous_aminoacid_map_override': {} - } + cls._instance.reset_config() return cls._instance + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is not None: + print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}") + self.reset_config() + def get(self, key: str) -> Any: """ Get the value of a configuration key. @@ -638,6 +643,14 @@ def validate_inputs(self, key: str, value: Any) -> None: raise ValueError(f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}") else: raise ValueError(f"Invalid configuration key: {key}") + def reset_config(self) -> None: + """ + Reset the configuration to the default values. + """ + self._config = { + 'ambiguous_aminoacid_behavior': 'raise_error', + 'ambiguous_aminoacid_map_override': {} + } def load_python_object_from_disk(file_path: str) -> Any: diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py index 42342c9..51efdd3 100644 --- a/tests/test_CodonData.py +++ b/tests/test_CodonData.py @@ -9,10 +9,32 @@ get_amino_acid_sequence, is_correct_seq, read_fasta_file, + preprocess_protein_sequence, ) - +from CodonTransformer.CodonUtils import ConfigManager class TestCodonData(unittest.TestCase): + def test_preprocess_protein_sequence(self): + with ConfigManager() as config: + protein = "Z_" + try: + preprocess_protein_sequence(protein) + self.fail("Expected ValueError") + except ValueError: + pass + config.set("ambiguous_aminoacid_behavior", "standardize_deterministic") + for _ in range(10): + preprocessed_protein = preprocess_protein_sequence(protein) + self.assertEqual(preprocessed_protein, "Q_") + config.set("ambiguous_aminoacid_behavior", "standardize_random") + random_results = set() + # The probability of getting the same result 30 times in a row is + # 1 in 1.073741824*10^9 if there are only two possible results. + for _ in range(30): + preprocessed_protein = preprocess_protein_sequence(protein) + random_results.add(preprocessed_protein) + self.assertGreater(len(random_results), 1) + def test_read_fasta_file(self): fasta_content = ">sequence1\n" "ATGATGATGATGATG\n" ">sequence2\n" "TGATGATGATGA" diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py index 1d9a94c..832f816 100644 --- a/tests/test_CodonUtils.py +++ b/tests/test_CodonUtils.py @@ -4,6 +4,7 @@ import unittest from CodonTransformer.CodonUtils import ( + ConfigManager, find_pattern_in_fasta, get_organism2id_dict, get_taxonomy_id, @@ -15,6 +16,51 @@ class TestCodonUtils(unittest.TestCase): + def test_config_manager(self): + with ConfigManager() as config: + config.set( + "ambiguous_aminoacid_behavior", + "standardize_deterministic" + ) + self.assertEqual( + config.get("ambiguous_aminoacid_behavior"), + "standardize_deterministic" + ) + config.set( + "ambiguous_aminoacid_map_override", + {"R": ["A", "G"]} + ) + self.assertEqual( + config.get("ambiguous_aminoacid_map_override"), + {"R": ["A", "G"]} + ) + config.update({ + "ambiguous_aminoacid_behavior": "raise_error", + "ambiguous_aminoacid_map_override": {"X": ["A", "G"]}, + }) + self.assertEqual( + config.get("ambiguous_aminoacid_behavior"), + "raise_error" + ) + self.assertEqual( + config.get("ambiguous_aminoacid_map_override"), + {"X": ["A", "G"]} + ) + try: + config.set("invalid_key", "invalid_value") + self.fail("Expected ValueError") + except ValueError: + pass + with ConfigManager() as config: + self.assertEqual( + config.get("ambiguous_aminoacid_behavior"), + "raise_error" + ) + self.assertEqual( + config.get("ambiguous_aminoacid_map_override"), + {} + ) + def test_load_python_object_from_disk(self): test_obj = {"key1": "value1", "key2": 2} with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: From a829c373bb66e944174994dc053c94c3038cf721 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 16:35:09 -0400 Subject: [PATCH 03/36] Bump version to 1.6.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6ca0e55..c8cd8f7 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def read_readme(): setup( name="CodonTransformer", - version="1.5.2", + version="1.6.0", packages=find_packages(), install_requires=read_requirements(), author="Adibvafa Fallahpour", From c17c5a09d90daea5b0d642fa4ae08ed1507a9c74 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 11:43:12 -0400 Subject: [PATCH 04/36] Add support for variable randomness in DNA prediction. --- CodonTransformer/CodonPrediction.py | 55 ++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 2ee618d..6adeef7 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -38,19 +38,20 @@ def predict_dna_sequence( model: Union[str, torch.nn.Module] = None, attention_type: str = "original_full", deterministic: bool = True, + temperature: float = 0.2, ) -> DNASequencePrediction: """ - Predict the DNA sequence for a given protein using CodonTransformer model. + Predict the DNA sequence for a given protein using the CodonTransformer model. This function takes a protein sequence and an organism (as ID or name) as input - and returns the predicted DNA sequence by CodonTransformer. It can use either - provided tokenizer and model objects or load them from specified paths. + and returns the predicted DNA sequence using the CodonTransformer model. It can use + either provided tokenizer and model objects or load them from specified paths. Args: - protein (str): The input protein sequence to predict the DNA sequence for. + protein (str): The input protein sequence for which to predict the DNA sequence. organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). If a string is provided, it will be converted - to the corresponding ID using ORGANISM2ID. + to the corresponding ID using `ORGANISM2ID`. device (torch.device): The device (CPU or GPU) to run the model on. tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, a pre-loaded tokenizer object, or None. If @@ -62,8 +63,18 @@ def predict_dna_sequence( model. Can be either 'block_sparse' or 'original_full'. Defaults to "original_full". deterministic (bool, optional): Whether to use deterministic decoding (most - likely tokens). If False, samples tokens according to their probabilities. - Defaults to True. + likely tokens). If False, samples tokens according to their probabilities + adjusted by the temperature. Defaults to True. + temperature (float, optional): A value controlling the randomness of predictions + during non-deterministic decoding. Lower values (e.g., 0.2) make the model + more conservative, while higher values (e.g., 0.8) increase randomness. + Using high temperatures may result in prediction of DNA sequences that + do not translate to the input protein. + Recommended values are: + - Low randomness: 0.2 + - Medium randomness: 0.5 + - High randomness: 0.8 + The temperature must be a positive float. Defaults to 0.2. Returns: DNASequencePrediction: An object containing the prediction results: @@ -73,12 +84,13 @@ def predict_dna_sequence( - predicted_dna (str): Predicted DNA sequence. Raises: - ValueError: If the protein sequence is empty or if the organism is invalid. + ValueError: If the protein sequence is empty, if the organism is invalid, + or if the temperature is not a positive float. Note: - This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from - CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their - corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to + This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from + `CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their + corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to respective codons. Example: @@ -110,23 +122,32 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with probabilistic sampling - >>> output_sampler = predict_dna_sequence( + >>> # Predict DNA sequence with medium randomness + >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, ... device=device, ... tokenizer=tokenizer, ... model=model, ... attention_type="original_full", - ... deterministic=False + ... deterministic=False, + ... temperature=1.0 ... ) >>> >>> print(format_model_output(output)) - >>> print(format_model_output(output_sampler)) + >>> print(format_model_output(output_random)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") + # Ensure the protein sequence contains only valid amino acids + if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): + raise ValueError("Invalid amino acid found in protein sequence.") + + # Validate temperature + if not isinstance(temperature, (float, int)) or temperature <= 0: + raise ValueError("Temperature must be a positive float.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -163,8 +184,10 @@ def predict_dna_sequence( predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: # Sample tokens according to their probability distribution - # Convert logits to probabilities using softmax + # Apply temperature scaling and convert logits to probabilities + logits = logits / temperature probabilities = torch.softmax(logits, dim=-1) + # Sample from the probability distribution at each position probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] predicted_indices = ( From acbd4a0cc96f9b064b4c78fc705985906fcf815b Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 11:45:14 -0400 Subject: [PATCH 05/36] Add random use case to docstring of predict_dna_sequence. --- CodonTransformer/CodonPrediction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 6adeef7..27b90e6 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -122,7 +122,7 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with medium randomness + >>> # Predict DNA sequence with low randomness >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, @@ -131,7 +131,7 @@ def predict_dna_sequence( ... model=model, ... attention_type="original_full", ... deterministic=False, - ... temperature=1.0 + ... temperature=0.2 ... ) >>> >>> print(format_model_output(output)) From a0962d2039b09cb233159deaaf9488fbaeb090ec Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 16:28:59 -0400 Subject: [PATCH 06/36] Add support for non-deterministic test with temperature. --- tests/test_CodonPrediction.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 3617c2b..28cb49e 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -44,7 +44,8 @@ def test_predict_dna_sequence_valid_input(self): def test_predict_dna_sequence_non_deterministic(self): protein_sequence = "MFWY" organism = "Escherichia coli general" - num_iterations = 64 + num_iterations = 100 + temperatures = [0.2, 0.5, 0.8] possible_outputs = set() possible_encodings_wo_stop = { "ATGTTTTGGTAT", @@ -54,15 +55,17 @@ def test_predict_dna_sequence_non_deterministic(self): } for _ in range(num_iterations): - result = predict_dna_sequence( - protein=protein_sequence, - organism=organism, - device=self.device, - tokenizer=self.tokenizer, - model=self.model, - deterministic=False, - ) - possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon + for temperature in temperatures: + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + ) + possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon self.assertEqual(possible_outputs, possible_encodings_wo_stop) From dbb4d427aebdee11663c999f858cf326702f8033 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 16:36:10 -0400 Subject: [PATCH 07/36] Bump version to 1.6.1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c8cd8f7..dcc4d5b 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def read_readme(): setup( name="CodonTransformer", - version="1.6.0", + version="1.6.1", packages=find_packages(), install_requires=read_requirements(), author="Adibvafa Fallahpour", From dae089032eedbf4a7d496dac315e9ad802d8c57e Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 16:37:38 -0400 Subject: [PATCH 08/36] Bump version to 1.6.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e74248e..360e798 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "CodonTransformer" -version = "1.5.2" +version = "1.6.1" description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." authors = ["Adibvafa Fallahpour "] license = "Apache-2.0" From 590280e9aaf99f7b42817dd434c209637b19f6f3 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 16:40:29 -0400 Subject: [PATCH 09/36] Bump version to 1.6.2 --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 360e798..4b8bfb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "CodonTransformer" -version = "1.6.1" +version = "1.6.2" description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." authors = ["Adibvafa Fallahpour "] license = "Apache-2.0" diff --git a/setup.py b/setup.py index dcc4d5b..d90c946 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def read_readme(): setup( name="CodonTransformer", - version="1.6.1", + version="1.6.2", packages=find_packages(), install_requires=read_requirements(), author="Adibvafa Fallahpour", From d0f6fc5eebc59581c2b81c7900002c38d28cc387 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 17:38:41 -0400 Subject: [PATCH 10/36] Add more tests to check predict_dna_sequence. --- tests/test_CodonPrediction.py | 83 +++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 28cb49e..16a6136 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -87,6 +87,89 @@ def test_predict_dna_sequence_invalid_inputs(self): model=self.model, ) + def test_predict_dna_sequence_top_p_effect(self): + """Test that changing top_p affects the diversity of outputs.""" + protein_sequence = "MFWY" + organism = "Escherichia coli general" + num_iterations = 50 + temperature = 0.5 + top_p_values = [0.8, 0.95] + outputs_by_top_p = {top_p: set() for top_p in top_p_values} + + for top_p in top_p_values: + for _ in range(num_iterations): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + outputs_by_top_p[top_p].add( + result.predicted_dna[:-3] + ) # Remove stop codon + + # Assert that higher top_p results in more diverse outputs + diversity_lower_top_p = len(outputs_by_top_p[0.8]) + diversity_higher_top_p = len(outputs_by_top_p[0.95]) + self.assertGreaterEqual( + diversity_higher_top_p, + diversity_lower_top_p, + "Higher top_p should result in more diverse outputs", + ) + + def test_predict_dna_sequence_invalid_temperature_and_top_p(self): + """Test that invalid temperature and top_p values raise ValueError.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + invalid_params = [ + {"temperature": -0.1, "top_p": 0.95}, + {"temperature": 0, "top_p": 0.95}, + {"temperature": 0.5, "top_p": -0.1}, + {"temperature": 0.5, "top_p": 1.1}, + ] + + for params in invalid_params: + with self.subTest(params=params): + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=params["temperature"], + top_p=params["top_p"], + ) + + def test_predict_dna_sequence_translation_consistency(self): + """Test that the predicted DNA translates back to the original protein.""" + from CodonTransformer.CodonData import get_amino_acid_sequence + + protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE" + organism = "Escherichia coli general" + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Translate predicted DNA back to protein + translated_protein = get_amino_acid_sequence(result.predicted_dna[:-3]) + + self.assertEqual( + translated_protein, + protein_sequence, + "Translated protein does not match the original protein sequence", + ) + if __name__ == "__main__": unittest.main() From 97a007978e73bf8d60984dcc77447e56820b2ec0 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 17:38:56 -0400 Subject: [PATCH 11/36] Add support for top_p in non-deterministic generation. --- CodonTransformer/CodonPrediction.py | 98 +++++++++++++++++++++++++---- 1 file changed, 85 insertions(+), 13 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 27b90e6..a09a811 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -39,6 +39,7 @@ def predict_dna_sequence( attention_type: str = "original_full", deterministic: bool = True, temperature: float = 0.2, + top_p: float = 0.95, ) -> DNASequencePrediction: """ Predict the DNA sequence for a given protein using the CodonTransformer model. @@ -75,6 +76,10 @@ def predict_dna_sequence( - Medium randomness: 0.5 - High randomness: 0.8 The temperature must be a positive float. Defaults to 0.2. + top_p (float, optional): The cumulative probability threshold for nucleus sampling. + Tokens with cumulative probability up to `top_p` are considered for sampling. + This parameter helps balance diversity and coherence in the predicted DNA sequences. + The value must be a float between 0 and 1. Defaults to 0.95. Returns: DNASequencePrediction: An object containing the prediction results: @@ -85,7 +90,7 @@ def predict_dna_sequence( Raises: ValueError: If the protein sequence is empty, if the organism is invalid, - or if the temperature is not a positive float. + if the temperature is not a positive float, or if `top_p` is not between 0 and 1. Note: This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from @@ -122,7 +127,7 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with low randomness + >>> # Predict DNA sequence with low randomness and top_p sampling >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, @@ -131,7 +136,8 @@ def predict_dna_sequence( ... model=model, ... attention_type="original_full", ... deterministic=False, - ... temperature=0.2 + ... temperature=0.2, + ... top_p=0.95 ... ) >>> >>> print(format_model_output(output)) @@ -148,6 +154,10 @@ def predict_dna_sequence( if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") + # Validate top_p + if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: + raise ValueError("top_p must be a float between 0 and 1.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -180,18 +190,11 @@ def predict_dna_sequence( # Decode the predicted DNA sequence from the model output if deterministic: - # Select the most probable tokens (argmax) predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: - # Sample tokens according to their probability distribution - # Apply temperature scaling and convert logits to probabilities - logits = logits / temperature - probabilities = torch.softmax(logits, dim=-1) - - # Sample from the probability distribution at each position - probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] - predicted_indices = ( - torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist() + # Use the standalone non-deterministic sampling function + predicted_indices = sample_non_deterministic( + logits=logits, temperature=temperature, top_p=top_p ) predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) @@ -209,6 +212,75 @@ def predict_dna_sequence( ) +def sample_non_deterministic( + logits: torch.Tensor, + temperature: float = 1.0, + top_p: float = 0.95, +) -> List[int]: + """ + Sample token indices from logits using temperature scaling and nucleus (top-p) sampling. + + This function applies temperature scaling to the logits, computes probabilities, + and then performs nucleus sampling to select token indices. It is used for + non-deterministic decoding in language models to introduce randomness while + maintaining coherence in the generated sequences. + + Args: + logits (torch.Tensor): The logits output from the model of shape + [seq_len, vocab_size] or [batch_size, seq_len, vocab_size]. + temperature (float, optional): Temperature value for scaling logits. + Must be a positive float. Defaults to 1.0. + top_p (float, optional): Cumulative probability threshold for nucleus sampling. + Must be a float between 0 and 1. Tokens with cumulative probability up to + `top_p` are considered for sampling. Defaults to 0.95. + + Returns: + List[int]: A list of sampled token indices corresponding to the predicted tokens. + + Raises: + ValueError: If `temperature` is not a positive float or if `top_p` is not between 0 and 1. + + Example: + >>> logits = model_output.logits # Assume logits is a tensor of shape [seq_len, vocab_size] + >>> predicted_indices = sample_non_deterministic(logits, temperature=0.7, top_p=0.9) + """ + if not isinstance(temperature, (float, int)) or temperature <= 0: + raise ValueError("Temperature must be a positive float.") + if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: + raise ValueError("top_p must be a float between 0 and 1.") + + # Apply temperature scaling and compute probabilities + logits = logits / temperature + probabilities = torch.softmax(logits, dim=-1) + + # Remove batch dimension if present + if probabilities.dim() == 3 and probabilities.size(0) == 1: + probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] + + predicted_indices = [] + for probs in probabilities: + sorted_probs, sorted_indices = torch.sort(probs, descending=True) + cumulative_probs = torch.cumsum(sorted_probs, dim=0) + + # Find the cutoff index where cumulative_probs exceeds top_p + cutoff_index = torch.where(cumulative_probs > top_p)[0] + if len(cutoff_index) > 0: + cutoff_index = cutoff_index[0].item() + # Keep only tokens up to the cutoff index + sorted_probs = sorted_probs[: cutoff_index + 1] + sorted_indices = sorted_indices[: cutoff_index + 1] + + # Re-normalize the probabilities after filtering + filtered_probs = sorted_probs / sorted_probs.sum() + + # Sample from the filtered distribution + sampled_index = torch.multinomial(filtered_probs, num_samples=1).item() + predicted_index = sorted_indices[sampled_index].item() + predicted_indices.append(predicted_index) + + return predicted_indices + + def load_model( model_path: Optional[str] = None, device: torch.device = None, From 6c5ffdf4fc7d4a4f691a28f6f12e43a5e755df95 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 17:40:37 -0400 Subject: [PATCH 12/36] Improve style. --- CodonTransformer/CodonPrediction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index a09a811..8500196 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -192,7 +192,6 @@ def predict_dna_sequence( if deterministic: predicted_indices = logits.argmax(dim=-1).squeeze().tolist() else: - # Use the standalone non-deterministic sampling function predicted_indices = sample_non_deterministic( logits=logits, temperature=temperature, top_p=top_p ) From b80f4190b0569adc9323f62b0090868883bc4c01 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:03:19 -0400 Subject: [PATCH 13/36] Add extensive testing for predict_dna_sequence. --- tests/test_CodonPrediction.py | 266 +++++++++++++++++++++++++++++++++- 1 file changed, 264 insertions(+), 2 deletions(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 16a6136..e82b535 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,8 +1,16 @@ import unittest import warnings +import random import torch +from CodonTransformer.CodonData import get_amino_acid_sequence +from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, + INDEX2TOKEN, + STOP_SYMBOLS, + ORGANISM2ID, +) from CodonTransformer.CodonPrediction import ( load_model, load_tokenizer, @@ -148,8 +156,6 @@ def test_predict_dna_sequence_invalid_temperature_and_top_p(self): def test_predict_dna_sequence_translation_consistency(self): """Test that the predicted DNA translates back to the original protein.""" - from CodonTransformer.CodonData import get_amino_acid_sequence - protein_sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVE" organism = "Escherichia coli general" result = predict_dna_sequence( @@ -170,6 +176,262 @@ def test_predict_dna_sequence_translation_consistency(self): "Translated protein does not match the original protein sequence", ) + def test_predict_dna_sequence_long_protein_sequence(self): + """Test the function with a very long protein sequence to check performance and correctness.""" + protein_sequence = ( + "M" + + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" + * 20 + + STOP_SYMBOLS[0] + ) + organism = "Escherichia coli general" + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Check that the predicted DNA translates back to the original protein + dna_sequence = result.predicted_dna[:-3] + translated_protein = get_amino_acid_sequence(dna_sequence) + self.assertEqual( + translated_protein, + protein_sequence[:-1], + "Translated protein does not match the original long protein sequence", + ) + + def test_predict_dna_sequence_edge_case_organisms(self): + """Test the function with organism IDs at the boundaries of the mapping.""" + protein_sequence = "MWWMW" + # Assuming ORGANISM2ID has IDs starting from 0 to N + min_organism_id = min(ORGANISM2ID.values()) + max_organism_id = max(ORGANISM2ID.values()) + organisms = [min_organism_id, max_organism_id] + + for organism_id in organisms: + with self.subTest(organism_id=organism_id): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism_id, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) + + def test_predict_dna_sequence_concurrent_calls(self): + """Test the function's behavior under concurrent execution.""" + import threading + + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + results = [] + + def call_predict(): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + results.append(result.predicted_dna) + + threads = [threading.Thread(target=call_predict) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + self.assertEqual(len(results), 10) + self.assertTrue(all(dna == results[0] for dna in results)) + + def test_predict_dna_sequence_random_seed_consistency(self): + """Test that setting a random seed results in consistent outputs in non-deterministic mode.""" + protein_sequence = "MFWY" + organism = "Escherichia coli general" + temperature = 0.5 + top_p = 0.95 + torch.manual_seed(42) + + result1 = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + + torch.manual_seed(42) + + result2 = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + temperature=temperature, + top_p=top_p, + ) + + self.assertEqual( + result1.predicted_dna, + result2.predicted_dna, + "Outputs should be consistent when random seed is set", + ) + + def test_predict_dna_sequence_invalid_tokenizer_and_model(self): + """Test that providing invalid tokenizer or model raises appropriate exceptions.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + + with self.subTest("Invalid tokenizer"): + with self.assertRaises(Exception): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer="invalid_tokenizer_path", + model=self.model, + ) + + with self.subTest("Invalid model"): + with self.assertRaises(Exception): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model="invalid_model_path", + ) + + def test_predict_dna_sequence_stop_codon_handling(self): + """Test the function's handling of protein sequences ending with a non '_' or '*' stop symbol.""" + protein_sequence = "MWW/" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) + + def test_predict_dna_sequence_ambiguous_amino_acids(self): + """Test the function's response to ambiguous or non-standard amino acids.""" + protein_sequence = "MWWBXZ" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) + + def test_predict_dna_sequence_device_compatibility(self): + """Test that the function works correctly on both CPU and GPU devices.""" + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + + for device in devices: + with self.subTest(device=device): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) + + def test_predict_dna_sequence_random_proteins(self): + """Test random proteins to ensure translated DNA matches the original protein.""" + organism = "Escherichia coli general" + num_tests = 200 + + for _ in range(num_tests): + # Generate a random protein sequence of random length between 10 and 50 + protein_length = random.randint(10, 500) + protein_sequence = "M" + "".join( + random.choices(AMINO_ACIDS, k=protein_length - 1) + ) + protein_sequence += random.choice(STOP_SYMBOLS) + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Remove stop codon from predicted DNA + dna_sequence = result.predicted_dna[:-3] + + # Translate predicted DNA back to protein + translated_protein = get_amino_acid_sequence(dna_sequence) + self.assertEqual( + translated_protein, + protein_sequence[:-1], # Remove stop symbol + f"Translated protein does not match the original protein sequence for protein: {protein_sequence}", + ) + + def test_predict_dna_sequence_long_protein_over_max_length(self): + """Test that the model handles protein sequences longer than 2048 amino acids.""" + # Create a protein sequence longer than 2048 amino acids + base_sequence = ( + "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGG" + ) + protein_sequence = base_sequence * 100 # Length > 2048 amino acids + organism = "Escherichia coli general" + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + ) + + # Remove stop codon from predicted DNA + dna_sequence = result.predicted_dna[:-3] + translated_protein = get_amino_acid_sequence(dna_sequence) + + # Due to potential model limitations, compare up to the model's max supported length + max_length = len(translated_protein) + self.assertEqual( + translated_protein[:max_length], + protein_sequence[:max_length], + "Translated protein does not match the original protein sequence up to the maximum length supported.", + ) + if __name__ == "__main__": unittest.main() From 353a14b772cef0457e110f524a402c6a92b69a53 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:03:35 -0400 Subject: [PATCH 14/36] Add a list of possible stop symbols. --- CodonTransformer/CodonPrediction.py | 4 +++- CodonTransformer/CodonUtils.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 8500196..6d3e1b1 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -22,6 +22,8 @@ from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, + STOP_SYMBOLS, INDEX2TOKEN, NUM_ORGANISMS, ORGANISM2ID, @@ -147,7 +149,7 @@ def predict_dna_sequence( raise ValueError("Protein sequence cannot be empty.") # Ensure the protein sequence contains only valid amino acids - if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): + if not all(aminoacid in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") # Validate temperature diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 29711de..34b6ea4 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -38,6 +38,7 @@ "W", # Tryptophan "Y", # Tyrosine ] +STOP_SYMBOLS = ["_", "*"] # Stop codon symbols # Dictionary ambiguous amino acids to standard amino acids AMBIGUOUS_AMINOACID_MAP: Dict[str, list[str]] = { From 0e43429cca8972fdd4ca7154c79daf9f1f7031de Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:06:14 -0400 Subject: [PATCH 15/36] Add docstrings for sample_non_deterministic and STOP_SYMBOLS. --- CodonTransformer/CodonPrediction.py | 2 +- README.md | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 6d3e1b1..af54a50 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -215,7 +215,7 @@ def predict_dna_sequence( def sample_non_deterministic( logits: torch.Tensor, - temperature: float = 1.0, + temperature: float = 0.2, top_p: float = 0.95, ) -> List[int]: """ diff --git a/README.md b/README.md index 8187f77..2baaa62 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,10 @@ This subpackage contains functions and classes that handle the core prediction f Predict the DNA sequence for a given protein using the CodonTransformer model. +- `sample_non_deterministic(logits: torch.Tensor, temperature: float = 0.2, top_p: float = 0.95) -> List[int]` + + Sample token indices from logits using temperature scaling and nucleus (top-p) sampling. + - `load_model(path: str, device: torch.device = None, num_organisms: int = None, remove_prefix: bool = True, attention_type: str = "original_full") -> torch.nn.Module` Load a BigBirdForMaskedLM model from a file or checkpoint. @@ -383,6 +387,7 @@ The CodonUtils subpackage contains constants and helper functions essential for #### Constants - `AMINO_ACIDS`: List of all standard amino acids +- `STOP_SYMBOLS`: List of possible stop symbols to end the protein with - `AMBIGUOUS_AMINOACID_MAP`: Mapping of ambiguous amino acids to standard amino acids - `START_CODONS` and `STOP_CODONS`: Lists of start and stop codons - `TOKEN2INDEX` and `INDEX2TOKEN`: Mappings between tokens and their indices From 6c0c091ff6f89fd30a57dc575b0ec2c1021d536e Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:10:05 -0400 Subject: [PATCH 16/36] Remove checking for protein sequence validity and bring it to preprocessing function. --- CodonTransformer/CodonData.py | 11 ++++++----- CodonTransformer/CodonPrediction.py | 6 ------ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index ea576ce..73c31b3 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -18,6 +18,7 @@ from tqdm import tqdm from CodonTransformer.CodonUtils import ( + STOP_SYMBOLS, AMBIGUOUS_AMINOACID_MAP, AMINO2CODON_TYPE, AMINO_ACIDS, @@ -197,13 +198,13 @@ def preprocess_protein_sequence(protein: str) -> str: raise ValueError(f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}.") # Check for sequence validity - if any( - aminoacid not in AMINO_ACIDS + ["*", STOP_SYMBOL] for aminoacid in protein[:-1] - ): + if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): raise ValueError("Invalid characters in protein sequence.") - if protein[-1] not in AMINO_ACIDS + ["*", STOP_SYMBOL]: - raise ValueError("Protein sequence must end with *, or _, or an amino acid.") + if protein[-1] not in AMINO_ACIDS + STOP_SYMBOLS: + raise ValueError( + "Protein sequence must end with `*`, or `_`, or an amino acid." + ) # Replace '*' at the end of protein with STOP_SYMBOL if present if protein[-1] == "*": diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index af54a50..f147105 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -22,8 +22,6 @@ from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( - AMINO_ACIDS, - STOP_SYMBOLS, INDEX2TOKEN, NUM_ORGANISMS, ORGANISM2ID, @@ -148,10 +146,6 @@ def predict_dna_sequence( if not protein: raise ValueError("Protein sequence cannot be empty.") - # Ensure the protein sequence contains only valid amino acids - if not all(aminoacid in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein): - raise ValueError("Invalid amino acid found in protein sequence.") - # Validate temperature if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") From 12e3af19059dc34ec163a2b8c07f0f6efd97f6de Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:12:33 -0400 Subject: [PATCH 17/36] Remove test_predict_dna_sequence_ambiguous_amino_acids test. --- tests/test_CodonPrediction.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index e82b535..75f1456 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -7,7 +7,6 @@ from CodonTransformer.CodonData import get_amino_acid_sequence from CodonTransformer.CodonUtils import ( AMINO_ACIDS, - INDEX2TOKEN, STOP_SYMBOLS, ORGANISM2ID, ) @@ -79,7 +78,7 @@ def test_predict_dna_sequence_non_deterministic(self): def test_predict_dna_sequence_invalid_inputs(self): test_cases = [ - ("MKTZZFVLLL", "Escherichia coli general", "invalid protein sequence"), + ("MKTZZFVLLL?", "Escherichia coli general", "invalid protein sequence"), ("MKTFFVLLL", "Alien $%#@!", "invalid organism code"), ("", "Escherichia coli general", "empty protein sequence"), ] @@ -331,20 +330,6 @@ def test_predict_dna_sequence_stop_codon_handling(self): model=self.model, ) - def test_predict_dna_sequence_ambiguous_amino_acids(self): - """Test the function's response to ambiguous or non-standard amino acids.""" - protein_sequence = "MWWBXZ" - organism = "Escherichia coli general" - - with self.assertRaises(ValueError): - predict_dna_sequence( - protein=protein_sequence, - organism=organism, - device=self.device, - tokenizer=self.tokenizer, - model=self.model, - ) - def test_predict_dna_sequence_device_compatibility(self): """Test that the function works correctly on both CPU and GPU devices.""" protein_sequence = "MWWMW" From 0723663597b476d308e4034cdafedd2d1bb7cd70 Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:16:25 -0400 Subject: [PATCH 18/36] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e0c0e74..a11f108 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,3 +32,9 @@ repos: hooks: - id: ruff - id: ruff-format + + - repo: https://github.com/psf/black + rev: 24.8.0 + hooks: + - id: black + language_version: python3 From e8c25bf0d0fd8a90b1e2137c8a5d16733dbd3738 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 18:20:17 -0400 Subject: [PATCH 19/36] Improve style. --- CodonTransformer/CodonData.py | 2 +- tests/test_CodonPrediction.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 73c31b3..7bee937 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -18,7 +18,6 @@ from tqdm import tqdm from CodonTransformer.CodonUtils import ( - STOP_SYMBOLS, AMBIGUOUS_AMINOACID_MAP, AMINO2CODON_TYPE, AMINO_ACIDS, @@ -26,6 +25,7 @@ START_CODONS, STOP_CODONS, STOP_SYMBOL, + STOP_SYMBOLS, ConfigManager, find_pattern_in_fasta, get_taxonomy_id, diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 75f1456..1fb0d43 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,20 +1,20 @@ +import random import unittest import warnings -import random import torch from CodonTransformer.CodonData import get_amino_acid_sequence -from CodonTransformer.CodonUtils import ( - AMINO_ACIDS, - STOP_SYMBOLS, - ORGANISM2ID, -) from CodonTransformer.CodonPrediction import ( load_model, load_tokenizer, predict_dna_sequence, ) +from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, + ORGANISM2ID, + STOP_SYMBOLS, +) class TestCodonPrediction(unittest.TestCase): From 0ecc4cba75240e59d95f040de2981f27a7a015c3 Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:21:41 -0400 Subject: [PATCH 20/36] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a11f108..15b9d33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: ruff-format - repo: https://github.com/psf/black - rev: 24.8.0 - hooks: - - id: black - language_version: python3 + rev: 24.8.0 + hooks: + - id: black + language_version: python3 From 1ffe33ac8009abfc9d41da23dbe6443b05462594 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 20:40:44 -0400 Subject: [PATCH 21/36] Fix issue with top_p sampling. --- tests/test_CodonPrediction.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 1fb0d43..079cbe0 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -60,7 +60,6 @@ def test_predict_dna_sequence_non_deterministic(self): "ATGTTTTGGTAC", "ATGTTCTGGTAC", } - for _ in range(num_iterations): for temperature in temperatures: result = predict_dna_sequence( From b5201a9e60177e0d9f24a1261f2524b7ef2367ac Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 20:41:08 -0400 Subject: [PATCH 22/36] Fix issue with top_p sampling. --- CodonTransformer/CodonPrediction.py | 49 +++++++++++++---------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index f147105..adb85ed 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -241,39 +241,32 @@ def sample_non_deterministic( """ if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") + if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: raise ValueError("top_p must be a float between 0 and 1.") - # Apply temperature scaling and compute probabilities - logits = logits / temperature - probabilities = torch.softmax(logits, dim=-1) + # Compute probabilities using temperature scaling + logits /= temperature + probs = torch.softmax(logits, dim=-1) # Remove batch dimension if present - if probabilities.dim() == 3 and probabilities.size(0) == 1: - probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] - - predicted_indices = [] - for probs in probabilities: - sorted_probs, sorted_indices = torch.sort(probs, descending=True) - cumulative_probs = torch.cumsum(sorted_probs, dim=0) - - # Find the cutoff index where cumulative_probs exceeds top_p - cutoff_index = torch.where(cumulative_probs > top_p)[0] - if len(cutoff_index) > 0: - cutoff_index = cutoff_index[0].item() - # Keep only tokens up to the cutoff index - sorted_probs = sorted_probs[: cutoff_index + 1] - sorted_indices = sorted_indices[: cutoff_index + 1] - - # Re-normalize the probabilities after filtering - filtered_probs = sorted_probs / sorted_probs.sum() - - # Sample from the filtered distribution - sampled_index = torch.multinomial(filtered_probs, num_samples=1).item() - predicted_index = sorted_indices[sampled_index].item() - predicted_indices.append(predicted_index) - - return predicted_indices + if probs.dim() == 3: + probs = probs.squeeze(0) # Shape: [seq_len, vocab_size] + + # Sort probabilities in descending order + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + + # Zero out probabilities for tokens beyond the top-p threshold + probs_sort[mask] = 0.0 + + # Renormalize the probabilities + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + predicted_indices = torch.gather(probs_idx, -1, next_token).squeeze(-1) + + return predicted_indices.tolist() def load_model( From d66cf72382b1e274fa2c174ce03aaf4506617985 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Fri, 20 Sep 2024 20:43:04 -0400 Subject: [PATCH 23/36] Bump version to 1.6.3 --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4b8bfb7..7729c53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "CodonTransformer" -version = "1.6.2" +version = "1.6.3" description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." authors = ["Adibvafa Fallahpour "] license = "Apache-2.0" diff --git a/setup.py b/setup.py index d90c946..03f4f4f 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def read_readme(): setup( name="CodonTransformer", - version="1.6.2", + version="1.6.3", packages=find_packages(), install_requires=read_requirements(), author="Adibvafa Fallahpour", From 28589ba445f21a0fc9fd9dece17a53ef7b47dcd4 Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Sat, 21 Sep 2024 10:08:20 -0400 Subject: [PATCH 24/36] Update issue templates --- .github/ISSUE_TEMPLATE/feature_request.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..bbcbbe7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. From 8247187fac958ae204fea85d7a7bcdc9e20857ed Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Sat, 21 Sep 2024 10:11:14 -0400 Subject: [PATCH 25/36] Update issue templates --- .github/ISSUE_TEMPLATE/bug_report.md | 38 +++++++++++++++++++++++ .github/ISSUE_TEMPLATE/feature_request.md | 5 +-- .github/ISSUE_TEMPLATE/other.md | 10 ++++++ 3 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/other.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..dd84ea7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,38 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Desktop (please complete the following information):** + - OS: [e.g. iOS] + - Browser [e.g. chrome, safari] + - Version [e.g. 22] + +**Smartphone (please complete the following information):** + - Device: [e.g. iPhone6] + - OS: [e.g. iOS8.1] + - Browser [e.g. stock browser, safari] + - Version [e.g. 22] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index bbcbbe7..59094e2 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -2,7 +2,7 @@ name: Feature request about: Suggest an idea for this project title: '' -labels: '' +labels: enhancement assignees: '' --- @@ -13,8 +13,5 @@ A clear and concise description of what the problem is. Ex. I'm always frustrate **Describe the solution you'd like** A clear and concise description of what you want to happen. -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - **Additional context** Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/other.md b/.github/ISSUE_TEMPLATE/other.md new file mode 100644 index 0000000..10efe40 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/other.md @@ -0,0 +1,10 @@ +--- +name: Other +about: Any other issue +title: '' +labels: bug +assignees: '' + +--- + +**Describe your issue here** From eb054195fad1eb992dbcf9f03f57e4226cebcc6e Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Sat, 21 Sep 2024 10:12:03 -0400 Subject: [PATCH 26/36] Create CODE_OF_CONDUCT.md --- CODE_OF_CONDUCT.md | 128 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..9faefd4 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +Adibvafa.fallahpour@mail.utoronto.ca. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. From 126b6c2b4156dcb64c397764e4d35cf4b1b76512 Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Sat, 21 Sep 2024 10:28:02 -0400 Subject: [PATCH 27/36] Add support for multiple sequence generation. --- CodonTransformer/CodonPrediction.py | 87 ++++++++++++++++++----------- 1 file changed, 53 insertions(+), 34 deletions(-) diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index adb85ed..83b9251 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -40,19 +40,20 @@ def predict_dna_sequence( deterministic: bool = True, temperature: float = 0.2, top_p: float = 0.95, -) -> DNASequencePrediction: + num_sequences: int = 1, +) -> Union[DNASequencePrediction, List[DNASequencePrediction]]: """ - Predict the DNA sequence for a given protein using the CodonTransformer model. + Predict the DNA sequence(s) for a given protein using the CodonTransformer model. This function takes a protein sequence and an organism (as ID or name) as input - and returns the predicted DNA sequence using the CodonTransformer model. It can use + and returns the predicted DNA sequence(s) using the CodonTransformer model. It can use either provided tokenizer and model objects or load them from specified paths. Args: protein (str): The input protein sequence for which to predict the DNA sequence. organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). If a string is provided, it will be converted - to the corresponding ID using `ORGANISM2ID`. + to the corresponding ID using ORGANISM2ID. device (torch.device): The device (CPU or GPU) to run the model on. tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, a pre-loaded tokenizer object, or None. If @@ -77,12 +78,15 @@ def predict_dna_sequence( - High randomness: 0.8 The temperature must be a positive float. Defaults to 0.2. top_p (float, optional): The cumulative probability threshold for nucleus sampling. - Tokens with cumulative probability up to `top_p` are considered for sampling. + Tokens with cumulative probability up to top_p are considered for sampling. This parameter helps balance diversity and coherence in the predicted DNA sequences. The value must be a float between 0 and 1. Defaults to 0.95. + num_sequences (int, optional): The number of DNA sequences to generate. Only applicable + when deterministic is False. Defaults to 1. Returns: - DNASequencePrediction: An object containing the prediction results: + Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects + containing the prediction results: - organism (str): Name of the organism used for prediction. - protein (str): Input protein sequence for which DNA sequence is predicted. - processed_input (str): Processed input sequence (merged protein and DNA). @@ -90,12 +94,13 @@ def predict_dna_sequence( Raises: ValueError: If the protein sequence is empty, if the organism is invalid, - if the temperature is not a positive float, or if `top_p` is not between 0 and 1. + if the temperature is not a positive float, if top_p is not between 0 and 1, + or if num_sequences is less than 1 or used with deterministic mode. Note: - This function uses `ORGANISM2ID` and `INDEX2TOKEN` dictionaries imported from - `CodonTransformer.CodonUtils`. `ORGANISM2ID` maps organism names to their - corresponding IDs. `INDEX2TOKEN` maps model output indices (token IDs) to + This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from + CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their + corresponding IDs. INDEX2TOKEN maps model output indices (token IDs) to respective codons. Example: @@ -116,7 +121,7 @@ def predict_dna_sequence( >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA" >>> organism = "Escherichia coli general" >>> - >>> # Predict DNA sequence with deterministic decoding + >>> # Predict DNA sequence with deterministic decoding (single sequence) >>> output = predict_dna_sequence( ... protein=protein, ... organism=organism, @@ -127,7 +132,7 @@ def predict_dna_sequence( ... deterministic=True ... ) >>> - >>> # Predict DNA sequence with low randomness and top_p sampling + >>> # Predict multiple DNA sequences with low randomness and top_p sampling >>> output_random = predict_dna_sequence( ... protein=protein, ... organism=organism, @@ -137,23 +142,33 @@ def predict_dna_sequence( ... attention_type="original_full", ... deterministic=False, ... temperature=0.2, - ... top_p=0.95 + ... top_p=0.95, + ... num_sequences=3 ... ) >>> >>> print(format_model_output(output)) - >>> print(format_model_output(output_random)) + >>> for i, seq in enumerate(output_random, 1): + ... print(f"Sequence {i}:") + ... print(format_model_output(seq)) + ... print() """ if not protein: raise ValueError("Protein sequence cannot be empty.") - # Validate temperature if not isinstance(temperature, (float, int)) or temperature <= 0: raise ValueError("Temperature must be a positive float.") - # Validate top_p if not isinstance(top_p, (float, int)) or not 0 < top_p <= 1.0: raise ValueError("top_p must be a float between 0 and 1.") + if not isinstance(num_sequences, int) or num_sequences < 1: + raise ValueError("num_sequences must be a positive integer.") + + if deterministic and num_sequences > 1: + raise ValueError( + "Multiple sequences can only be generated in non-deterministic mode." + ) + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -184,27 +199,31 @@ def predict_dna_sequence( output_dict = model(**tokenized_input, return_dict=True) logits = output_dict.logits.detach().cpu() - # Decode the predicted DNA sequence from the model output - if deterministic: - predicted_indices = logits.argmax(dim=-1).squeeze().tolist() - else: - predicted_indices = sample_non_deterministic( - logits=logits, temperature=temperature, top_p=top_p + predictions = [] + for _ in range(num_sequences): + # Decode the predicted DNA sequence from the model output + if deterministic: + predicted_indices = logits.argmax(dim=-1).squeeze().tolist() + else: + predicted_indices = sample_non_deterministic( + logits=logits, temperature=temperature, top_p=top_p + ) + + predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) + predicted_dna = ( + "".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper() ) - predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) - - # Skip special tokens [CLS] and [SEP] to create the predicted_dna - predicted_dna = ( - "".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper() - ) + predictions.append( + DNASequencePrediction( + organism=organism_name, + protein=protein, + processed_input=merged_seq, + predicted_dna=predicted_dna, + ) + ) - return DNASequencePrediction( - organism=organism_name, - protein=protein, - processed_input=merged_seq, - predicted_dna=predicted_dna, - ) + return predictions[0] if num_sequences == 1 else predictions def sample_non_deterministic( From c35d3c76ca9e4e8b8c701f2a9e96d8fab21b815e Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Sat, 21 Sep 2024 10:28:29 -0400 Subject: [PATCH 28/36] Test multiple sequence generation. --- tests/test_CodonPrediction.py | 76 +++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 079cbe0..310193e 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -14,6 +14,7 @@ AMINO_ACIDS, ORGANISM2ID, STOP_SYMBOLS, + DNASequencePrediction, ) @@ -416,6 +417,81 @@ def test_predict_dna_sequence_long_protein_over_max_length(self): "Translated protein does not match the original protein sequence up to the maximum length supported.", ) + def test_predict_dna_sequence_multi_output(self): + """Test that the function returns multiple sequences when num_sequences > 1.""" + protein_sequence = "MFQLLAPWY" + organism = "Escherichia coli general" + num_sequences = 20 + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + num_sequences=num_sequences, + ) + + self.assertIsInstance(result, list) + self.assertEqual(len(result), num_sequences) + + for prediction in result: + self.assertIsInstance(prediction, DNASequencePrediction) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in prediction.predicted_dna) + ) + + # Check that all predicted DNA sequences translate back to the original protein + translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3]) + self.assertEqual(translated_protein, protein_sequence) + + def test_predict_dna_sequence_deterministic_multi_raises_error(self): + """Test that requesting multiple sequences in deterministic mode raises an error.""" + protein_sequence = "MFWY" + organism = "Escherichia coli general" + + with self.assertRaises(ValueError): + predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=True, + num_sequences=3, + ) + + def test_predict_dna_sequence_multi_diversity(self): + """Test that multiple sequences generated are diverse.""" + protein_sequence = "MFWYMFWY" + organism = "Escherichia coli general" + num_sequences = 10 + + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + num_sequences=num_sequences, + temperature=0.8, + ) + + unique_sequences = set(prediction.predicted_dna for prediction in result) + + self.assertGreater( + len(unique_sequences), + 2, + "Multiple sequence generation should produce diverse results", + ) + + # Check that all sequences are valid translations of the input protein + for prediction in result: + translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3]) + self.assertEqual(translated_protein, protein_sequence) + if __name__ == "__main__": unittest.main() From 17343c70d762acccc4af18a2faa6e055233ebb1d Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Sat, 21 Sep 2024 15:11:22 -0400 Subject: [PATCH 29/36] Update README. --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2baaa62..598c7bb 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ The genetic code is degenerate allowing a multitude of possible DNA sequences to **For an interactive demo, check out our [Google Colab Notebook.](https://adibvafa.github.io/CodonTransformer/GoogleColab)**

After installing CodonTransformer, you can use: + ```python import torch from transformers import AutoTokenizer, BigBirdForMaskedLM @@ -156,7 +157,7 @@ To finetune CodonTransformer on your own data, follow these steps: --learning_rate 0.00005 \ --warmup_fraction 0.1 \ --save_every_n_steps 512 \ - --seed 123 + --seed 23 ``` This script automatically loads the pretrained model from Hugging Face and finetunes it on your dataset. For an example of a SLURM job request, see the `slurm` directory in the repository. @@ -271,7 +272,7 @@ This subpackage contains functions and classes that handle the core prediction f ### Available Functions and Classes -- `predict_dna_sequence(protein: str, organism: Union[int, str], device: torch.device, tokenizer: Union[str, PreTrainedTokenizerFast], model: Union[str, torch.nn.Module], attention_type: str = "original_full") -> DNASequencePrediction` +- `predict_dna_sequence(protein: str, organism: Union[int, str], device: torch.device, tokenizer: Union[str, PreTrainedTokenizerFast], model: Union[str, torch.nn.Module], attention_type: str = "original_full", deterministic: bool = True, temperature: float = 0.2, top_p: float = 0.95, num_sequences: int = 1) -> DNASequencePrediction` Predict the DNA sequence for a given protein using the CodonTransformer model. From 73cef8aff8d8e9a16890a490818d299ac7c3b1aa Mon Sep 17 00:00:00 2001 From: Adibvafa Date: Sat, 21 Sep 2024 15:19:12 -0400 Subject: [PATCH 30/36] Bump version to 1.6.4 --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7729c53..ef821e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "CodonTransformer" -version = "1.6.3" +version = "1.6.4" description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." authors = ["Adibvafa Fallahpour "] license = "Apache-2.0" diff --git a/setup.py b/setup.py index 03f4f4f..2c51966 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def read_readme(): setup( name="CodonTransformer", - version="1.6.3", + version="1.6.4", packages=find_packages(), install_requires=read_requirements(), author="Adibvafa Fallahpour", From 349cc6830924aa1ff3b84e419b54c07f012b2294 Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Sat, 21 Sep 2024 15:17:48 -0400 Subject: [PATCH 31/36] Update README.md --- README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 598c7bb..5554bfe 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ output = predict_dna_sequence( tokenizer=tokenizer, model=model, attention_type="original_full", + deterministic=True ) print(format_model_output(output)) ``` @@ -86,13 +87,23 @@ M_UNK A_UNK L_UNK W_UNK M_UNK R_UNK L_UNK L_UNK P_UNK L_UNK L_UNK A_UNK L_UNK L_ ----------------------------- ATGGCTTTATGGATGCGTCTGCTGCCGCTGCTGGCGCTGCTGGCGCTGTGGGGCCCGGACCCGGCGGCGGCGTTTGTGAATCAGCACCTGTGCGGCAGCCACCTGGTGGAAGCGCTGTATCTGGTGTGCGGTGAGCGCGGCTTCTTCTACACGCCCAAAACCCGCCGCGAAGCGGAAGATCTGCAGGTGGGCCAGGTGGAGCTGGGCGGCTAA ``` + +### Generating Multiple Variable Sequences + +Set `deterministic=False` to generate variable sequences. Control the variability using `temperature`: + +- `temperature`: (recommended between 0.2 and 0.8) + - Lower values (e.g., 0.2): More conservative predictions + - Higher values (e.g., 0.8): More diverse predictions + +Using very high temperatures might result in prediction of DNA sequences that do not translate to the exact input protein.
+Generate multiple sequences by setting `num_sequences` to a value greater than 1.
**You can use the [inference template](https://github.com/Adibvafa/CodonTransformer/raw/main/src/CodonTransformer_inference_template.xlsx) for batch inference in [Google Colab](https://adibvafa.github.io/CodonTransformer/GoogleColab).**
- ## Installation Install CodonTransformer via pip: From 78a818bf7e34a35d206ec44ca90d4c17216a0e83 Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Sat, 21 Sep 2024 15:18:12 -0400 Subject: [PATCH 32/36] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5554bfe..e39da48 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Set `deterministic=False` to generate variable sequences. Control the variabilit - Lower values (e.g., 0.2): More conservative predictions - Higher values (e.g., 0.8): More diverse predictions -Using very high temperatures might result in prediction of DNA sequences that do not translate to the exact input protein.
+Using high temperatures might result in prediction of DNA sequences that do not translate to the input protein.
Generate multiple sequences by setting `num_sequences` to a value greater than 1.
From b7a66957a029023dc56adcec39b661ca299115ae Mon Sep 17 00:00:00 2001 From: derpbuffalo Date: Tue, 24 Sep 2024 10:13:20 +0200 Subject: [PATCH 33/36] change ConfigManager to abc and add ProteinConfig for issue #5 --- CodonTransformer/CodonData.py | 4 +-- CodonTransformer/CodonUtils.py | 66 +++++++++++++++++++++------------- tests/test_CodonData.py | 4 +-- tests/test_CodonUtils.py | 6 ++-- 4 files changed, 49 insertions(+), 31 deletions(-) diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 7bee937..a40f6f6 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -26,7 +26,7 @@ STOP_CODONS, STOP_SYMBOL, STOP_SYMBOLS, - ConfigManager, + ProteinConfig, find_pattern_in_fasta, get_taxonomy_id, sort_amino2codon_skeleton, @@ -175,7 +175,7 @@ def preprocess_protein_sequence(protein: str) -> str: ) # Handle ambiguous amino acids based on the specified behavior - config = ConfigManager() + config = ProteinConfig() ambiguous_aminoacid_map_override = config.get('ambiguous_aminoacid_map_override') ambiguous_aminoacid_behavior = config.get('ambiguous_aminoacid_behavior') ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy() diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 34b6ea4..5a339d5 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -8,6 +8,7 @@ import os import pickle import re +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Tuple @@ -545,30 +546,12 @@ def __init__(self, data_path: str, train: bool = True, **kwargs): self.train = train -class ConfigManager: +class ConfigManager(ABC): """ - A class to manage configuration settings. - - This class ensures that the configuration is a singleton. - It provides methods to get, set, and update configuration values. - - Attributes: - _instance (Optional[ConfigManager]): The singleton instance of the ConfigManager. - _config (Dict[str, Any]): The configuration dictionary. + Abstract base class for managing configuration settings. """ - _instance = None - - def __new__(cls): - """ - Create a new instance of the ConfigManager class. - - Returns: - ConfigManager: The singleton instance of the ConfigManager. - """ - if cls._instance is None: - cls._instance = super(ConfigManager, cls).__new__(cls) - cls._instance.reset_config() - return cls._instance + def __init__(self): + self._config: Dict[str, Any] = {} def __enter__(self): return self @@ -578,6 +561,11 @@ def __exit__(self, exc_type, exc_value, traceback): print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}") self.reset_config() + @abstractmethod + def reset_config(self) -> None: + """Reset the configuration to default values.""" + pass + def get(self, key: str) -> Any: """ Get the value of a configuration key. @@ -610,8 +598,37 @@ def update(self, config_dict: dict) -> None: """ for key, value in config_dict.items(): self.validate_inputs(key, value) - for key, value in config_dict.items(): - self.set(key, value) + self._config.update(config_dict) + + @abstractmethod + def validate_inputs(self, key: str, value: Any) -> None: + """Validate the inputs for the configuration.""" + pass + +class ProteinConfig(ConfigManager): + """ + A class to manage configuration settings for protein sequences. + + This class ensures that the configuration is a singleton. + It provides methods to get, set, and update configuration values. + + Attributes: + _instance (Optional[ConfigManager]): The singleton instance of the ConfigManager. + _config (Dict[str, Any]): The configuration dictionary. + """ + _instance = None + + def __new__(cls): + """ + Create a new instance of the ProteinConfig class. + + Returns: + ProteinConfig: The singleton instance of the ProteinConfig. + """ + if cls._instance is None: + cls._instance = super(ProteinConfig, cls).__new__(cls) + cls._instance.reset_config() + return cls._instance def validate_inputs(self, key: str, value: Any) -> None: """ @@ -644,6 +661,7 @@ def validate_inputs(self, key: str, value: Any) -> None: raise ValueError(f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}") else: raise ValueError(f"Invalid configuration key: {key}") + def reset_config(self) -> None: """ Reset the configuration to the default values. diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py index 51efdd3..5d20704 100644 --- a/tests/test_CodonData.py +++ b/tests/test_CodonData.py @@ -11,11 +11,11 @@ read_fasta_file, preprocess_protein_sequence, ) -from CodonTransformer.CodonUtils import ConfigManager +from CodonTransformer.CodonUtils import ProteinConfig class TestCodonData(unittest.TestCase): def test_preprocess_protein_sequence(self): - with ConfigManager() as config: + with ProteinConfig() as config: protein = "Z_" try: preprocess_protein_sequence(protein) diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py index 832f816..5780314 100644 --- a/tests/test_CodonUtils.py +++ b/tests/test_CodonUtils.py @@ -4,7 +4,7 @@ import unittest from CodonTransformer.CodonUtils import ( - ConfigManager, + ProteinConfig, find_pattern_in_fasta, get_organism2id_dict, get_taxonomy_id, @@ -17,7 +17,7 @@ class TestCodonUtils(unittest.TestCase): def test_config_manager(self): - with ConfigManager() as config: + with ProteinConfig() as config: config.set( "ambiguous_aminoacid_behavior", "standardize_deterministic" @@ -51,7 +51,7 @@ def test_config_manager(self): self.fail("Expected ValueError") except ValueError: pass - with ConfigManager() as config: + with ProteinConfig() as config: self.assertEqual( config.get("ambiguous_aminoacid_behavior"), "raise_error" From 0fcba6aa4ac08e5e5ed5684e4878590dca621279 Mon Sep 17 00:00:00 2001 From: derpbuffalo Date: Wed, 25 Sep 2024 12:14:21 +0200 Subject: [PATCH 34/36] change default behavior of ProteinConfig --- CodonTransformer/CodonUtils.py | 2 +- tests/test_CodonData.py | 1 + tests/test_CodonUtils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 5a339d5..4c7a071 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -667,7 +667,7 @@ def reset_config(self) -> None: Reset the configuration to the default values. """ self._config = { - 'ambiguous_aminoacid_behavior': 'raise_error', + 'ambiguous_aminoacid_behavior': 'standardize_random', 'ambiguous_aminoacid_map_override': {} } diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py index 5d20704..1c718bb 100644 --- a/tests/test_CodonData.py +++ b/tests/test_CodonData.py @@ -16,6 +16,7 @@ class TestCodonData(unittest.TestCase): def test_preprocess_protein_sequence(self): with ProteinConfig() as config: + config.set("ambiguous_aminoacid_behavior", "raise_error") protein = "Z_" try: preprocess_protein_sequence(protein) diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py index 5780314..92dbcea 100644 --- a/tests/test_CodonUtils.py +++ b/tests/test_CodonUtils.py @@ -54,7 +54,7 @@ def test_config_manager(self): with ProteinConfig() as config: self.assertEqual( config.get("ambiguous_aminoacid_behavior"), - "raise_error" + "standardize_random" ) self.assertEqual( config.get("ambiguous_aminoacid_map_override"), From f4fe97ac9b0afc95c76c70306cb5f0791d79d864 Mon Sep 17 00:00:00 2001 From: derpbuffalo Date: Thu, 26 Sep 2024 12:37:56 +0200 Subject: [PATCH 35/36] fix init behavior of ConfigManager --- CodonTransformer/CodonUtils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 4c7a071..349a26d 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -550,9 +550,6 @@ class ConfigManager(ABC): """ Abstract base class for managing configuration settings. """ - def __init__(self): - self._config: Dict[str, Any] = {} - def __enter__(self): return self From d7091191993341e514f8233dd95331f349705589 Mon Sep 17 00:00:00 2001 From: derpbuffalo Date: Thu, 26 Sep 2024 12:38:28 +0200 Subject: [PATCH 36/36] fix testcase for ConfigManager --- tests/test_CodonUtils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py index 92dbcea..128b638 100644 --- a/tests/test_CodonUtils.py +++ b/tests/test_CodonUtils.py @@ -28,11 +28,11 @@ def test_config_manager(self): ) config.set( "ambiguous_aminoacid_map_override", - {"R": ["A", "G"]} + {"X": ["A", "G"]} ) self.assertEqual( config.get("ambiguous_aminoacid_map_override"), - {"R": ["A", "G"]} + {"X": ["A", "G"]} ) config.update({ "ambiguous_aminoacid_behavior": "raise_error",