diff --git a/.gitignore b/.gitignore index 0a03dca..46f2a2f 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,4 @@ lightning_logs/ # PyTorch model weights *.pth -*.pt \ No newline at end of file +*.pt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e0c0e74 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,34 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +# Don't run pre-commit on files under third-party/ +exclude: "^\ + (third-party/.*)\ + " + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: check-added-large-files # prevents giant files from being committed. + - id: check-case-conflict # checks for files that would conflict in case-insensitive filesystems. + - id: check-merge-conflict # checks for files that contain merge conflict strings. + - id: check-yaml # checks yaml files for parseable syntax. + - id: detect-private-key # detects the presence of private keys. + - id: end-of-file-fixer # ensures that a file is either empty, or ends with one newline. + - id: fix-byte-order-marker # removes utf-8 byte order marker. + - id: mixed-line-ending # replaces or checks mixed line ending. + - id: requirements-txt-fixer # sorts entries in requirements.txt. + - id: trailing-whitespace # trims trailing whitespace. + + - repo: https://github.com/sirosen/check-jsonschema + rev: 0.23.2 + hooks: + - id: check-github-actions + - id: check-github-workflows + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.13 + hooks: + - id: ruff + - id: ruff-format diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 5932ff8..b6f6d86 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -5,32 +5,30 @@ preparing the data for training and inference of the CodonTransformer model. """ -import os import json +import os +from typing import Dict, List, Optional, Tuple, Union + import pandas as pd +import python_codon_tables as pct +from Bio import SeqIO +from Bio.Seq import Seq from sklearn.utils import shuffle as sk_shuffle +from tqdm import tqdm from CodonTransformer.CodonUtils import ( + AMBIGUOUS_AMINOACID_MAP, + AMINO2CODON_TYPE, AMINO_ACIDS, + ORGANISM2ID, START_CODONS, STOP_CODONS, STOP_SYMBOL, - AMINO2CODON_TYPE, - AMBIGUOUS_AMINOACID_MAP, - ORGANISM2ID, find_pattern_in_fasta, - sort_amino2codon_skeleton, get_taxonomy_id, + sort_amino2codon_skeleton, ) -from Bio import SeqIO -from Bio.Seq import Seq - -import python_codon_tables as pct - -from typing import List, Dict, Tuple, Union, Optional -from tqdm import tqdm - def prepare_training_data( dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True @@ -50,7 +48,8 @@ def prepare_training_data( Args: dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format. output_file (str): Path to save the output JSON dataset. - shuffle (bool, optional): Whether to shuffle the dataset before saving. Defaults to True. + shuffle (bool, optional): Whether to shuffle the dataset before saving. + Defaults to True. Returns: None @@ -78,7 +77,7 @@ def prepare_training_data( def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None: """ - Convert a pandas DataFrame to a JSON file format suitable for training CodonTransformer. + Convert pandas DataFrame to JSON file format suitable for training CodonTransformer. This function takes a preprocessed DataFrame and writes it to a JSON file where each line is a JSON object representing a single record. @@ -86,7 +85,8 @@ def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) Args: df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns. output_file (str): Path to the output JSON file. - shuffle (bool, optional): Whether to shuffle the dataset before saving. Defaults to True. + shuffle (bool, optional): Whether to shuffle the dataset before saving. + Defaults to True. Returns: None @@ -123,8 +123,9 @@ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) It validates the input against a provided mapping of organism names to IDs. Args: - organism (Union[str, int]): The input organism, either as a name (str) or ID (int). - organism_to_id (Dict[str, int]): A dictionary mapping organism names to their corresponding IDs. + organism (Union[str, int]): Input organism, either as a name (str) or ID (int). + organism_to_id (Dict[str, int]): Dictionary mapping organism names to their + corresponding IDs. Returns: int: The validated organism ID. @@ -150,7 +151,8 @@ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) def preprocess_protein_sequence(protein: str) -> str: """ - Preprocess a protein sequence by cleaning, standardizing, and handling ambiguous amino acids. + Preprocess a protein sequence by cleaning, standardizing, and handling + ambiguous amino acids. Args: protein (str): The input protein sequence. @@ -221,7 +223,8 @@ def replace_ambiguous_codons(dna: str) -> str: def preprocess_dna_sequence(dna: str) -> str: """ - Cleans and preprocesses a DNA sequence by standardizing it and replacing ambiguous codons. + Cleans and preprocesses a DNA sequence by standardizing it and replacing + ambiguous codons. Args: dna (str): The DNA sequence to preprocess. @@ -247,8 +250,9 @@ def preprocess_dna_sequence(dna: str) -> str: def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: """ - Return the merged sequence of protein amino acids and DNA codons in the form of tokens - separated by space, where each token is composed of an amino acid + separator + codon. + Return the merged sequence of protein amino acids and DNA codons in the form + of tokens separated by space, where each token is composed of an amino acid + + separator + codon. Args: protein (str): Protein sequence. @@ -274,8 +278,9 @@ def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: # Check if the length of protein and dna sequences are equal if len(dna) > 0 and len(protein) != len(dna) / 3: raise ValueError( - 'Length of protein (including stop symbol such as "_") and \ - the number of codons in DNA sequence (including stop codon) must be equal.' + 'Length of protein (including stop symbol such as "_") and ' + "the number of codons in DNA sequence (including stop codon) " + "must be equal." ) # Merge protein and DNA sequences into tokens @@ -331,8 +336,8 @@ def get_amino_acid_sequence( return_correct_seq (bool): Whether to return if the sequence is correct. Returns: - Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if return_correct_seq is True, - otherwise just the protein sequence. + Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if + return_correct_seq is True, otherwise just the protein sequence. """ dna_seq = Seq(dna).strip() @@ -365,12 +370,15 @@ def read_fasta_file( Args: input_file (str): Path to the input FASTA file. - save_to_file (Optional[str]): Path to save the output DataFrame. If None, data is only returned. - organism (str): Name of the organism. If empty, it will be extracted from the FASTA description. + save_to_file (Optional[str]): Path to save the output DataFrame. If None, + data is only returned. + organism (str): Name of the organism. If empty, it will be extracted from + the FASTA description. buffer_size (int): Number of records to process before writing to file. Returns: - pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe is True, else None. + pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe + is True, else None. Raises: FileNotFoundError: If the input file does not exist. @@ -498,7 +506,8 @@ def download_codon_frequencies_from_kazusa( def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: """ - Return the empty skeleton of the amino2codon dictionary, needed for get_codon_frequencies. + Return the empty skeleton of the amino2codon dictionary, needed for + get_codon_frequencies. Args: organism (str): Name of the organism. @@ -514,7 +523,8 @@ def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: return_correct_seq=False, ) - # Initialize the amino2codon skeleton with all possible codons and set their frequencies to 0 + # Initialize the amino2codon skeleton with all possible codons and set their + # frequencies to 0 for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)): if amino not in amino2codon: amino2codon[amino] = ([], []) @@ -543,7 +553,8 @@ def get_codon_frequencies( organism (Optional[str]): Name of the organism. Returns: - AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons and frequencies. + AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons + and frequencies. """ if organism: codon_table = get_codon_table(organism) @@ -583,7 +594,8 @@ def get_organism_to_codon_frequencies( organisms (List[str]): List of organisms. Returns: - Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon frequency distribution. + Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon + frequency distribution. """ organism2frequencies = {} @@ -617,7 +629,8 @@ def get_codon_table(organism: str) -> int: "Arabidopsis thaliana", "Caenorhabditis elegans", "Chlamydomonas reinhardtii", - "Saccharomyces cerevisiae" "Danio rerio", + "Saccharomyces cerevisiae", + "Danio rerio", "Drosophila melanogaster", "Homo sapiens", "Mus musculus", diff --git a/CodonTransformer/CodonEvaluation.py b/CodonTransformer/CodonEvaluation.py index e94423a..d09ba92 100644 --- a/CodonTransformer/CodonEvaluation.py +++ b/CodonTransformer/CodonEvaluation.py @@ -1,18 +1,17 @@ """ File: CodonEvaluation.py --------------------------- -Includes functions to calculate various evaluation metrics along with helper functions. +Includes functions to calculate various evaluation metrics along with helper +functions. """ -import pandas as pd +from typing import Dict, List, Tuple +import pandas as pd from CAI import CAI, relative_adaptiveness - -from typing import List, Dict, Tuple from tqdm import tqdm - def get_CSI_weights(sequences: List[str]) -> Dict[str, float]: """ Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences. @@ -47,7 +46,7 @@ def get_organism_to_CSI_weights( Calculate the Codon Similarity Index (CSI) weights for a list of organisms. Args: - dataset (pd.DataFrame): The dataset containing organism and DNA sequence information. + dataset (pd.DataFrame): Dataset containing organism and DNA sequence info. organisms (List[str]): List of organism names. Returns: @@ -91,7 +90,8 @@ def get_cfd( Args: dna (str): The DNA sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequency distribution per amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequency distribution per amino acid. threshold (float): Frequency threshold for counting rare codons. Returns: @@ -127,7 +127,8 @@ def get_min_max_percentage( Args: dna (str): The DNA sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequency distribution per amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequency distribution per amino acid. window_size (int): Size of the window to calculate %MinMax. Returns: @@ -147,14 +148,12 @@ def get_min_max_percentage( # Iterate through the DNA sequence using the specified window size for i in range(len(codons) - window_size + 1): - codon_window = codons[ - i : i + window_size - ] # List of the codons in the current window + codon_window = codons[i : i + window_size] # Codons in the current window Actual = 0.0 # Average of the actual codon frequencies Max = 0.0 # Average of the min codon frequencies Min = 0.0 # Average of the max codon frequencies - Avg = 0.0 # Average of the averages of all the frequencies associated with each amino acid + Avg = 0.0 # Average of the averages of all frequencies for each amino acid # Sum the frequencies for codons in the current window for codon in codon_window: @@ -210,7 +209,7 @@ def sum_up_to(x): return x + sum_up_to(x - 1) def f(x): - """Function that returns 4 if x is greater than or equal to 4, else returns x.""" + """Returns 4 if x is greater than or equal to 4, else returns x.""" if x >= 4: return 4 elif x < 4: @@ -242,8 +241,10 @@ def get_sequence_similarity( Args: original (str): The original sequence. predicted (str): The predicted sequence. - truncate (bool): If True, truncate the original sequence to match the length of the predicted sequence. - window_length (int): Length of the window for comparison (1 for amino acids, 3 for codons). + truncate (bool): If True, truncate the original sequence to match the length + of the predicted sequence. + window_length (int): Length of the window for comparison (1 for amino acids, + 3 for codons). Returns: float: The sequence similarity as a percentage. diff --git a/CodonTransformer/CodonJupyter.py b/CodonTransformer/CodonJupyter.py index 6c730cd..fdf7164 100644 --- a/CodonTransformer/CodonJupyter.py +++ b/CodonTransformer/CodonJupyter.py @@ -7,13 +7,13 @@ from typing import Dict, List, Tuple import ipywidgets as widgets -from IPython.display import display, HTML +from IPython.display import HTML, display from CodonTransformer.CodonUtils import ( - DNASequencePrediction, COMMON_ORGANISMS, - ORGANISM2ID, ID2ORGANISM, + ORGANISM2ID, + DNASequencePrediction, ) @@ -49,11 +49,11 @@ def create_styled_options( organism_id = organism2id[organism] if is_fine_tuned: if organism_id < 10: - styled_options.append(f"\u200B{organism_id:>6}. {organism}") + styled_options.append(f"\u200b{organism_id:>6}. {organism}") elif organism_id < 100: - styled_options.append(f"\u200B{organism_id:>5}. {organism}") + styled_options.append(f"\u200b{organism_id:>5}. {organism}") else: - styled_options.append(f"\u200B{organism_id:>4}. {organism}") + styled_options.append(f"\u200b{organism_id:>4}. {organism}") else: if organism_id < 10: styled_options.append(f"{organism_id:>6}. {organism}") @@ -160,7 +160,7 @@ def get_dropdown_style() -> str: flex-direction: column; align-items: flex-start; } - .widget-dropdown option[value^="\u200B"] { + .widget-dropdown option[value^="\u200b"] { font-family: sans-serif; font-weight: bold; font-size: 18px; @@ -188,7 +188,8 @@ def display_organism_dropdown(container: UserContainer) -> None: """ dropdown = create_organism_dropdown(container) header = widgets.HTML( - 'Select Organism:
' + 'Select Organism:' + '
' ) container_widget = widgets.VBox( [header, dropdown], @@ -242,7 +243,8 @@ def save_protein(change: Dict[str, str]) -> None: Save the input protein sequence to the container. Args: - change (Dict[str, str]): A dictionary containing information about the change in textarea value. + change (Dict[str, str]): A dictionary containing information about + the change in textarea value. """ container.protein = ( change["new"] @@ -258,7 +260,8 @@ def save_protein(change: Dict[str, str]) -> None: # Display the input widget header = widgets.HTML( - 'Enter Protein Sequence:
' + 'Enter Protein Sequence:' + '
' ) container_widget = widgets.VBox( [header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px") @@ -270,13 +273,13 @@ def save_protein(change: Dict[str, str]) -> None: def format_model_output(output: DNASequencePrediction) -> str: """ - Format the DNA sequence prediction output in a visually appealing and easy-to-read manner. + Format DNA sequence prediction output in an appealing and easy-to-read manner. This function takes the prediction output and formats it into a structured string with clear section headers and separators. Args: - output (DNASequencePrediction): An object containing the prediction output. + output (DNASequencePrediction): Object containing the prediction output. Expected attributes: - organism (str): The organism name. - protein (str): The input protein sequence. diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 7df8520..a4d5434 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -5,27 +5,28 @@ helper functions related to processing data for passing to the model. """ -from typing import Any, List, Dict, Tuple, Optional, Union -import onnxruntime as rt +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np +import onnxruntime as rt import torch import transformers from transformers import ( + AutoTokenizer, BatchEncoding, - PreTrainedTokenizerFast, BigBirdConfig, - AutoTokenizer, BigBirdForMaskedLM, + PreTrainedTokenizerFast, ) -import numpy as np from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( AMINO_ACIDS, - ORGANISM2ID, - TOKEN2INDEX, INDEX2TOKEN, NUM_ORGANISMS, + ORGANISM2ID, + TOKEN2INDEX, DNASequencePrediction, ) @@ -37,6 +38,7 @@ def predict_dna_sequence( tokenizer: Union[str, PreTrainedTokenizerFast] = None, model: Union[str, torch.nn.Module] = None, attention_type: str = "original_full", + deterministic: bool = True, ) -> DNASequencePrediction: """ Predict the DNA sequence for a given protein using CodonTransformer model. @@ -47,30 +49,38 @@ def predict_dna_sequence( Args: protein (str): The input protein sequence to predict the DNA sequence for. - organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). - If a string is provided, it will be converted to the corresponding ID using ORGANISM2ID. + organism (Union[int, str]): Either the ID of the organism or its name (e.g., + "Escherichia coli general"). If a string is provided, it will be converted + to the corresponding ID using ORGANISM2ID. device (torch.device): The device (CPU or GPU) to run the model on. - tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, - a pre-loaded tokenizer object, or None. If None, it will be loaded from HuggingFace. Defaults to None. - model (Union[str, torch.nn.Module, None], optional): Either a file path to load the model from, - a pre-loaded model object, or None. If None, it will be loaded from HuggingFace. Defaults to None. - attention_type (str, optional): The type of attention mechanism to use in the model. - Can be either 'block_sparse' or 'original_full'. Defaults to "original_full". + tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file + path to load the tokenizer from, a pre-loaded tokenizer object, or None. If + None, it will be loaded from HuggingFace. Defaults to None. + model (Union[str, torch.nn.Module, None], optional): Either a file path to load + the model from, a pre-loaded model object, or None. If None, it will be + loaded from HuggingFace. Defaults to None. + attention_type (str, optional): The type of attention mechanism to use in the + model. Can be either 'block_sparse' or 'original_full'. Defaults to + "original_full". + deterministic (bool, optional): Whether to use deterministic decoding (most + likely tokens). If False, samples tokens according to their probabilities. + Defaults to True. Returns: DNASequencePrediction: An object containing the prediction results: - - organism (str): The name of the organism used for prediction. - - protein (str): The input protein sequence for which DNA sequence is predicted. - - processed_input (str): The processed input sequence (merged protein and DNA). - - predicted_dna (str): The predicted DNA sequence. + - organism (str): Name of the organism used for prediction. + - protein (str): Input protein sequence for which DNA sequence is predicted. + - processed_input (str): Processed input sequence (merged protein and DNA). + - predicted_dna (str): Predicted DNA sequence. Raises: ValueError: If the protein sequence is empty or if the organism is invalid. Note: - This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from CodonTransformer.CodonUtils. - ORGANISM2ID maps organism names to their corresponding IDs. - INDEX2TOKEN maps model output indices (token ids) to respective codons. + This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from + CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their + corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to + respective codons. Example: >>> import torch @@ -83,30 +93,41 @@ def predict_dna_sequence( >>> >>> # Load tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") - >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device) + >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") + >>> model = model.to(device) >>> >>> # Define protein sequence and organism - >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" + >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA" >>> organism = "Escherichia coli general" >>> - >>> # Predict DNA sequence + >>> # Predict DNA sequence with deterministic decoding >>> output = predict_dna_sequence( ... protein=protein, ... organism=organism, ... device=device, ... tokenizer=tokenizer, ... model=model, - ... attention_type="original_full" + ... attention_type="original_full", + ... deterministic=True + ... ) + >>> + >>> # Predict DNA sequence with probabilistic sampling + >>> output_sampler = predict_dna_sequence( + ... protein=protein, + ... organism=organism, + ... device=device, + ... tokenizer=tokenizer, + ... model=model, + ... attention_type="original_full", + ... deterministic=False ... ) >>> >>> print(format_model_output(output)) + >>> print(format_model_output(output_sampler)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") - if not isinstance(protein, str): - raise ValueError("Protein sequence must be a string.") - # Test that the input protein sequence contains only valid amino acids if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") @@ -139,12 +160,23 @@ def predict_dna_sequence( # Get the model predictions output_dict = model(**tokenized_input, return_dict=True) - output = output_dict.logits.detach().cpu().numpy() + logits = output_dict.logits.detach().cpu() # Decode the predicted DNA sequence from the model output - predicted_dna = list( - map(INDEX2TOKEN.__getitem__, output.argmax(axis=-1).squeeze().tolist()) - ) + if deterministic: + # Select the most probable tokens (argmax) + predicted_indices = logits.argmax(dim=-1).squeeze().tolist() + else: + # Sample tokens according to their probability distribution + # Convert logits to probabilities using softmax + probabilities = torch.softmax(logits, dim=-1) + # Sample from the probability distribution at each position + probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] + predicted_indices = ( + torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist() + ) + + predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) # Skip special tokens [CLS] and [SEP] to create the predicted_dna predicted_dna = ( @@ -170,17 +202,21 @@ def load_model( Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace. Args: - model_path (Optional[str]): Path to the model file or checkpoint. If None, load from HuggingFace. + model_path (Optional[str]): Path to the model file or checkpoint. If None, + load from HuggingFace. device (torch.device, optional): The device to load the model onto. - attention_type (str, optional): The type of attention, 'block_sparse' or 'original_full'. - num_organisms (int, optional): Number of organisms, needed if loading from a checkpoint that requires this. - remove_prefix (bool, optional): Whether to remove the "model." prefix from the keys in the state dict. + attention_type (str, optional): The type of attention, 'block_sparse' + or 'original_full'. + num_organisms (int, optional): Number of organisms, needed if loading from a + checkpoint that requires this. + remove_prefix (bool, optional): Whether to remove the "model." prefix from the + keys in the state dict. Returns: torch.nn.Module: The loaded model. """ if not model_path: - print("Warning: Model path not provided. Loading from HuggingFace.") + warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning) model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") elif model_path.endswith(".ckpt"): @@ -209,7 +245,8 @@ def load_model( else: raise ValueError( - "Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace." + "Unsupported file type. Please provide a .ckpt or .pt file, " + "or None to load from HuggingFace." ) # Prepare model for evaluation @@ -260,16 +297,19 @@ def create_model_from_checkpoint( def load_tokenizer(tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast: """ - Create and return a tokenizer object from the given tokenizer path or HuggingFace. + Create and return a tokenizer object from tokenizer path or HuggingFace. Args: - tokenizer_path (Optional[str]): Path to the tokenizer file. If None, load from HuggingFace. + tokenizer_path (Optional[str]): Path to the tokenizer file. If None, + load from HuggingFace. Returns: PreTrainedTokenizerFast: The tokenizer object. """ if not tokenizer_path: - print("Warning: Tokenizer path not provided. Loading from HuggingFace.") + warnings.warn( + "Tokenizer path not provided. Loading from HuggingFace.", UserWarning + ) return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") return transformers.PreTrainedTokenizerFast( @@ -291,11 +331,14 @@ def tokenize( ) -> BatchEncoding: """ Return the tokenized sequences given a batch of input data. - Each data in the batch is expected to be a dictionary with "codons" and "organism" keys. + Each data in the batch is expected to be a dictionary with "codons" and + "organism" keys. Args: - batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and "organism" keys. - tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or path to the tokenizer file. + batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and + "organism" keys. + tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or + path to the tokenizer file. max_len (int, optional): Maximum length of the tokenized sequence. Returns: @@ -326,28 +369,32 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: """ Validate and convert the organism input to both ID and name. - This function takes either an organism ID or name as input and returns both the ID and name. - It performs validation to ensure the input corresponds to a valid organism in the ORGANISM2ID dictionary. + This function takes either an organism ID or name as input and returns both + the ID and name. It performs validation to ensure the input corresponds to + a valid organism in the ORGANISM2ID dictionary. Args: - organism (Union[int, str]): Either the ID of the organism (int) or its name (str). + organism (Union[int, str]): Either the ID of the organism (int) or its + name (str). Returns: Tuple[int, str]: A tuple containing the organism ID (int) and name (str). Raises: - ValueError: If the input is neither a string nor an integer, if the organism name is not found - in ORGANISM2ID, if the organism ID is not a value in ORGANISM2ID, or if no name - is found for a given ID. + ValueError: If the input is neither a string nor an integer, if the + organism name is not found in ORGANISM2ID, if the organism ID is not a + value in ORGANISM2ID, or if no name is found for a given ID. Note: - This function relies on the ORGANISM2ID dictionary imported from CodonTransformer.CodonUtils, - which maps organism names to their corresponding IDs. + This function relies on the ORGANISM2ID dictionary imported from + CodonTransformer.CodonUtils, which maps organism names to their + corresponding IDs. """ if isinstance(organism, str): if organism not in ORGANISM2ID: raise ValueError( - f"Invalid organism name: {organism}. Please use a valid organism name or ID." + f"Invalid organism name: {organism}. " + "Please use a valid organism name or ID." ) organism_id = ORGANISM2ID[organism] organism_name = organism @@ -355,7 +402,8 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: elif isinstance(organism, int): if organism not in ORGANISM2ID.values(): raise ValueError( - f"Invalid organism ID: {organism}. Please use a valid organism name or ID." + f"Invalid organism ID: {organism}. " + "Please use a valid organism name or ID." ) organism_id = organism @@ -365,9 +413,6 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: if organism_name is None: raise ValueError(f"No organism name found for ID: {organism}") - else: - raise ValueError("Organism must be either a string (name) or an integer (ID).") - return organism_id, organism_name @@ -375,12 +420,13 @@ def get_high_frequency_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using High Frequency Choice (HFC) approach in which - the most frequent codon for a given amino acid is always chosen. + Return the DNA sequence optimized using High Frequency Choice (HFC) approach + in which the most frequent codon for a given amino acid is always chosen. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -400,7 +446,8 @@ def precompute_most_frequent_codons( Precompute the most frequent codon for each amino acid. Args: - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: Dict[str, str]: The most frequent codon for each amino acid. @@ -421,7 +468,8 @@ def get_high_frequency_choice_sequence_optimized( Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -436,17 +484,20 @@ def get_background_frequency_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using Background Frequency Choice (BFC) approach in which - a random codon for a given amino acid is chosen using the codon frequencies probability distribution. + Return the DNA sequence optimized using Background Frequency Choice (BFC) + approach in which a random codon for a given amino acid is chosen using + the codon frequencies probability distribution. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. """ - # Select a random codon for each amino acid based on the codon frequencies probability distribution + # Select a random codon for each amino acid based on the codon frequencies + # probability distribution dna_codons = [ np.random.choice( codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1] @@ -463,7 +514,8 @@ def precompute_cdf( Precompute the cumulative distribution function (CDF) for each amino acid. Args: - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid. @@ -486,7 +538,8 @@ def get_background_frequency_choice_sequence_optimized( Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -507,12 +560,14 @@ def get_uniform_random_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using Uniform Random Choice (URC) approach in which - a random codon for a given amino acid is chosen using a uniform prior. + Return the DNA sequence optimized using Uniform Random Choice (URC) approach + in which a random codon for a given amino acid is chosen using a uniform + prior. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -529,7 +584,8 @@ def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> st Return the optimized codon sequence for the given protein sequence using ICOR. Credit: ICOR: improving codon optimization with recurrent neural networks - Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas Densmore + Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas + Densmore Args: input_seq (str): The input protein sequence. diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index cf84b24..dd91e3f 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -4,18 +4,16 @@ Includes constants and helper functions used by other Python scripts. """ +import itertools import os -import re import pickle -import requests -import itertools - -import torch -import pandas as pd +import re from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple -from typing import Any, List, Dict, Tuple, Optional, Iterator - +import pandas as pd +import requests +import torch # List of all amino acids AMINO_ACIDS: List[str] = [ @@ -470,10 +468,10 @@ class DNASequencePrediction: A class to hold the output of the DNA sequence prediction. Attributes: - organism (str): The name of the organism used for prediction. - protein (str): The input protein sequence for which DNA sequence is predicted. - processed_input (str): The processed input sequence (merged protein and DNA). - predicted_dna (str): The predicted DNA sequence. + organism (str): Name of the organism used for prediction. + protein (str): Input protein sequence for which DNA sequence is predicted. + processed_input (str): Processed input sequence (merged protein and DNA). + predicted_dna (str): Predicted DNA sequence. """ organism: str @@ -488,7 +486,8 @@ class IterableData(torch.utils.data.IterableDataset): data) in parallel multi-processing environments, e.g., multi-GPU. Args: - dist_env (Optional[str]): The distribution environment identifier (e.g., "slurm"). + dist_env (Optional[str]): The distribution environment identifier + (e.g., "slurm"). Credit: Guillaume Filion """ @@ -501,7 +500,7 @@ def __init__(self, dist_env: Optional[str] = None): @property def iterator(self) -> Iterator: - """Define the stream logic for the dataset. Should be implemented in subclasses.""" + """Define the stream logic for the dataset. Implement in subclasses.""" raise NotImplementedError def __iter__(self) -> Iterator: @@ -573,7 +572,8 @@ def save_python_object_to_disk(input_object: Any, file_path: str) -> None: def find_pattern_in_fasta(keyword: str, text: str) -> str: """ - Find a specific keyword pattern in text. Helpful for identifying parts of a FASTA sequence. + Find a specific keyword pattern in text. Helpful for identifying parts + of a FASTA sequence. Args: keyword (str): The keyword pattern to search for. @@ -589,17 +589,19 @@ def find_pattern_in_fasta(keyword: str, text: str) -> str: def get_organism2id_dict(organism_reference: str) -> Dict[str, int]: """ - Return a dictionary mapping each organism in training data to an index used for training. + Return a dictionary mapping each organism in training data to an index + used for training. Args: - organism_reference (str): Path to a CSV file containing a list of all organisms. - The format of the CSV file should be as follows: - 0,Escherichia coli - 1,Homo sapiens - 2,Mus musculus + organism_reference (str): Path to a CSV file containing a list of + all organisms. The format of the CSV file should be as follows: + + 0,Escherichia coli + 1,Homo sapiens + 2,Mus musculus Returns: - Dict[str, int]: A dictionary mapping organism names to their respective indices. + Dict[str, int]: Dictionary mapping organism names to their respective indices. """ # Read the CSV file and create a dictionary mapping organisms to their indices organisms = pd.read_csv(organism_reference, index_col=0, header=None) diff --git a/README.md b/README.md index e1b4304..6b0b7a3 100644 --- a/README.md +++ b/README.md @@ -115,14 +115,14 @@ The package requires `python>=3.9`. The requirements are [availabe here](require To finetune CodonTransformer on your own data, follow these steps: 1. **Prepare your dataset** - + Create a CSV file with the following columns: - `dna`: DNA sequences (string, preferably uppercase ATCG) - `protein`: Protein sequences (string, preferably uppercase amino acid letters) - `organism`: Target organism (string or int, must be from `ORGANISM2ID` in `CodonUtils`) - Note: + Note: - Use organisms from the `FINE_TUNE_ORGANISMS` list for best results. - For E. coli, use `Escherichia coli general`. - DNA sequences should ideally contain only A, T, C, and G. Ambiguous codons are replaced with 'UNK' for tokenization. @@ -132,7 +132,7 @@ To finetune CodonTransformer on your own data, follow these steps:
2. **Prepare training data** - + Use the `prepare_training_data` function from `CodonData` to prepare training data from your dataset. ```python @@ -142,7 +142,7 @@ To finetune CodonTransformer on your own data, follow these steps:
3. **Run the finetuning script** - + Execute finetune.py with appropriate arguments: ```bash python finetune.py \ diff --git a/finetune.py b/finetune.py index dfc9dc3..fce46f3 100644 --- a/finetune.py +++ b/finetune.py @@ -4,22 +4,22 @@ Finetune the CodonTransformer model. The pretrained model is loaded directly from Hugging Face. -The dataset is a JSON file. You can use prepare_training_data from CodonData to prepare the dataset. -The repository Readme has a guide on how to prepare the dataset and use this script. +The dataset is a JSON file. You can use prepare_training_data from CodonData to +prepare the dataset. The repository README has a guide on how to prepare the +dataset and use this script. """ -import os -import torch import argparse +import os import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader - from transformers import AutoTokenizer, BigBirdForMaskedLM from CodonTransformer.CodonUtils import ( - TOKEN2MASK, MAX_LEN, + TOKEN2MASK, IterableJSONData, ) diff --git a/pretrain.py b/pretrain.py index bafd738..1b253cc 100644 --- a/pretrain.py +++ b/pretrain.py @@ -3,23 +3,23 @@ ------------------- Pretrain the CodonTransformer model. -The dataset is a JSON file. You can use prepare_training_data from CodonData to prepare the dataset. -The repository Readme has a guide on how to prepare the dataset and use this script. +The dataset is a JSON file. You can use prepare_training_data from CodonData to +prepare the dataset. The repository README has a guide on how to prepare the +dataset and use this script. """ -import os -import torch import argparse +import os import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader - -from transformers import PreTrainedTokenizerFast, BigBirdConfig, BigBirdForMaskedLM +from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast from CodonTransformer.CodonUtils import ( - TOKEN2MASK, - NUM_ORGANISMS, MAX_LEN, + NUM_ORGANISMS, + TOKEN2MASK, IterableJSONData, ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e74248e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[tool.poetry] +name = "CodonTransformer" +version = "1.5.2" +description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." +authors = ["Adibvafa Fallahpour "] +license = "Apache-2.0" +readme = "README.md" +homepage = "https://github.com/adibvafa/CodonTransformer" +repository = "https://github.com/adibvafa/CodonTransformer" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] + +[tool.poetry.dependencies] +python = "^3.9" +biopython = "^1.83" +ipywidgets = "^7.0.0" +numpy = "^1.26.4" +onnxruntime = "^1.17.3" +pandas = "^2.0.0" +python_codon_tables = "^0.1.12" +pytorch_lightning = "^2.2.1" +scikit-learn = "^1.2.2" +scipy = "^1.13.1" +setuptools = "^70.0.0" +torch = "^2.0.0" +tqdm = "^4.66.2" +transformers = "^4.40.0" +CAI-PyPI = "^2.0.1" + +[tool.poetry.dev-dependencies] +coverage = "^7.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 88 +indent-width = 4 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/requirements.txt b/requirements.txt index b02f5d5..61b3643 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ biopython>=1.83,<2.0 +CAI-PyPI>=2.0.1,<3.0 ipywidgets>=7.0.0,<10.0 numpy>=1.26.4,<3.0 onnxruntime>=1.17.3,<3.0 @@ -11,4 +12,3 @@ setuptools>=70.0.0 torch>=2.0.0,<3.0 tqdm>=4.66.2,<5.0 transformers>=4.40.0,<5.0 -CAI-PyPI>=2.0.1,<3.0 \ No newline at end of file diff --git a/setup.py b/setup.py index e4e9648..6ca0e55 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # setup.py import os -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def read_requirements(): @@ -23,7 +24,11 @@ def read_readme(): install_requires=read_requirements(), author="Adibvafa Fallahpour", author_email="Adibvafa.fallahpour@mail.utoronto.ca", - description="The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms.", + description=( + "The ultimate tool for codon optimization, " + "transforming protein sequences into optimized DNA sequences " + "specific for your target organisms." + ), long_description=read_readme(), long_description_content_type="text/markdown", url="https://github.com/adibvafa/CodonTransformer", diff --git a/slurm/finetune.sh b/slurm/finetune.sh index dbaffba..980be12 100644 --- a/slurm/finetune.sh +++ b/slurm/finetune.sh @@ -35,4 +35,4 @@ stdbuf -oL -eL srun python finetune.py \ --learning_rate 0.00005 \ --warmup_fraction 0.1 \ --save_every_n_steps 512 \ - --seed 123 \ No newline at end of file + --seed 123 diff --git a/slurm/pretrain.sh b/slurm/pretrain.sh index 0935499..50d31b2 100644 --- a/slurm/pretrain.sh +++ b/slurm/pretrain.sh @@ -34,4 +34,4 @@ stdbuf -oL -eL srun python pretrain.py \ --learning_rate 0.00005 \ --warmup_fraction 0.1 \ --save_interval 5 \ - --seed 123 \ No newline at end of file + --seed 123 diff --git a/src/CodonTransformerTokenizer.json b/src/CodonTransformerTokenizer.json index 0d27966..7db063d 100644 --- a/src/CodonTransformerTokenizer.json +++ b/src/CodonTransformerTokenizer.json @@ -1 +1 @@ -{"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "special": true, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 1, "special": true, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 2, "special": true, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 3, "special": true, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 4, "special": true, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}], "normalizer": {"type": "Sequence", "normalizers": [{"type": "Lowercase"}]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"String": " "}, "behavior": "Isolated", "invert": false}, {"type": "Whitespace"}]}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": null, "model": {"type": "WordPiece", "unk_token": "[UNK]", "continuing_subword_prefix": "##", "max_input_chars_per_word": 100, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "a_unk": 5, "c_unk": 6, "d_unk": 7, "e_unk": 8, "f_unk": 9, "g_unk": 10, "h_unk": 11, "i_unk": 12, "k_unk": 13, "l_unk": 14, "m_unk": 15, "n_unk": 16, "p_unk": 17, "q_unk": 18, "r_unk": 19, "s_unk": 20, "t_unk": 21, "v_unk": 22, "w_unk": 23, "y_unk": 24, "__unk": 25, "k_aaa": 26, "n_aac": 27, "k_aag": 28, "n_aat": 29, "t_aca": 30, "t_acc": 31, "t_acg": 32, "t_act": 33, "r_aga": 34, "s_agc": 35, "r_agg": 36, "s_agt": 37, "i_ata": 38, "i_atc": 39, "m_atg": 40, "i_att": 41, "q_caa": 42, "h_cac": 43, "q_cag": 44, "h_cat": 45, "p_cca": 46, "p_ccc": 47, "p_ccg": 48, "p_cct": 49, "r_cga": 50, "r_cgc": 51, "r_cgg": 52, "r_cgt": 53, "l_cta": 54, "l_ctc": 55, "l_ctg": 56, "l_ctt": 57, "e_gaa": 58, "d_gac": 59, "e_gag": 60, "d_gat": 61, "a_gca": 62, "a_gcc": 63, "a_gcg": 64, "a_gct": 65, "g_gga": 66, "g_ggc": 67, "g_ggg": 68, "g_ggt": 69, "v_gta": 70, "v_gtc": 71, "v_gtg": 72, "v_gtt": 73, "__taa": 74, "y_tac": 75, "__tag": 76, "y_tat": 77, "s_tca": 78, "s_tcc": 79, "s_tcg": 80, "s_tct": 81, "__tga": 82, "c_tgc": 83, "w_tgg": 84, "c_tgt": 85, "l_tta": 86, "f_ttc": 87, "l_ttg": 88, "f_ttt": 89}}} \ No newline at end of file +{"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "special": true, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 1, "special": true, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 2, "special": true, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 3, "special": true, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 4, "special": true, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}], "normalizer": {"type": "Sequence", "normalizers": [{"type": "Lowercase"}]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"String": " "}, "behavior": "Isolated", "invert": false}, {"type": "Whitespace"}]}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": null, "model": {"type": "WordPiece", "unk_token": "[UNK]", "continuing_subword_prefix": "##", "max_input_chars_per_word": 100, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "a_unk": 5, "c_unk": 6, "d_unk": 7, "e_unk": 8, "f_unk": 9, "g_unk": 10, "h_unk": 11, "i_unk": 12, "k_unk": 13, "l_unk": 14, "m_unk": 15, "n_unk": 16, "p_unk": 17, "q_unk": 18, "r_unk": 19, "s_unk": 20, "t_unk": 21, "v_unk": 22, "w_unk": 23, "y_unk": 24, "__unk": 25, "k_aaa": 26, "n_aac": 27, "k_aag": 28, "n_aat": 29, "t_aca": 30, "t_acc": 31, "t_acg": 32, "t_act": 33, "r_aga": 34, "s_agc": 35, "r_agg": 36, "s_agt": 37, "i_ata": 38, "i_atc": 39, "m_atg": 40, "i_att": 41, "q_caa": 42, "h_cac": 43, "q_cag": 44, "h_cat": 45, "p_cca": 46, "p_ccc": 47, "p_ccg": 48, "p_cct": 49, "r_cga": 50, "r_cgc": 51, "r_cgg": 52, "r_cgt": 53, "l_cta": 54, "l_ctc": 55, "l_ctg": 56, "l_ctt": 57, "e_gaa": 58, "d_gac": 59, "e_gag": 60, "d_gat": 61, "a_gca": 62, "a_gcc": 63, "a_gcg": 64, "a_gct": 65, "g_gga": 66, "g_ggc": 67, "g_ggg": 68, "g_ggt": 69, "v_gta": 70, "v_gtc": 71, "v_gtg": 72, "v_gtt": 73, "__taa": 74, "y_tac": 75, "__tag": 76, "y_tat": 77, "s_tca": 78, "s_tcc": 79, "s_tcg": 80, "s_tct": 81, "__tga": 82, "c_tgc": 83, "w_tgg": 84, "c_tgt": 85, "l_tta": 86, "f_ttc": 87, "l_ttg": 88, "f_ttt": 89}}} diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py index ca7d464..42342c9 100644 --- a/tests/test_CodonData.py +++ b/tests/test_CodonData.py @@ -1,13 +1,15 @@ import tempfile import unittest + import pandas as pd +from Bio.Data.CodonTable import TranslationError + from CodonTransformer.CodonData import ( - read_fasta_file, build_amino2codon_skeleton, get_amino_acid_sequence, is_correct_seq, + read_fasta_file, ) -from Bio.Data.CodonTable import TranslationError class TestCodonData(unittest.TestCase): diff --git a/tests/test_CodonJupyter.py b/tests/test_CodonJupyter.py index 0ff6afa..6816813 100644 --- a/tests/test_CodonJupyter.py +++ b/tests/test_CodonJupyter.py @@ -1,13 +1,15 @@ import unittest + import ipywidgets + from CodonTransformer.CodonJupyter import ( + DNASequencePrediction, UserContainer, - create_organism_dropdown, create_dropdown_options, + create_organism_dropdown, display_organism_dropdown, display_protein_input, format_model_output, - DNASequencePrediction, ) from CodonTransformer.CodonUtils import ORGANISM2ID diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 199c383..3617c2b 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,50 +1,88 @@ import unittest +import warnings + import torch + from CodonTransformer.CodonPrediction import ( + load_model, + load_tokenizer, predict_dna_sequence, - # add other imported functions or classes as needed ) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class TestCodonPrediction(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Suppress warnings about loading from HuggingFace + for message in [ + "Tokenizer path not provided. Loading from HuggingFace.", + "Model path not provided. Loading from HuggingFace.", + ]: + warnings.filterwarnings("ignore", message=message) + + cls.model = load_model(device=cls.device) + cls.tokenizer = load_tokenizer() + def test_predict_dna_sequence_valid_input(self): - # Test predict_dna_sequence with a valid protein sequence and organism code protein_sequence = "MWWMW" organism = "Escherichia coli general" - result = predict_dna_sequence(protein_sequence, organism, device) - # Test if the output is a string and contains only A, T, C, G. + result = predict_dna_sequence( + protein_sequence, + organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) self.assertEqual(result.predicted_dna, "ATGTGGTGGATGTGGTGA") - def test_predict_dna_sequence_invalid_protein_sequence(self): - # Test predict_dna_sequence with an invalid protein sequence - protein_sequence = "MKTZZFVLLL" # 'Z' is not a valid amino acid + def test_predict_dna_sequence_non_deterministic(self): + protein_sequence = "MFWY" organism = "Escherichia coli general" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) - - def test_predict_dna_sequence_invalid_organism_code(self): - # Test predict_dna_sequence with an invalid organism code - protein_sequence = "MKTFFVLLL" - organism = "Alien $%#@!" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) - - def test_predict_dna_sequence_empty_protein_sequence(self): - # Test predict_dna_sequence with an empty protein sequence - protein_sequence = "" - organism = "Escherichia coli general" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) + num_iterations = 64 + possible_outputs = set() + possible_encodings_wo_stop = { + "ATGTTTTGGTAT", + "ATGTTCTGGTAT", + "ATGTTTTGGTAC", + "ATGTTCTGGTAC", + } - def test_predict_dna_sequence_none_protein_sequence(self): - # Test predict_dna_sequence with None as protein sequence - protein_sequence = None - organism = "Escherichia coli general" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) + for _ in range(num_iterations): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + ) + possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon + + self.assertEqual(possible_outputs, possible_encodings_wo_stop) + + def test_predict_dna_sequence_invalid_inputs(self): + test_cases = [ + ("MKTZZFVLLL", "Escherichia coli general", "invalid protein sequence"), + ("MKTFFVLLL", "Alien $%#@!", "invalid organism code"), + ("", "Escherichia coli general", "empty protein sequence"), + ] + + for protein_sequence, organism, error_type in test_cases: + with self.subTest(error_type=error_type): + with self.assertRaises(ValueError): + predict_dna_sequence( + protein_sequence, + organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) if __name__ == "__main__": diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py index c7faf11..1d9a94c 100644 --- a/tests/test_CodonUtils.py +++ b/tests/test_CodonUtils.py @@ -2,14 +2,15 @@ import pickle import tempfile import unittest + from CodonTransformer.CodonUtils import ( - load_python_object_from_disk, - save_python_object_to_disk, find_pattern_in_fasta, get_organism2id_dict, get_taxonomy_id, - sort_amino2codon_skeleton, load_pkl_from_url, + load_python_object_from_disk, + save_python_object_to_disk, + sort_amino2codon_skeleton, ) @@ -32,7 +33,10 @@ def test_save_python_object_to_disk(self): os.remove(temp_file_name) def test_find_pattern_in_fasta(self): - text = ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n>seq2 [keyword=value2]\nGGTACGATCGATCGAT" + text = ( + ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n" + ">seq2 [keyword=value2]\nGGTACGATCGATCGAT" + ) self.assertEqual(find_pattern_in_fasta("keyword", text), "value1") self.assertEqual(find_pattern_in_fasta("nonexistent", text), "")