diff --git a/.gitignore b/.gitignore index 0a03dca..46f2a2f 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,4 @@ lightning_logs/ # PyTorch model weights *.pth -*.pt \ No newline at end of file +*.pt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e0c0e74 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,34 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +# Don't run pre-commit on files under third-party/ +exclude: "^\ + (third-party/.*)\ + " + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: check-added-large-files # prevents giant files from being committed. + - id: check-case-conflict # checks for files that would conflict in case-insensitive filesystems. + - id: check-merge-conflict # checks for files that contain merge conflict strings. + - id: check-yaml # checks yaml files for parseable syntax. + - id: detect-private-key # detects the presence of private keys. + - id: end-of-file-fixer # ensures that a file is either empty, or ends with one newline. + - id: fix-byte-order-marker # removes utf-8 byte order marker. + - id: mixed-line-ending # replaces or checks mixed line ending. + - id: requirements-txt-fixer # sorts entries in requirements.txt. + - id: trailing-whitespace # trims trailing whitespace. + + - repo: https://github.com/sirosen/check-jsonschema + rev: 0.23.2 + hooks: + - id: check-github-actions + - id: check-github-workflows + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.13 + hooks: + - id: ruff + - id: ruff-format diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 5932ff8..b6f6d86 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -5,32 +5,30 @@ preparing the data for training and inference of the CodonTransformer model. """ -import os import json +import os +from typing import Dict, List, Optional, Tuple, Union + import pandas as pd +import python_codon_tables as pct +from Bio import SeqIO +from Bio.Seq import Seq from sklearn.utils import shuffle as sk_shuffle +from tqdm import tqdm from CodonTransformer.CodonUtils import ( + AMBIGUOUS_AMINOACID_MAP, + AMINO2CODON_TYPE, AMINO_ACIDS, + ORGANISM2ID, START_CODONS, STOP_CODONS, STOP_SYMBOL, - AMINO2CODON_TYPE, - AMBIGUOUS_AMINOACID_MAP, - ORGANISM2ID, find_pattern_in_fasta, - sort_amino2codon_skeleton, get_taxonomy_id, + sort_amino2codon_skeleton, ) -from Bio import SeqIO -from Bio.Seq import Seq - -import python_codon_tables as pct - -from typing import List, Dict, Tuple, Union, Optional -from tqdm import tqdm - def prepare_training_data( dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True @@ -50,7 +48,8 @@ def prepare_training_data( Args: dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format. output_file (str): Path to save the output JSON dataset. - shuffle (bool, optional): Whether to shuffle the dataset before saving. Defaults to True. + shuffle (bool, optional): Whether to shuffle the dataset before saving. + Defaults to True. Returns: None @@ -78,7 +77,7 @@ def prepare_training_data( def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None: """ - Convert a pandas DataFrame to a JSON file format suitable for training CodonTransformer. + Convert pandas DataFrame to JSON file format suitable for training CodonTransformer. This function takes a preprocessed DataFrame and writes it to a JSON file where each line is a JSON object representing a single record. @@ -86,7 +85,8 @@ def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) Args: df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns. output_file (str): Path to the output JSON file. - shuffle (bool, optional): Whether to shuffle the dataset before saving. Defaults to True. + shuffle (bool, optional): Whether to shuffle the dataset before saving. + Defaults to True. Returns: None @@ -123,8 +123,9 @@ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) It validates the input against a provided mapping of organism names to IDs. Args: - organism (Union[str, int]): The input organism, either as a name (str) or ID (int). - organism_to_id (Dict[str, int]): A dictionary mapping organism names to their corresponding IDs. + organism (Union[str, int]): Input organism, either as a name (str) or ID (int). + organism_to_id (Dict[str, int]): Dictionary mapping organism names to their + corresponding IDs. Returns: int: The validated organism ID. @@ -150,7 +151,8 @@ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) def preprocess_protein_sequence(protein: str) -> str: """ - Preprocess a protein sequence by cleaning, standardizing, and handling ambiguous amino acids. + Preprocess a protein sequence by cleaning, standardizing, and handling + ambiguous amino acids. Args: protein (str): The input protein sequence. @@ -221,7 +223,8 @@ def replace_ambiguous_codons(dna: str) -> str: def preprocess_dna_sequence(dna: str) -> str: """ - Cleans and preprocesses a DNA sequence by standardizing it and replacing ambiguous codons. + Cleans and preprocesses a DNA sequence by standardizing it and replacing + ambiguous codons. Args: dna (str): The DNA sequence to preprocess. @@ -247,8 +250,9 @@ def preprocess_dna_sequence(dna: str) -> str: def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: """ - Return the merged sequence of protein amino acids and DNA codons in the form of tokens - separated by space, where each token is composed of an amino acid + separator + codon. + Return the merged sequence of protein amino acids and DNA codons in the form + of tokens separated by space, where each token is composed of an amino acid + + separator + codon. Args: protein (str): Protein sequence. @@ -274,8 +278,9 @@ def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: # Check if the length of protein and dna sequences are equal if len(dna) > 0 and len(protein) != len(dna) / 3: raise ValueError( - 'Length of protein (including stop symbol such as "_") and \ - the number of codons in DNA sequence (including stop codon) must be equal.' + 'Length of protein (including stop symbol such as "_") and ' + "the number of codons in DNA sequence (including stop codon) " + "must be equal." ) # Merge protein and DNA sequences into tokens @@ -331,8 +336,8 @@ def get_amino_acid_sequence( return_correct_seq (bool): Whether to return if the sequence is correct. Returns: - Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if return_correct_seq is True, - otherwise just the protein sequence. + Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if + return_correct_seq is True, otherwise just the protein sequence. """ dna_seq = Seq(dna).strip() @@ -365,12 +370,15 @@ def read_fasta_file( Args: input_file (str): Path to the input FASTA file. - save_to_file (Optional[str]): Path to save the output DataFrame. If None, data is only returned. - organism (str): Name of the organism. If empty, it will be extracted from the FASTA description. + save_to_file (Optional[str]): Path to save the output DataFrame. If None, + data is only returned. + organism (str): Name of the organism. If empty, it will be extracted from + the FASTA description. buffer_size (int): Number of records to process before writing to file. Returns: - pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe is True, else None. + pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe + is True, else None. Raises: FileNotFoundError: If the input file does not exist. @@ -498,7 +506,8 @@ def download_codon_frequencies_from_kazusa( def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: """ - Return the empty skeleton of the amino2codon dictionary, needed for get_codon_frequencies. + Return the empty skeleton of the amino2codon dictionary, needed for + get_codon_frequencies. Args: organism (str): Name of the organism. @@ -514,7 +523,8 @@ def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: return_correct_seq=False, ) - # Initialize the amino2codon skeleton with all possible codons and set their frequencies to 0 + # Initialize the amino2codon skeleton with all possible codons and set their + # frequencies to 0 for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)): if amino not in amino2codon: amino2codon[amino] = ([], []) @@ -543,7 +553,8 @@ def get_codon_frequencies( organism (Optional[str]): Name of the organism. Returns: - AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons and frequencies. + AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons + and frequencies. """ if organism: codon_table = get_codon_table(organism) @@ -583,7 +594,8 @@ def get_organism_to_codon_frequencies( organisms (List[str]): List of organisms. Returns: - Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon frequency distribution. + Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon + frequency distribution. """ organism2frequencies = {} @@ -617,7 +629,8 @@ def get_codon_table(organism: str) -> int: "Arabidopsis thaliana", "Caenorhabditis elegans", "Chlamydomonas reinhardtii", - "Saccharomyces cerevisiae" "Danio rerio", + "Saccharomyces cerevisiae", + "Danio rerio", "Drosophila melanogaster", "Homo sapiens", "Mus musculus", diff --git a/CodonTransformer/CodonEvaluation.py b/CodonTransformer/CodonEvaluation.py index e94423a..d09ba92 100644 --- a/CodonTransformer/CodonEvaluation.py +++ b/CodonTransformer/CodonEvaluation.py @@ -1,18 +1,17 @@ """ File: CodonEvaluation.py --------------------------- -Includes functions to calculate various evaluation metrics along with helper functions. +Includes functions to calculate various evaluation metrics along with helper +functions. """ -import pandas as pd +from typing import Dict, List, Tuple +import pandas as pd from CAI import CAI, relative_adaptiveness - -from typing import List, Dict, Tuple from tqdm import tqdm - def get_CSI_weights(sequences: List[str]) -> Dict[str, float]: """ Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences. @@ -47,7 +46,7 @@ def get_organism_to_CSI_weights( Calculate the Codon Similarity Index (CSI) weights for a list of organisms. Args: - dataset (pd.DataFrame): The dataset containing organism and DNA sequence information. + dataset (pd.DataFrame): Dataset containing organism and DNA sequence info. organisms (List[str]): List of organism names. Returns: @@ -91,7 +90,8 @@ def get_cfd( Args: dna (str): The DNA sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequency distribution per amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequency distribution per amino acid. threshold (float): Frequency threshold for counting rare codons. Returns: @@ -127,7 +127,8 @@ def get_min_max_percentage( Args: dna (str): The DNA sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequency distribution per amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequency distribution per amino acid. window_size (int): Size of the window to calculate %MinMax. Returns: @@ -147,14 +148,12 @@ def get_min_max_percentage( # Iterate through the DNA sequence using the specified window size for i in range(len(codons) - window_size + 1): - codon_window = codons[ - i : i + window_size - ] # List of the codons in the current window + codon_window = codons[i : i + window_size] # Codons in the current window Actual = 0.0 # Average of the actual codon frequencies Max = 0.0 # Average of the min codon frequencies Min = 0.0 # Average of the max codon frequencies - Avg = 0.0 # Average of the averages of all the frequencies associated with each amino acid + Avg = 0.0 # Average of the averages of all frequencies for each amino acid # Sum the frequencies for codons in the current window for codon in codon_window: @@ -210,7 +209,7 @@ def sum_up_to(x): return x + sum_up_to(x - 1) def f(x): - """Function that returns 4 if x is greater than or equal to 4, else returns x.""" + """Returns 4 if x is greater than or equal to 4, else returns x.""" if x >= 4: return 4 elif x < 4: @@ -242,8 +241,10 @@ def get_sequence_similarity( Args: original (str): The original sequence. predicted (str): The predicted sequence. - truncate (bool): If True, truncate the original sequence to match the length of the predicted sequence. - window_length (int): Length of the window for comparison (1 for amino acids, 3 for codons). + truncate (bool): If True, truncate the original sequence to match the length + of the predicted sequence. + window_length (int): Length of the window for comparison (1 for amino acids, + 3 for codons). Returns: float: The sequence similarity as a percentage. diff --git a/CodonTransformer/CodonJupyter.py b/CodonTransformer/CodonJupyter.py index 6c730cd..fdf7164 100644 --- a/CodonTransformer/CodonJupyter.py +++ b/CodonTransformer/CodonJupyter.py @@ -7,13 +7,13 @@ from typing import Dict, List, Tuple import ipywidgets as widgets -from IPython.display import display, HTML +from IPython.display import HTML, display from CodonTransformer.CodonUtils import ( - DNASequencePrediction, COMMON_ORGANISMS, - ORGANISM2ID, ID2ORGANISM, + ORGANISM2ID, + DNASequencePrediction, ) @@ -49,11 +49,11 @@ def create_styled_options( organism_id = organism2id[organism] if is_fine_tuned: if organism_id < 10: - styled_options.append(f"\u200B{organism_id:>6}. {organism}") + styled_options.append(f"\u200b{organism_id:>6}. {organism}") elif organism_id < 100: - styled_options.append(f"\u200B{organism_id:>5}. {organism}") + styled_options.append(f"\u200b{organism_id:>5}. {organism}") else: - styled_options.append(f"\u200B{organism_id:>4}. {organism}") + styled_options.append(f"\u200b{organism_id:>4}. {organism}") else: if organism_id < 10: styled_options.append(f"{organism_id:>6}. {organism}") @@ -160,7 +160,7 @@ def get_dropdown_style() -> str: flex-direction: column; align-items: flex-start; } - .widget-dropdown option[value^="\u200B"] { + .widget-dropdown option[value^="\u200b"] { font-family: sans-serif; font-weight: bold; font-size: 18px; @@ -188,7 +188,8 @@ def display_organism_dropdown(container: UserContainer) -> None: """ dropdown = create_organism_dropdown(container) header = widgets.HTML( - 'Select Organism:
' + 'Select Organism:' + '' ) container_widget = widgets.VBox( [header, dropdown], @@ -242,7 +243,8 @@ def save_protein(change: Dict[str, str]) -> None: Save the input protein sequence to the container. Args: - change (Dict[str, str]): A dictionary containing information about the change in textarea value. + change (Dict[str, str]): A dictionary containing information about + the change in textarea value. """ container.protein = ( change["new"] @@ -258,7 +260,8 @@ def save_protein(change: Dict[str, str]) -> None: # Display the input widget header = widgets.HTML( - 'Enter Protein Sequence:' + 'Enter Protein Sequence:' + '' ) container_widget = widgets.VBox( [header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px") @@ -270,13 +273,13 @@ def save_protein(change: Dict[str, str]) -> None: def format_model_output(output: DNASequencePrediction) -> str: """ - Format the DNA sequence prediction output in a visually appealing and easy-to-read manner. + Format DNA sequence prediction output in an appealing and easy-to-read manner. This function takes the prediction output and formats it into a structured string with clear section headers and separators. Args: - output (DNASequencePrediction): An object containing the prediction output. + output (DNASequencePrediction): Object containing the prediction output. Expected attributes: - organism (str): The organism name. - protein (str): The input protein sequence. diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 7df8520..a4d5434 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -5,27 +5,28 @@ helper functions related to processing data for passing to the model. """ -from typing import Any, List, Dict, Tuple, Optional, Union -import onnxruntime as rt +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np +import onnxruntime as rt import torch import transformers from transformers import ( + AutoTokenizer, BatchEncoding, - PreTrainedTokenizerFast, BigBirdConfig, - AutoTokenizer, BigBirdForMaskedLM, + PreTrainedTokenizerFast, ) -import numpy as np from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( AMINO_ACIDS, - ORGANISM2ID, - TOKEN2INDEX, INDEX2TOKEN, NUM_ORGANISMS, + ORGANISM2ID, + TOKEN2INDEX, DNASequencePrediction, ) @@ -37,6 +38,7 @@ def predict_dna_sequence( tokenizer: Union[str, PreTrainedTokenizerFast] = None, model: Union[str, torch.nn.Module] = None, attention_type: str = "original_full", + deterministic: bool = True, ) -> DNASequencePrediction: """ Predict the DNA sequence for a given protein using CodonTransformer model. @@ -47,30 +49,38 @@ def predict_dna_sequence( Args: protein (str): The input protein sequence to predict the DNA sequence for. - organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). - If a string is provided, it will be converted to the corresponding ID using ORGANISM2ID. + organism (Union[int, str]): Either the ID of the organism or its name (e.g., + "Escherichia coli general"). If a string is provided, it will be converted + to the corresponding ID using ORGANISM2ID. device (torch.device): The device (CPU or GPU) to run the model on. - tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, - a pre-loaded tokenizer object, or None. If None, it will be loaded from HuggingFace. Defaults to None. - model (Union[str, torch.nn.Module, None], optional): Either a file path to load the model from, - a pre-loaded model object, or None. If None, it will be loaded from HuggingFace. Defaults to None. - attention_type (str, optional): The type of attention mechanism to use in the model. - Can be either 'block_sparse' or 'original_full'. Defaults to "original_full". + tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file + path to load the tokenizer from, a pre-loaded tokenizer object, or None. If + None, it will be loaded from HuggingFace. Defaults to None. + model (Union[str, torch.nn.Module, None], optional): Either a file path to load + the model from, a pre-loaded model object, or None. If None, it will be + loaded from HuggingFace. Defaults to None. + attention_type (str, optional): The type of attention mechanism to use in the + model. Can be either 'block_sparse' or 'original_full'. Defaults to + "original_full". + deterministic (bool, optional): Whether to use deterministic decoding (most + likely tokens). If False, samples tokens according to their probabilities. + Defaults to True. Returns: DNASequencePrediction: An object containing the prediction results: - - organism (str): The name of the organism used for prediction. - - protein (str): The input protein sequence for which DNA sequence is predicted. - - processed_input (str): The processed input sequence (merged protein and DNA). - - predicted_dna (str): The predicted DNA sequence. + - organism (str): Name of the organism used for prediction. + - protein (str): Input protein sequence for which DNA sequence is predicted. + - processed_input (str): Processed input sequence (merged protein and DNA). + - predicted_dna (str): Predicted DNA sequence. Raises: ValueError: If the protein sequence is empty or if the organism is invalid. Note: - This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from CodonTransformer.CodonUtils. - ORGANISM2ID maps organism names to their corresponding IDs. - INDEX2TOKEN maps model output indices (token ids) to respective codons. + This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from + CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their + corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to + respective codons. Example: >>> import torch @@ -83,30 +93,41 @@ def predict_dna_sequence( >>> >>> # Load tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") - >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device) + >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") + >>> model = model.to(device) >>> >>> # Define protein sequence and organism - >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" + >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA" >>> organism = "Escherichia coli general" >>> - >>> # Predict DNA sequence + >>> # Predict DNA sequence with deterministic decoding >>> output = predict_dna_sequence( ... protein=protein, ... organism=organism, ... device=device, ... tokenizer=tokenizer, ... model=model, - ... attention_type="original_full" + ... attention_type="original_full", + ... deterministic=True + ... ) + >>> + >>> # Predict DNA sequence with probabilistic sampling + >>> output_sampler = predict_dna_sequence( + ... protein=protein, + ... organism=organism, + ... device=device, + ... tokenizer=tokenizer, + ... model=model, + ... attention_type="original_full", + ... deterministic=False ... ) >>> >>> print(format_model_output(output)) + >>> print(format_model_output(output_sampler)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") - if not isinstance(protein, str): - raise ValueError("Protein sequence must be a string.") - # Test that the input protein sequence contains only valid amino acids if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") @@ -139,12 +160,23 @@ def predict_dna_sequence( # Get the model predictions output_dict = model(**tokenized_input, return_dict=True) - output = output_dict.logits.detach().cpu().numpy() + logits = output_dict.logits.detach().cpu() # Decode the predicted DNA sequence from the model output - predicted_dna = list( - map(INDEX2TOKEN.__getitem__, output.argmax(axis=-1).squeeze().tolist()) - ) + if deterministic: + # Select the most probable tokens (argmax) + predicted_indices = logits.argmax(dim=-1).squeeze().tolist() + else: + # Sample tokens according to their probability distribution + # Convert logits to probabilities using softmax + probabilities = torch.softmax(logits, dim=-1) + # Sample from the probability distribution at each position + probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] + predicted_indices = ( + torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist() + ) + + predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) # Skip special tokens [CLS] and [SEP] to create the predicted_dna predicted_dna = ( @@ -170,17 +202,21 @@ def load_model( Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace. Args: - model_path (Optional[str]): Path to the model file or checkpoint. If None, load from HuggingFace. + model_path (Optional[str]): Path to the model file or checkpoint. If None, + load from HuggingFace. device (torch.device, optional): The device to load the model onto. - attention_type (str, optional): The type of attention, 'block_sparse' or 'original_full'. - num_organisms (int, optional): Number of organisms, needed if loading from a checkpoint that requires this. - remove_prefix (bool, optional): Whether to remove the "model." prefix from the keys in the state dict. + attention_type (str, optional): The type of attention, 'block_sparse' + or 'original_full'. + num_organisms (int, optional): Number of organisms, needed if loading from a + checkpoint that requires this. + remove_prefix (bool, optional): Whether to remove the "model." prefix from the + keys in the state dict. Returns: torch.nn.Module: The loaded model. """ if not model_path: - print("Warning: Model path not provided. Loading from HuggingFace.") + warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning) model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") elif model_path.endswith(".ckpt"): @@ -209,7 +245,8 @@ def load_model( else: raise ValueError( - "Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace." + "Unsupported file type. Please provide a .ckpt or .pt file, " + "or None to load from HuggingFace." ) # Prepare model for evaluation @@ -260,16 +297,19 @@ def create_model_from_checkpoint( def load_tokenizer(tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast: """ - Create and return a tokenizer object from the given tokenizer path or HuggingFace. + Create and return a tokenizer object from tokenizer path or HuggingFace. Args: - tokenizer_path (Optional[str]): Path to the tokenizer file. If None, load from HuggingFace. + tokenizer_path (Optional[str]): Path to the tokenizer file. If None, + load from HuggingFace. Returns: PreTrainedTokenizerFast: The tokenizer object. """ if not tokenizer_path: - print("Warning: Tokenizer path not provided. Loading from HuggingFace.") + warnings.warn( + "Tokenizer path not provided. Loading from HuggingFace.", UserWarning + ) return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") return transformers.PreTrainedTokenizerFast( @@ -291,11 +331,14 @@ def tokenize( ) -> BatchEncoding: """ Return the tokenized sequences given a batch of input data. - Each data in the batch is expected to be a dictionary with "codons" and "organism" keys. + Each data in the batch is expected to be a dictionary with "codons" and + "organism" keys. Args: - batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and "organism" keys. - tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or path to the tokenizer file. + batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and + "organism" keys. + tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or + path to the tokenizer file. max_len (int, optional): Maximum length of the tokenized sequence. Returns: @@ -326,28 +369,32 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: """ Validate and convert the organism input to both ID and name. - This function takes either an organism ID or name as input and returns both the ID and name. - It performs validation to ensure the input corresponds to a valid organism in the ORGANISM2ID dictionary. + This function takes either an organism ID or name as input and returns both + the ID and name. It performs validation to ensure the input corresponds to + a valid organism in the ORGANISM2ID dictionary. Args: - organism (Union[int, str]): Either the ID of the organism (int) or its name (str). + organism (Union[int, str]): Either the ID of the organism (int) or its + name (str). Returns: Tuple[int, str]: A tuple containing the organism ID (int) and name (str). Raises: - ValueError: If the input is neither a string nor an integer, if the organism name is not found - in ORGANISM2ID, if the organism ID is not a value in ORGANISM2ID, or if no name - is found for a given ID. + ValueError: If the input is neither a string nor an integer, if the + organism name is not found in ORGANISM2ID, if the organism ID is not a + value in ORGANISM2ID, or if no name is found for a given ID. Note: - This function relies on the ORGANISM2ID dictionary imported from CodonTransformer.CodonUtils, - which maps organism names to their corresponding IDs. + This function relies on the ORGANISM2ID dictionary imported from + CodonTransformer.CodonUtils, which maps organism names to their + corresponding IDs. """ if isinstance(organism, str): if organism not in ORGANISM2ID: raise ValueError( - f"Invalid organism name: {organism}. Please use a valid organism name or ID." + f"Invalid organism name: {organism}. " + "Please use a valid organism name or ID." ) organism_id = ORGANISM2ID[organism] organism_name = organism @@ -355,7 +402,8 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: elif isinstance(organism, int): if organism not in ORGANISM2ID.values(): raise ValueError( - f"Invalid organism ID: {organism}. Please use a valid organism name or ID." + f"Invalid organism ID: {organism}. " + "Please use a valid organism name or ID." ) organism_id = organism @@ -365,9 +413,6 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: if organism_name is None: raise ValueError(f"No organism name found for ID: {organism}") - else: - raise ValueError("Organism must be either a string (name) or an integer (ID).") - return organism_id, organism_name @@ -375,12 +420,13 @@ def get_high_frequency_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using High Frequency Choice (HFC) approach in which - the most frequent codon for a given amino acid is always chosen. + Return the DNA sequence optimized using High Frequency Choice (HFC) approach + in which the most frequent codon for a given amino acid is always chosen. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -400,7 +446,8 @@ def precompute_most_frequent_codons( Precompute the most frequent codon for each amino acid. Args: - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: Dict[str, str]: The most frequent codon for each amino acid. @@ -421,7 +468,8 @@ def get_high_frequency_choice_sequence_optimized( Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -436,17 +484,20 @@ def get_background_frequency_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using Background Frequency Choice (BFC) approach in which - a random codon for a given amino acid is chosen using the codon frequencies probability distribution. + Return the DNA sequence optimized using Background Frequency Choice (BFC) + approach in which a random codon for a given amino acid is chosen using + the codon frequencies probability distribution. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. """ - # Select a random codon for each amino acid based on the codon frequencies probability distribution + # Select a random codon for each amino acid based on the codon frequencies + # probability distribution dna_codons = [ np.random.choice( codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1] @@ -463,7 +514,8 @@ def precompute_cdf( Precompute the cumulative distribution function (CDF) for each amino acid. Args: - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid. @@ -486,7 +538,8 @@ def get_background_frequency_choice_sequence_optimized( Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -507,12 +560,14 @@ def get_uniform_random_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using Uniform Random Choice (URC) approach in which - a random codon for a given amino acid is chosen using a uniform prior. + Return the DNA sequence optimized using Uniform Random Choice (URC) approach + in which a random codon for a given amino acid is chosen using a uniform + prior. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -529,7 +584,8 @@ def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> st Return the optimized codon sequence for the given protein sequence using ICOR. Credit: ICOR: improving codon optimization with recurrent neural networks - Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas Densmore + Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas + Densmore Args: input_seq (str): The input protein sequence. diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index cf84b24..dd91e3f 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -4,18 +4,16 @@ Includes constants and helper functions used by other Python scripts. """ +import itertools import os -import re import pickle -import requests -import itertools - -import torch -import pandas as pd +import re from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple -from typing import Any, List, Dict, Tuple, Optional, Iterator - +import pandas as pd +import requests +import torch # List of all amino acids AMINO_ACIDS: List[str] = [ @@ -470,10 +468,10 @@ class DNASequencePrediction: A class to hold the output of the DNA sequence prediction. Attributes: - organism (str): The name of the organism used for prediction. - protein (str): The input protein sequence for which DNA sequence is predicted. - processed_input (str): The processed input sequence (merged protein and DNA). - predicted_dna (str): The predicted DNA sequence. + organism (str): Name of the organism used for prediction. + protein (str): Input protein sequence for which DNA sequence is predicted. + processed_input (str): Processed input sequence (merged protein and DNA). + predicted_dna (str): Predicted DNA sequence. """ organism: str @@ -488,7 +486,8 @@ class IterableData(torch.utils.data.IterableDataset): data) in parallel multi-processing environments, e.g., multi-GPU. Args: - dist_env (Optional[str]): The distribution environment identifier (e.g., "slurm"). + dist_env (Optional[str]): The distribution environment identifier + (e.g., "slurm"). Credit: Guillaume Filion """ @@ -501,7 +500,7 @@ def __init__(self, dist_env: Optional[str] = None): @property def iterator(self) -> Iterator: - """Define the stream logic for the dataset. Should be implemented in subclasses.""" + """Define the stream logic for the dataset. Implement in subclasses.""" raise NotImplementedError def __iter__(self) -> Iterator: @@ -573,7 +572,8 @@ def save_python_object_to_disk(input_object: Any, file_path: str) -> None: def find_pattern_in_fasta(keyword: str, text: str) -> str: """ - Find a specific keyword pattern in text. Helpful for identifying parts of a FASTA sequence. + Find a specific keyword pattern in text. Helpful for identifying parts + of a FASTA sequence. Args: keyword (str): The keyword pattern to search for. @@ -589,17 +589,19 @@ def find_pattern_in_fasta(keyword: str, text: str) -> str: def get_organism2id_dict(organism_reference: str) -> Dict[str, int]: """ - Return a dictionary mapping each organism in training data to an index used for training. + Return a dictionary mapping each organism in training data to an index + used for training. Args: - organism_reference (str): Path to a CSV file containing a list of all organisms. - The format of the CSV file should be as follows: - 0,Escherichia coli - 1,Homo sapiens - 2,Mus musculus + organism_reference (str): Path to a CSV file containing a list of + all organisms. The format of the CSV file should be as follows: + + 0,Escherichia coli + 1,Homo sapiens + 2,Mus musculus Returns: - Dict[str, int]: A dictionary mapping organism names to their respective indices. + Dict[str, int]: Dictionary mapping organism names to their respective indices. """ # Read the CSV file and create a dictionary mapping organisms to their indices organisms = pd.read_csv(organism_reference, index_col=0, header=None) diff --git a/README.md b/README.md index e1b4304..6b0b7a3 100644 --- a/README.md +++ b/README.md @@ -115,14 +115,14 @@ The package requires `python>=3.9`. The requirements are [availabe here](require To finetune CodonTransformer on your own data, follow these steps: 1. **Prepare your dataset** - + Create a CSV file with the following columns: - `dna`: DNA sequences (string, preferably uppercase ATCG) - `protein`: Protein sequences (string, preferably uppercase amino acid letters) - `organism`: Target organism (string or int, must be from `ORGANISM2ID` in `CodonUtils`) - Note: + Note: - Use organisms from the `FINE_TUNE_ORGANISMS` list for best results. - For E. coli, use `Escherichia coli general`. - DNA sequences should ideally contain only A, T, C, and G. Ambiguous codons are replaced with 'UNK' for tokenization. @@ -132,7 +132,7 @@ To finetune CodonTransformer on your own data, follow these steps: