From a89b980c92a930e5d8177d146e6e63074e52d13f Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 18 Sep 2024 12:20:32 -0400 Subject: [PATCH 1/6] Add missing requirement for sklearn --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 23d93fc..db408c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ onnxruntime>=1.17.3,<3.0 pandas>=2.0.0,<3.0 python_codon_tables>=0.1.12,<1.0 pytorch_lightning>=2.2.1,<3.0 +scikit-learn>=1.2.2,<2.0 scipy>=1.13.1,<3.0 setuptools>=68.2.2,<70.0 torch>=2.0.0,<3.0 From 7e935e7f8224dd354321fd75aef2e57b1d655e50 Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 18 Sep 2024 15:59:51 -0400 Subject: [PATCH 2/6] Add test code and Codecov integration - Added test cases for various modules. - Integrated Codecov for test coverage reporting. - Improved .gitignore to exclude additional files and directories. --- .github/workflows/ci.yml | 39 ++++++++++ .gitignore | 17 ++++ CodonTransformer/CodonData.py | 86 ++++++++++++--------- CodonTransformer/CodonPrediction.py | 32 +++++--- CodonTransformer/CodonUtils.py | 4 + Makefile | 9 +++ tests/__init__.py | 0 tests/test_CodonData.py | 85 ++++++++++++++++++++ tests/test_CodonJupyter.py | 116 ++++++++++++++++++++++++++++ tests/test_CodonPrediction.py | 51 ++++++++++++ tests/test_CodonUtils.py | 89 +++++++++++++++++++++ 11 files changed, 479 insertions(+), 49 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 Makefile create mode 100644 tests/__init__.py create mode 100644 tests/test_CodonData.py create mode 100644 tests/test_CodonJupyter.py create mode 100644 tests/test_CodonPrediction.py create mode 100644 tests/test_CodonUtils.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..080d1e7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,39 @@ +# .github/workflows/ci.yml + +name: CI + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install coverage + + - name: Run tests with coverage + run: | + make test_with_coverage + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v2 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 68bc17f..0a03dca 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,20 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Coverage reports +coverage.xml + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Temporary files +*.tmp +*.temp + +# PyTorch Lightning checkpoints +lightning_logs/ + +# PyTorch model weights +*.pth +*.pt \ No newline at end of file diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 102d7b2..5932ff8 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -355,24 +355,29 @@ def get_amino_acid_sequence( def read_fasta_file( input_file: str, - output_path: str, + save_to_file: Optional[str] = None, organism: str = "", - return_dataframe: bool = True, buffer_size: int = 50000, ) -> pd.DataFrame: """ - Read a FASTA file of DNA sequences and save it to a Pandas DataFrame. + Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame. + Optionally, save the DataFrame to a CSV file. Args: input_file (str): Path to the input FASTA file. - output_path (str): Path to save the output DataFrame. - organism (str): Name of the organism. - return_dataframe (bool): Whether to return the DataFrame. - buffer_size (int): Buffer size for reading the file. + save_to_file (Optional[str]): Path to save the output DataFrame. If None, data is only returned. + organism (str): Name of the organism. If empty, it will be extracted from the FASTA description. + buffer_size (int): Number of records to process before writing to file. Returns: - pd.DataFrame: DataFrame containing the DNA sequences. + pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe is True, else None. + + Raises: + FileNotFoundError: If the input file does not exist. """ + if not os.path.exists(input_file): + raise FileNotFoundError(f"Input file not found: {input_file}") + buffer = [] columns = [ "dna", @@ -384,20 +389,25 @@ def read_fasta_file( "tokenized", ] - # Read the FASTA file and process each sequence record + # Initialize DataFrame to store all data if return_dataframe is True + all_data = pd.DataFrame(columns=columns) + with open(input_file, "r") as fasta_file: for record in tqdm( - SeqIO.parse(fasta_file, "fasta"), desc=f"{organism}", unit=" Rows" + SeqIO.parse(fasta_file, "fasta"), + desc=f"Processing {organism}", + unit=" Records", ): - dna = str(record.seq).strip() + dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence # Determine the organism from the record if not provided - if not organism: - organism = find_pattern_in_fasta("organism", record.description) - GeneID = find_pattern_in_fasta("GeneID", record.description) + current_organism = organism or find_pattern_in_fasta( + "organism", record.description + ) + gene_id = find_pattern_in_fasta("GeneID", record.description) # Get the appropriate codon table for the organism - codon_table = get_codon_table(organism) + codon_table = get_codon_table(current_organism) # Translate DNA to protein sequence protein, correct_seq = get_amino_acid_sequence( @@ -406,44 +416,46 @@ def read_fasta_file( codon_table=codon_table, return_correct_seq=True, ) - description = str(record.description[: record.description.find("[")]) - tokenized = get_merged_seq(protein, dna, seperator=STOP_SYMBOL) + description = record.description.split("[", 1)[0].strip() + tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL) # Create a data row for the current sequence data_row = { "dna": dna, "protein": protein, "correct_seq": correct_seq, - "organism": organism, - "GeneID": GeneID, + "organism": current_organism, + "GeneID": gene_id, "description": description, "tokenized": tokenized, } buffer.append(data_row) # Write buffer to CSV file when buffer size is reached - if len(buffer) >= buffer_size: - buffer_df = pd.DataFrame(buffer, columns=columns) - buffer_df.to_csv( - output_path, - mode="a", - header=(not os.path.exists(output_path)), - index=True, - ) + if save_to_file and len(buffer) >= buffer_size: + write_buffer_to_csv(buffer, save_to_file, columns) buffer = [] - # Write remaining buffer to CSV file - if buffer: - buffer_df = pd.DataFrame(buffer, columns=columns) - buffer_df.to_csv( - output_path, - mode="a", - header=(not os.path.exists(output_path)), - index=True, + all_data = pd.concat( + [all_data, pd.DataFrame([data_row])], ignore_index=True ) - if return_dataframe: - return pd.read_csv(output_path, index_col=0) + # Write remaining buffer to CSV file + if save_to_file and buffer: + write_buffer_to_csv(buffer, save_to_file, columns) + + return all_data + + +def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]): + """Helper function to write buffer to CSV file.""" + buffer_df = pd.DataFrame(buffer, columns=columns) + buffer_df.to_csv( + output_path, + mode="a", + header=(not os.path.exists(output_path)), + index=True, + ) def download_codon_frequencies_from_kazusa( diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 2240b53..7df8520 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -15,12 +15,13 @@ PreTrainedTokenizerFast, BigBirdConfig, AutoTokenizer, - BigBirdForMaskedLM + BigBirdForMaskedLM, ) import numpy as np from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, ORGANISM2ID, TOKEN2INDEX, INDEX2TOKEN, @@ -76,18 +77,18 @@ def predict_dna_sequence( >>> from transformers import AutoTokenizer, BigBirdForMaskedLM >>> from CodonTransformer.CodonPrediction import predict_dna_sequence >>> from CodonTransformer.CodonJupyter import format_model_output - >>> + >>> >>> # Set up device >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - >>> + >>> >>> # Load tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device) - >>> + >>> >>> # Define protein sequence and organism >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" >>> organism = "Escherichia coli general" - >>> + >>> >>> # Predict DNA sequence >>> output = predict_dna_sequence( ... protein=protein, @@ -97,12 +98,19 @@ def predict_dna_sequence( ... model=model, ... attention_type="original_full" ... ) - >>> + >>> >>> print(format_model_output(output)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") + if not isinstance(protein, str): + raise ValueError("Protein sequence must be a string.") + + # Test that the input protein sequence contains only valid amino acids + if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): + raise ValueError("Invalid amino acid found in protein sequence.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -127,9 +135,7 @@ def predict_dna_sequence( "codons": merged_seq, "organism": organism_id, } - tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to( - device - ) + tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device) # Get the model predictions output_dict = model(**tokenized_input, return_dict=True) @@ -202,7 +208,9 @@ def load_model( model.load_state_dict(state_dict) else: - raise ValueError("Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace.") + raise ValueError( + "Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace." + ) # Prepare model for evaluation model.bert.set_attention_type(attention_type) @@ -386,7 +394,7 @@ def get_high_frequency_choice_sequence( def precompute_most_frequent_codons( - codon_frequencies: Dict[str, Tuple[List[str], List[float]]] + codon_frequencies: Dict[str, Tuple[List[str], List[float]]], ) -> Dict[str, str]: """ Precompute the most frequent codon for each amino acid. @@ -449,7 +457,7 @@ def get_background_frequency_choice_sequence( def precompute_cdf( - codon_frequencies: Dict[str, Tuple[List[str], List[float]]] + codon_frequencies: Dict[str, Tuple[List[str], List[float]]], ) -> Dict[str, Tuple[List[str], Any]]: """ Precompute the cumulative distribution function (CDF) for each amino acid. diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 8595ac1..cf84b24 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -593,6 +593,10 @@ def get_organism2id_dict(organism_reference: str) -> Dict[str, int]: Args: organism_reference (str): Path to a CSV file containing a list of all organisms. + The format of the CSV file should be as follows: + 0,Escherichia coli + 1,Homo sapiens + 2,Mus musculus Returns: Dict[str, int]: A dictionary mapping organism names to their respective indices. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2527ef0 --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +# Makefile + +.PHONY: test +test: + python -m unittest discover -s tests + +.PHONY: test_with_coverage +test_with_coverage: + coverage run -m unittest discover -s tests diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py new file mode 100644 index 0000000..ca7d464 --- /dev/null +++ b/tests/test_CodonData.py @@ -0,0 +1,85 @@ +import tempfile +import unittest +import pandas as pd +from CodonTransformer.CodonData import ( + read_fasta_file, + build_amino2codon_skeleton, + get_amino_acid_sequence, + is_correct_seq, +) +from Bio.Data.CodonTable import TranslationError + + +class TestCodonData(unittest.TestCase): + def test_read_fasta_file(self): + fasta_content = ">sequence1\n" "ATGATGATGATGATG\n" ">sequence2\n" "TGATGATGATGA" + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".fasta" + ) as temp_file: + temp_file.write(fasta_content) + temp_file_name = temp_file.name + + try: + sequences = read_fasta_file(temp_file_name, save_to_file=None) + self.assertIsInstance(sequences, pd.DataFrame) + self.assertEqual(len(sequences), 2) + self.assertEqual(sequences.iloc[0]["dna"], "ATGATGATGATGATG") + self.assertEqual(sequences.iloc[1]["dna"], "TGATGATGATGA") + finally: + import os + + os.unlink(temp_file_name) + + def test_build_amino2codon_skeleton(self): + organism = "Homo sapiens" + codon_skeleton = build_amino2codon_skeleton(organism) + + expected_amino_acids = "ARNDCQEGHILKMFPSTWYV_" + + for amino_acid in expected_amino_acids: + self.assertIn(amino_acid, codon_skeleton) + codons, frequencies = codon_skeleton[amino_acid] + self.assertIsInstance(codons, list) + self.assertIsInstance(frequencies, list) + self.assertEqual(len(codons), len(frequencies)) + self.assertTrue(all(isinstance(codon, str) for codon in codons)) + self.assertTrue(all(freq == 0 for freq in frequencies)) + + all_codons = set( + codon for codons, _ in codon_skeleton.values() for codon in codons + ) + self.assertEqual(len(all_codons), 64) # There should be 64 unique codons + + def test_get_amino_acid_sequence(self): + dna = "ATGGCCTGA" + protein, is_correct = get_amino_acid_sequence(dna, return_correct_seq=True) + self.assertEqual(protein, "MA_") + self.assertTrue(is_correct) + + def test_is_correct_seq(self): + dna = "ATGGCCTGA" + protein = "MA_" + self.assertTrue(is_correct_seq(dna, protein)) + + def test_read_fasta_file_raises_exception_for_non_dna(self): + non_dna_content = ">sequence1\nATGATGATGXYZATG\n>sequence2\nTGATGATGATGA" + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".fasta" + ) as temp_file: + temp_file.write(non_dna_content) + temp_file_name = temp_file.name + + try: + with self.assertRaises(TranslationError) as context: + read_fasta_file(temp_file_name) + self.assertIn("Codon 'XYZ' is invalid", str(context.exception)) + finally: + import os + + os.unlink(temp_file_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodonJupyter.py b/tests/test_CodonJupyter.py new file mode 100644 index 0000000..0ff6afa --- /dev/null +++ b/tests/test_CodonJupyter.py @@ -0,0 +1,116 @@ +import unittest +import ipywidgets +from CodonTransformer.CodonJupyter import ( + UserContainer, + create_organism_dropdown, + create_dropdown_options, + display_organism_dropdown, + display_protein_input, + format_model_output, + DNASequencePrediction, +) +from CodonTransformer.CodonUtils import ORGANISM2ID + + +class TestCodonJupyter(unittest.TestCase): + def test_UserContainer(self): + user_container = UserContainer() + self.assertEqual(user_container.organism, -1) + self.assertEqual(user_container.protein, "") + + def test_create_organism_dropdown(self): + container = UserContainer() + dropdown = create_organism_dropdown(container) + + self.assertIsInstance(dropdown, ipywidgets.Dropdown) + self.assertGreater(len(dropdown.options), 0) + self.assertEqual(dropdown.description, "") + self.assertEqual(dropdown.layout.width, "40%") + self.assertEqual(dropdown.layout.margin, "0 0 10px 0") + self.assertEqual(dropdown.style.description_width, "initial") + + # Test the dropdown options + options = dropdown.options + self.assertIn("", options) + self.assertIn("Selected Organisms", options) + self.assertIn("All Organisms", options) + + def test_create_dropdown_options(self): + options = create_dropdown_options(ORGANISM2ID) + self.assertIsInstance(options, list) + self.assertGreater(len(options), 0) + + def test_display_organism_dropdown(self): + container = UserContainer() + with unittest.mock.patch( + "CodonTransformer.CodonJupyter.display" + ) as mock_display: + display_organism_dropdown(container) + + # Check that display was called twice (for container_widget and HTML) + self.assertEqual(mock_display.call_count, 2) + + # Check that the first call to display was with a VBox widget + self.assertIsInstance(mock_display.call_args_list[0][0][0], ipywidgets.VBox) + + # Check that the VBox contains a Dropdown + dropdown = mock_display.call_args_list[0][0][0].children[1] + self.assertIsInstance(dropdown, ipywidgets.Dropdown) + self.assertGreater(len(dropdown.options), 0) + + def test_display_protein_input(self): + container = UserContainer() + with unittest.mock.patch( + "CodonTransformer.CodonJupyter.display" + ) as mock_display: + display_protein_input(container) + + # Check that display was called twice (for container_widget and HTML) + self.assertEqual(mock_display.call_count, 2) + + # Check that the first call to display was with a VBox widget + self.assertIsInstance(mock_display.call_args_list[0][0][0], ipywidgets.VBox) + + # Check that the VBox contains a Textarea + textarea = mock_display.call_args_list[0][0][0].children[1] + self.assertIsInstance(textarea, ipywidgets.Textarea) + + # Verify the properties of the Textarea + self.assertEqual(textarea.value, "") + self.assertEqual(textarea.placeholder, "Enter here...") + self.assertEqual(textarea.description, "") + self.assertEqual(textarea.layout.width, "100%") + self.assertEqual(textarea.layout.height, "100px") + self.assertEqual(textarea.layout.margin, "0 0 10px 0") + self.assertEqual(textarea.style.description_width, "initial") + + def test_format_model_output(self): + output = DNASequencePrediction( + organism="Escherichia coli", + protein="MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + processed_input="MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + predicted_dna="ATGAAAACTGTTCGTCAGGAACGTCTGAAATCTATTGTTCGTATTCTGGAACGTTCTAAAGAACCGGTTTCTGGTGCTCAACTGGCTGAAGAACTGTCTGTTTCTCGTCAGGTTATTGTTCAGGACATTGCTTACCTGCGTTCTCTGGGTTATAA", + ) + formatted_output = format_model_output(output) + self.assertIsInstance(formatted_output, str) + self.assertIn("Organism", formatted_output) + self.assertIn("Escherichia coli", formatted_output) + self.assertIn("Input Protein", formatted_output) + self.assertIn( + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + formatted_output, + ) + self.assertIn("Processed Input", formatted_output) + self.assertIn( + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + formatted_output, + ) + self.assertIn("Predicted DNA", formatted_output) + self.assertIn( + "ATGAAAACTGTTCGTCAGGAACGTCTGAAATCTATTGTTCGTATTCTGGAACGTTCTAAAGAACCGGTTTCTGGTGCTCAACTGGCTGAAGAACTGTCTGTTTCTCGTCAGGTTATTGTTCAGGACATTGCTTACCTGCGTTCTCTGGGTTATAA", + formatted_output, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py new file mode 100644 index 0000000..199c383 --- /dev/null +++ b/tests/test_CodonPrediction.py @@ -0,0 +1,51 @@ +import unittest +import torch +from CodonTransformer.CodonPrediction import ( + predict_dna_sequence, + # add other imported functions or classes as needed +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class TestCodonPrediction(unittest.TestCase): + def test_predict_dna_sequence_valid_input(self): + # Test predict_dna_sequence with a valid protein sequence and organism code + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + result = predict_dna_sequence(protein_sequence, organism, device) + # Test if the output is a string and contains only A, T, C, G. + self.assertIsInstance(result.predicted_dna, str) + self.assertEqual(result.predicted_dna, "ATGTGGTGGATGTGGTGA") + + def test_predict_dna_sequence_invalid_protein_sequence(self): + # Test predict_dna_sequence with an invalid protein sequence + protein_sequence = "MKTZZFVLLL" # 'Z' is not a valid amino acid + organism = "Escherichia coli general" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + def test_predict_dna_sequence_invalid_organism_code(self): + # Test predict_dna_sequence with an invalid organism code + protein_sequence = "MKTFFVLLL" + organism = "Alien $%#@!" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + def test_predict_dna_sequence_empty_protein_sequence(self): + # Test predict_dna_sequence with an empty protein sequence + protein_sequence = "" + organism = "Escherichia coli general" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + def test_predict_dna_sequence_none_protein_sequence(self): + # Test predict_dna_sequence with None as protein sequence + protein_sequence = None + organism = "Escherichia coli general" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py new file mode 100644 index 0000000..c7faf11 --- /dev/null +++ b/tests/test_CodonUtils.py @@ -0,0 +1,89 @@ +import os +import pickle +import tempfile +import unittest +from CodonTransformer.CodonUtils import ( + load_python_object_from_disk, + save_python_object_to_disk, + find_pattern_in_fasta, + get_organism2id_dict, + get_taxonomy_id, + sort_amino2codon_skeleton, + load_pkl_from_url, +) + + +class TestCodonUtils(unittest.TestCase): + def test_load_python_object_from_disk(self): + test_obj = {"key1": "value1", "key2": 2} + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: + temp_file_name = temp_file.name + save_python_object_to_disk(test_obj, temp_file_name) + loaded_obj = load_python_object_from_disk(temp_file_name) + self.assertEqual(test_obj, loaded_obj) + os.remove(temp_file_name) + + def test_save_python_object_to_disk(self): + test_obj = [1, 2, 3, 4, 5] + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: + temp_file_name = temp_file.name + save_python_object_to_disk(test_obj, temp_file_name) + self.assertTrue(os.path.exists(temp_file_name)) + os.remove(temp_file_name) + + def test_find_pattern_in_fasta(self): + text = ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n>seq2 [keyword=value2]\nGGTACGATCGATCGAT" + self.assertEqual(find_pattern_in_fasta("keyword", text), "value1") + self.assertEqual(find_pattern_in_fasta("nonexistent", text), "") + + def test_get_organism2id_dict(self): + with tempfile.NamedTemporaryFile( + mode="w", delete=True, suffix=".csv" + ) as temp_file: + temp_file.write("0,Escherichia coli\n1,Homo sapiens\n2,Mus musculus") + temp_file.flush() + organism2id = get_organism2id_dict(temp_file.name) + self.assertEqual( + organism2id, + {"Escherichia coli": 0, "Homo sapiens": 1, "Mus musculus": 2}, + ) + + def test_get_taxonomy_id(self): + taxonomy_dict = { + "Escherichia coli": 562, + "Homo sapiens": 9606, + "Mus musculus": 10090, + } + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=True) as temp_file: + temp_file_name = temp_file.name + save_python_object_to_disk(taxonomy_dict, temp_file_name) + self.assertEqual(get_taxonomy_id(temp_file_name, "Escherichia coli"), 562) + self.assertEqual( + get_taxonomy_id(temp_file_name, return_dict=True), taxonomy_dict + ) + + def test_sort_amino2codon_skeleton(self): + amino2codon = { + "A": (["GCT", "GCC", "GCA", "GCG"], [0.0, 0.0, 0.0, 0.0]), + "C": (["TGT", "TGC"], [0.0, 0.0]), + } + sorted_amino2codon = sort_amino2codon_skeleton(amino2codon) + self.assertEqual( + sorted_amino2codon, + { + "A": (["GCA", "GCC", "GCG", "GCT"], [0.0, 0.0, 0.0, 0.0]), + "C": (["TGC", "TGT"], [0.0, 0.0]), + }, + ) + + def test_load_pkl_from_url(self): + url = "https://example.com/test.pkl" + expected_obj = {"key": "value"} + with unittest.mock.patch("requests.get") as mock_get: + mock_get.return_value.content = pickle.dumps(expected_obj) + loaded_obj = load_pkl_from_url(url) + self.assertEqual(loaded_obj, expected_obj) + + +if __name__ == "__main__": + unittest.main() From 694055cf21685db58f3c54016efd0e7c46d925df Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 18 Sep 2024 16:19:12 -0400 Subject: [PATCH 3/6] Require setuptools >=70.0.0 to fix security issue --- requirements.txt | 2 +- setup.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index db408c4..b02f5d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ python_codon_tables>=0.1.12,<1.0 pytorch_lightning>=2.2.1,<3.0 scikit-learn>=1.2.2,<2.0 scipy>=1.13.1,<3.0 -setuptools>=68.2.2,<70.0 +setuptools>=70.0.0 torch>=2.0.0,<3.0 tqdm>=4.66.2,<5.0 transformers>=4.40.0,<5.0 diff --git a/setup.py b/setup.py index 6741ad3..e4e9648 100644 --- a/setup.py +++ b/setup.py @@ -10,11 +10,12 @@ def read_requirements(): def read_readme(): here = os.path.abspath(os.path.dirname(__file__)) - readme_path = os.path.join(here, 'README.md') - - with open(readme_path, 'r', encoding='utf-8') as f: + readme_path = os.path.join(here, "README.md") + + with open(readme_path, "r", encoding="utf-8") as f: return f.read() + setup( name="CodonTransformer", version="1.5.2", From 04cdfad1f4837ab4d584dd136c2285b4b1a8c2a4 Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 18 Sep 2024 16:23:07 -0400 Subject: [PATCH 4/6] Update GitHub Actions to use Node16 --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 080d1e7..bd91e3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,10 +10,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: '3.10' @@ -30,7 +30,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml From 08edd7ffa30eb70f3b025eab63e093e0d08fbd7a Mon Sep 17 00:00:00 2001 From: gui11aume Date: Wed, 18 Sep 2024 16:28:18 -0400 Subject: [PATCH 5/6] Update GitHub Actions to use Node20 --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bd91e3e..26201c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,10 +10,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' @@ -30,7 +30,7 @@ jobs: coverage xml - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml From 372748d35e717c3bc1f8fa1e271df2451d553160 Mon Sep 17 00:00:00 2001 From: gui11aume Date: Thu, 19 Sep 2024 00:05:24 -0400 Subject: [PATCH 6/6] Fix issue #1 and add formatting hooks --- .gitignore | 2 +- .pre-commit-config.yaml | 34 +++++ CodonTransformer/CodonData.py | 81 ++++++----- CodonTransformer/CodonEvaluation.py | 31 ++--- CodonTransformer/CodonJupyter.py | 27 ++-- CodonTransformer/CodonPrediction.py | 202 ++++++++++++++++++---------- CodonTransformer/CodonUtils.py | 46 ++++--- README.md | 8 +- finetune.py | 12 +- pretrain.py | 16 +-- pyproject.toml | 53 ++++++++ requirements.txt | 2 +- setup.py | 9 +- slurm/finetune.sh | 2 +- slurm/pretrain.sh | 2 +- src/CodonTransformerTokenizer.json | 2 +- tests/test_CodonData.py | 6 +- tests/test_CodonJupyter.py | 6 +- tests/test_CodonPrediction.py | 100 +++++++++----- tests/test_CodonUtils.py | 12 +- 20 files changed, 433 insertions(+), 220 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 0a03dca..46f2a2f 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,4 @@ lightning_logs/ # PyTorch model weights *.pth -*.pt \ No newline at end of file +*.pt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e0c0e74 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,34 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +# Don't run pre-commit on files under third-party/ +exclude: "^\ + (third-party/.*)\ + " + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.1.0 + hooks: + - id: check-added-large-files # prevents giant files from being committed. + - id: check-case-conflict # checks for files that would conflict in case-insensitive filesystems. + - id: check-merge-conflict # checks for files that contain merge conflict strings. + - id: check-yaml # checks yaml files for parseable syntax. + - id: detect-private-key # detects the presence of private keys. + - id: end-of-file-fixer # ensures that a file is either empty, or ends with one newline. + - id: fix-byte-order-marker # removes utf-8 byte order marker. + - id: mixed-line-ending # replaces or checks mixed line ending. + - id: requirements-txt-fixer # sorts entries in requirements.txt. + - id: trailing-whitespace # trims trailing whitespace. + + - repo: https://github.com/sirosen/check-jsonschema + rev: 0.23.2 + hooks: + - id: check-github-actions + - id: check-github-workflows + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.13 + hooks: + - id: ruff + - id: ruff-format diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 5932ff8..b6f6d86 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -5,32 +5,30 @@ preparing the data for training and inference of the CodonTransformer model. """ -import os import json +import os +from typing import Dict, List, Optional, Tuple, Union + import pandas as pd +import python_codon_tables as pct +from Bio import SeqIO +from Bio.Seq import Seq from sklearn.utils import shuffle as sk_shuffle +from tqdm import tqdm from CodonTransformer.CodonUtils import ( + AMBIGUOUS_AMINOACID_MAP, + AMINO2CODON_TYPE, AMINO_ACIDS, + ORGANISM2ID, START_CODONS, STOP_CODONS, STOP_SYMBOL, - AMINO2CODON_TYPE, - AMBIGUOUS_AMINOACID_MAP, - ORGANISM2ID, find_pattern_in_fasta, - sort_amino2codon_skeleton, get_taxonomy_id, + sort_amino2codon_skeleton, ) -from Bio import SeqIO -from Bio.Seq import Seq - -import python_codon_tables as pct - -from typing import List, Dict, Tuple, Union, Optional -from tqdm import tqdm - def prepare_training_data( dataset: Union[str, pd.DataFrame], output_file: str, shuffle: bool = True @@ -50,7 +48,8 @@ def prepare_training_data( Args: dataset (Union[str, pd.DataFrame]): Input dataset in CSV or DataFrame format. output_file (str): Path to save the output JSON dataset. - shuffle (bool, optional): Whether to shuffle the dataset before saving. Defaults to True. + shuffle (bool, optional): Whether to shuffle the dataset before saving. + Defaults to True. Returns: None @@ -78,7 +77,7 @@ def prepare_training_data( def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) -> None: """ - Convert a pandas DataFrame to a JSON file format suitable for training CodonTransformer. + Convert pandas DataFrame to JSON file format suitable for training CodonTransformer. This function takes a preprocessed DataFrame and writes it to a JSON file where each line is a JSON object representing a single record. @@ -86,7 +85,8 @@ def dataframe_to_json(df: pd.DataFrame, output_file: str, shuffle: bool = True) Args: df (pd.DataFrame): The input DataFrame with 'codons' and 'organism' columns. output_file (str): Path to the output JSON file. - shuffle (bool, optional): Whether to shuffle the dataset before saving. Defaults to True. + shuffle (bool, optional): Whether to shuffle the dataset before saving. + Defaults to True. Returns: None @@ -123,8 +123,9 @@ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) It validates the input against a provided mapping of organism names to IDs. Args: - organism (Union[str, int]): The input organism, either as a name (str) or ID (int). - organism_to_id (Dict[str, int]): A dictionary mapping organism names to their corresponding IDs. + organism (Union[str, int]): Input organism, either as a name (str) or ID (int). + organism_to_id (Dict[str, int]): Dictionary mapping organism names to their + corresponding IDs. Returns: int: The validated organism ID. @@ -150,7 +151,8 @@ def process_organism(organism: Union[str, int], organism_to_id: Dict[str, int]) def preprocess_protein_sequence(protein: str) -> str: """ - Preprocess a protein sequence by cleaning, standardizing, and handling ambiguous amino acids. + Preprocess a protein sequence by cleaning, standardizing, and handling + ambiguous amino acids. Args: protein (str): The input protein sequence. @@ -221,7 +223,8 @@ def replace_ambiguous_codons(dna: str) -> str: def preprocess_dna_sequence(dna: str) -> str: """ - Cleans and preprocesses a DNA sequence by standardizing it and replacing ambiguous codons. + Cleans and preprocesses a DNA sequence by standardizing it and replacing + ambiguous codons. Args: dna (str): The DNA sequence to preprocess. @@ -247,8 +250,9 @@ def preprocess_dna_sequence(dna: str) -> str: def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: """ - Return the merged sequence of protein amino acids and DNA codons in the form of tokens - separated by space, where each token is composed of an amino acid + separator + codon. + Return the merged sequence of protein amino acids and DNA codons in the form + of tokens separated by space, where each token is composed of an amino acid + + separator + codon. Args: protein (str): Protein sequence. @@ -274,8 +278,9 @@ def get_merged_seq(protein: str, dna: str = "", separator: str = "_") -> str: # Check if the length of protein and dna sequences are equal if len(dna) > 0 and len(protein) != len(dna) / 3: raise ValueError( - 'Length of protein (including stop symbol such as "_") and \ - the number of codons in DNA sequence (including stop codon) must be equal.' + 'Length of protein (including stop symbol such as "_") and ' + "the number of codons in DNA sequence (including stop codon) " + "must be equal." ) # Merge protein and DNA sequences into tokens @@ -331,8 +336,8 @@ def get_amino_acid_sequence( return_correct_seq (bool): Whether to return if the sequence is correct. Returns: - Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if return_correct_seq is True, - otherwise just the protein sequence. + Union[str, Tuple[str, bool]]: Protein sequence and correctness flag if + return_correct_seq is True, otherwise just the protein sequence. """ dna_seq = Seq(dna).strip() @@ -365,12 +370,15 @@ def read_fasta_file( Args: input_file (str): Path to the input FASTA file. - save_to_file (Optional[str]): Path to save the output DataFrame. If None, data is only returned. - organism (str): Name of the organism. If empty, it will be extracted from the FASTA description. + save_to_file (Optional[str]): Path to save the output DataFrame. If None, + data is only returned. + organism (str): Name of the organism. If empty, it will be extracted from + the FASTA description. buffer_size (int): Number of records to process before writing to file. Returns: - pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe is True, else None. + pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe + is True, else None. Raises: FileNotFoundError: If the input file does not exist. @@ -498,7 +506,8 @@ def download_codon_frequencies_from_kazusa( def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: """ - Return the empty skeleton of the amino2codon dictionary, needed for get_codon_frequencies. + Return the empty skeleton of the amino2codon dictionary, needed for + get_codon_frequencies. Args: organism (str): Name of the organism. @@ -514,7 +523,8 @@ def build_amino2codon_skeleton(organism: str) -> AMINO2CODON_TYPE: return_correct_seq=False, ) - # Initialize the amino2codon skeleton with all possible codons and set their frequencies to 0 + # Initialize the amino2codon skeleton with all possible codons and set their + # frequencies to 0 for i, (codon, amino) in enumerate(zip(possible_codons, possible_aminoacids)): if amino not in amino2codon: amino2codon[amino] = ([], []) @@ -543,7 +553,8 @@ def get_codon_frequencies( organism (Optional[str]): Name of the organism. Returns: - AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons and frequencies. + AMINO2CODON_TYPE: Dictionary mapping each amino acid to a tuple of codons + and frequencies. """ if organism: codon_table = get_codon_table(organism) @@ -583,7 +594,8 @@ def get_organism_to_codon_frequencies( organisms (List[str]): List of organisms. Returns: - Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon frequency distribution. + Dict[str, AMINO2CODON_TYPE]: Dictionary mapping each organism to its codon + frequency distribution. """ organism2frequencies = {} @@ -617,7 +629,8 @@ def get_codon_table(organism: str) -> int: "Arabidopsis thaliana", "Caenorhabditis elegans", "Chlamydomonas reinhardtii", - "Saccharomyces cerevisiae" "Danio rerio", + "Saccharomyces cerevisiae", + "Danio rerio", "Drosophila melanogaster", "Homo sapiens", "Mus musculus", diff --git a/CodonTransformer/CodonEvaluation.py b/CodonTransformer/CodonEvaluation.py index e94423a..d09ba92 100644 --- a/CodonTransformer/CodonEvaluation.py +++ b/CodonTransformer/CodonEvaluation.py @@ -1,18 +1,17 @@ """ File: CodonEvaluation.py --------------------------- -Includes functions to calculate various evaluation metrics along with helper functions. +Includes functions to calculate various evaluation metrics along with helper +functions. """ -import pandas as pd +from typing import Dict, List, Tuple +import pandas as pd from CAI import CAI, relative_adaptiveness - -from typing import List, Dict, Tuple from tqdm import tqdm - def get_CSI_weights(sequences: List[str]) -> Dict[str, float]: """ Calculate the Codon Similarity Index (CSI) weights for a list of DNA sequences. @@ -47,7 +46,7 @@ def get_organism_to_CSI_weights( Calculate the Codon Similarity Index (CSI) weights for a list of organisms. Args: - dataset (pd.DataFrame): The dataset containing organism and DNA sequence information. + dataset (pd.DataFrame): Dataset containing organism and DNA sequence info. organisms (List[str]): List of organism names. Returns: @@ -91,7 +90,8 @@ def get_cfd( Args: dna (str): The DNA sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequency distribution per amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequency distribution per amino acid. threshold (float): Frequency threshold for counting rare codons. Returns: @@ -127,7 +127,8 @@ def get_min_max_percentage( Args: dna (str): The DNA sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequency distribution per amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequency distribution per amino acid. window_size (int): Size of the window to calculate %MinMax. Returns: @@ -147,14 +148,12 @@ def get_min_max_percentage( # Iterate through the DNA sequence using the specified window size for i in range(len(codons) - window_size + 1): - codon_window = codons[ - i : i + window_size - ] # List of the codons in the current window + codon_window = codons[i : i + window_size] # Codons in the current window Actual = 0.0 # Average of the actual codon frequencies Max = 0.0 # Average of the min codon frequencies Min = 0.0 # Average of the max codon frequencies - Avg = 0.0 # Average of the averages of all the frequencies associated with each amino acid + Avg = 0.0 # Average of the averages of all frequencies for each amino acid # Sum the frequencies for codons in the current window for codon in codon_window: @@ -210,7 +209,7 @@ def sum_up_to(x): return x + sum_up_to(x - 1) def f(x): - """Function that returns 4 if x is greater than or equal to 4, else returns x.""" + """Returns 4 if x is greater than or equal to 4, else returns x.""" if x >= 4: return 4 elif x < 4: @@ -242,8 +241,10 @@ def get_sequence_similarity( Args: original (str): The original sequence. predicted (str): The predicted sequence. - truncate (bool): If True, truncate the original sequence to match the length of the predicted sequence. - window_length (int): Length of the window for comparison (1 for amino acids, 3 for codons). + truncate (bool): If True, truncate the original sequence to match the length + of the predicted sequence. + window_length (int): Length of the window for comparison (1 for amino acids, + 3 for codons). Returns: float: The sequence similarity as a percentage. diff --git a/CodonTransformer/CodonJupyter.py b/CodonTransformer/CodonJupyter.py index 6c730cd..fdf7164 100644 --- a/CodonTransformer/CodonJupyter.py +++ b/CodonTransformer/CodonJupyter.py @@ -7,13 +7,13 @@ from typing import Dict, List, Tuple import ipywidgets as widgets -from IPython.display import display, HTML +from IPython.display import HTML, display from CodonTransformer.CodonUtils import ( - DNASequencePrediction, COMMON_ORGANISMS, - ORGANISM2ID, ID2ORGANISM, + ORGANISM2ID, + DNASequencePrediction, ) @@ -49,11 +49,11 @@ def create_styled_options( organism_id = organism2id[organism] if is_fine_tuned: if organism_id < 10: - styled_options.append(f"\u200B{organism_id:>6}. {organism}") + styled_options.append(f"\u200b{organism_id:>6}. {organism}") elif organism_id < 100: - styled_options.append(f"\u200B{organism_id:>5}. {organism}") + styled_options.append(f"\u200b{organism_id:>5}. {organism}") else: - styled_options.append(f"\u200B{organism_id:>4}. {organism}") + styled_options.append(f"\u200b{organism_id:>4}. {organism}") else: if organism_id < 10: styled_options.append(f"{organism_id:>6}. {organism}") @@ -160,7 +160,7 @@ def get_dropdown_style() -> str: flex-direction: column; align-items: flex-start; } - .widget-dropdown option[value^="\u200B"] { + .widget-dropdown option[value^="\u200b"] { font-family: sans-serif; font-weight: bold; font-size: 18px; @@ -188,7 +188,8 @@ def display_organism_dropdown(container: UserContainer) -> None: """ dropdown = create_organism_dropdown(container) header = widgets.HTML( - 'Select Organism:
' + 'Select Organism:' + '
' ) container_widget = widgets.VBox( [header, dropdown], @@ -242,7 +243,8 @@ def save_protein(change: Dict[str, str]) -> None: Save the input protein sequence to the container. Args: - change (Dict[str, str]): A dictionary containing information about the change in textarea value. + change (Dict[str, str]): A dictionary containing information about + the change in textarea value. """ container.protein = ( change["new"] @@ -258,7 +260,8 @@ def save_protein(change: Dict[str, str]) -> None: # Display the input widget header = widgets.HTML( - 'Enter Protein Sequence:
' + 'Enter Protein Sequence:' + '
' ) container_widget = widgets.VBox( [header, protein_input], layout=widgets.Layout(padding="12px 12px 0 25px") @@ -270,13 +273,13 @@ def save_protein(change: Dict[str, str]) -> None: def format_model_output(output: DNASequencePrediction) -> str: """ - Format the DNA sequence prediction output in a visually appealing and easy-to-read manner. + Format DNA sequence prediction output in an appealing and easy-to-read manner. This function takes the prediction output and formats it into a structured string with clear section headers and separators. Args: - output (DNASequencePrediction): An object containing the prediction output. + output (DNASequencePrediction): Object containing the prediction output. Expected attributes: - organism (str): The organism name. - protein (str): The input protein sequence. diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 7df8520..a4d5434 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -5,27 +5,28 @@ helper functions related to processing data for passing to the model. """ -from typing import Any, List, Dict, Tuple, Optional, Union -import onnxruntime as rt +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np +import onnxruntime as rt import torch import transformers from transformers import ( + AutoTokenizer, BatchEncoding, - PreTrainedTokenizerFast, BigBirdConfig, - AutoTokenizer, BigBirdForMaskedLM, + PreTrainedTokenizerFast, ) -import numpy as np from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( AMINO_ACIDS, - ORGANISM2ID, - TOKEN2INDEX, INDEX2TOKEN, NUM_ORGANISMS, + ORGANISM2ID, + TOKEN2INDEX, DNASequencePrediction, ) @@ -37,6 +38,7 @@ def predict_dna_sequence( tokenizer: Union[str, PreTrainedTokenizerFast] = None, model: Union[str, torch.nn.Module] = None, attention_type: str = "original_full", + deterministic: bool = True, ) -> DNASequencePrediction: """ Predict the DNA sequence for a given protein using CodonTransformer model. @@ -47,30 +49,38 @@ def predict_dna_sequence( Args: protein (str): The input protein sequence to predict the DNA sequence for. - organism (Union[int, str]): Either the ID of the organism or its name (e.g., "Escherichia coli general"). - If a string is provided, it will be converted to the corresponding ID using ORGANISM2ID. + organism (Union[int, str]): Either the ID of the organism or its name (e.g., + "Escherichia coli general"). If a string is provided, it will be converted + to the corresponding ID using ORGANISM2ID. device (torch.device): The device (CPU or GPU) to run the model on. - tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file path to load the tokenizer from, - a pre-loaded tokenizer object, or None. If None, it will be loaded from HuggingFace. Defaults to None. - model (Union[str, torch.nn.Module, None], optional): Either a file path to load the model from, - a pre-loaded model object, or None. If None, it will be loaded from HuggingFace. Defaults to None. - attention_type (str, optional): The type of attention mechanism to use in the model. - Can be either 'block_sparse' or 'original_full'. Defaults to "original_full". + tokenizer (Union[str, PreTrainedTokenizerFast, None], optional): Either a file + path to load the tokenizer from, a pre-loaded tokenizer object, or None. If + None, it will be loaded from HuggingFace. Defaults to None. + model (Union[str, torch.nn.Module, None], optional): Either a file path to load + the model from, a pre-loaded model object, or None. If None, it will be + loaded from HuggingFace. Defaults to None. + attention_type (str, optional): The type of attention mechanism to use in the + model. Can be either 'block_sparse' or 'original_full'. Defaults to + "original_full". + deterministic (bool, optional): Whether to use deterministic decoding (most + likely tokens). If False, samples tokens according to their probabilities. + Defaults to True. Returns: DNASequencePrediction: An object containing the prediction results: - - organism (str): The name of the organism used for prediction. - - protein (str): The input protein sequence for which DNA sequence is predicted. - - processed_input (str): The processed input sequence (merged protein and DNA). - - predicted_dna (str): The predicted DNA sequence. + - organism (str): Name of the organism used for prediction. + - protein (str): Input protein sequence for which DNA sequence is predicted. + - processed_input (str): Processed input sequence (merged protein and DNA). + - predicted_dna (str): Predicted DNA sequence. Raises: ValueError: If the protein sequence is empty or if the organism is invalid. Note: - This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from CodonTransformer.CodonUtils. - ORGANISM2ID maps organism names to their corresponding IDs. - INDEX2TOKEN maps model output indices (token ids) to respective codons. + This function uses ORGANISM2ID and INDEX2TOKEN dictionaries imported from + CodonTransformer.CodonUtils. ORGANISM2ID maps organism names to their + corresponding IDs. INDEX2TOKEN maps model output indices (token ids) to + respective codons. Example: >>> import torch @@ -83,30 +93,41 @@ def predict_dna_sequence( >>> >>> # Load tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") - >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device) + >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") + >>> model = model.to(device) >>> >>> # Define protein sequence and organism - >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" + >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA" >>> organism = "Escherichia coli general" >>> - >>> # Predict DNA sequence + >>> # Predict DNA sequence with deterministic decoding >>> output = predict_dna_sequence( ... protein=protein, ... organism=organism, ... device=device, ... tokenizer=tokenizer, ... model=model, - ... attention_type="original_full" + ... attention_type="original_full", + ... deterministic=True + ... ) + >>> + >>> # Predict DNA sequence with probabilistic sampling + >>> output_sampler = predict_dna_sequence( + ... protein=protein, + ... organism=organism, + ... device=device, + ... tokenizer=tokenizer, + ... model=model, + ... attention_type="original_full", + ... deterministic=False ... ) >>> >>> print(format_model_output(output)) + >>> print(format_model_output(output_sampler)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") - if not isinstance(protein, str): - raise ValueError("Protein sequence must be a string.") - # Test that the input protein sequence contains only valid amino acids if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): raise ValueError("Invalid amino acid found in protein sequence.") @@ -139,12 +160,23 @@ def predict_dna_sequence( # Get the model predictions output_dict = model(**tokenized_input, return_dict=True) - output = output_dict.logits.detach().cpu().numpy() + logits = output_dict.logits.detach().cpu() # Decode the predicted DNA sequence from the model output - predicted_dna = list( - map(INDEX2TOKEN.__getitem__, output.argmax(axis=-1).squeeze().tolist()) - ) + if deterministic: + # Select the most probable tokens (argmax) + predicted_indices = logits.argmax(dim=-1).squeeze().tolist() + else: + # Sample tokens according to their probability distribution + # Convert logits to probabilities using softmax + probabilities = torch.softmax(logits, dim=-1) + # Sample from the probability distribution at each position + probabilities = probabilities.squeeze(0) # Shape: [seq_len, vocab_size] + predicted_indices = ( + torch.multinomial(probabilities, num_samples=1).squeeze(-1).tolist() + ) + + predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices)) # Skip special tokens [CLS] and [SEP] to create the predicted_dna predicted_dna = ( @@ -170,17 +202,21 @@ def load_model( Load a BigBirdForMaskedLM model from a model file, checkpoint, or HuggingFace. Args: - model_path (Optional[str]): Path to the model file or checkpoint. If None, load from HuggingFace. + model_path (Optional[str]): Path to the model file or checkpoint. If None, + load from HuggingFace. device (torch.device, optional): The device to load the model onto. - attention_type (str, optional): The type of attention, 'block_sparse' or 'original_full'. - num_organisms (int, optional): Number of organisms, needed if loading from a checkpoint that requires this. - remove_prefix (bool, optional): Whether to remove the "model." prefix from the keys in the state dict. + attention_type (str, optional): The type of attention, 'block_sparse' + or 'original_full'. + num_organisms (int, optional): Number of organisms, needed if loading from a + checkpoint that requires this. + remove_prefix (bool, optional): Whether to remove the "model." prefix from the + keys in the state dict. Returns: torch.nn.Module: The loaded model. """ if not model_path: - print("Warning: Model path not provided. Loading from HuggingFace.") + warnings.warn("Model path not provided. Loading from HuggingFace.", UserWarning) model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer") elif model_path.endswith(".ckpt"): @@ -209,7 +245,8 @@ def load_model( else: raise ValueError( - "Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace." + "Unsupported file type. Please provide a .ckpt or .pt file, " + "or None to load from HuggingFace." ) # Prepare model for evaluation @@ -260,16 +297,19 @@ def create_model_from_checkpoint( def load_tokenizer(tokenizer_path: Optional[str] = None) -> PreTrainedTokenizerFast: """ - Create and return a tokenizer object from the given tokenizer path or HuggingFace. + Create and return a tokenizer object from tokenizer path or HuggingFace. Args: - tokenizer_path (Optional[str]): Path to the tokenizer file. If None, load from HuggingFace. + tokenizer_path (Optional[str]): Path to the tokenizer file. If None, + load from HuggingFace. Returns: PreTrainedTokenizerFast: The tokenizer object. """ if not tokenizer_path: - print("Warning: Tokenizer path not provided. Loading from HuggingFace.") + warnings.warn( + "Tokenizer path not provided. Loading from HuggingFace.", UserWarning + ) return AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") return transformers.PreTrainedTokenizerFast( @@ -291,11 +331,14 @@ def tokenize( ) -> BatchEncoding: """ Return the tokenized sequences given a batch of input data. - Each data in the batch is expected to be a dictionary with "codons" and "organism" keys. + Each data in the batch is expected to be a dictionary with "codons" and + "organism" keys. Args: - batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and "organism" keys. - tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or path to the tokenizer file. + batch (List[Dict[str, Any]]): A list of dictionaries with "codons" and + "organism" keys. + tokenizer (PreTrainedTokenizerFast, str, optional): The tokenizer object or + path to the tokenizer file. max_len (int, optional): Maximum length of the tokenized sequence. Returns: @@ -326,28 +369,32 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: """ Validate and convert the organism input to both ID and name. - This function takes either an organism ID or name as input and returns both the ID and name. - It performs validation to ensure the input corresponds to a valid organism in the ORGANISM2ID dictionary. + This function takes either an organism ID or name as input and returns both + the ID and name. It performs validation to ensure the input corresponds to + a valid organism in the ORGANISM2ID dictionary. Args: - organism (Union[int, str]): Either the ID of the organism (int) or its name (str). + organism (Union[int, str]): Either the ID of the organism (int) or its + name (str). Returns: Tuple[int, str]: A tuple containing the organism ID (int) and name (str). Raises: - ValueError: If the input is neither a string nor an integer, if the organism name is not found - in ORGANISM2ID, if the organism ID is not a value in ORGANISM2ID, or if no name - is found for a given ID. + ValueError: If the input is neither a string nor an integer, if the + organism name is not found in ORGANISM2ID, if the organism ID is not a + value in ORGANISM2ID, or if no name is found for a given ID. Note: - This function relies on the ORGANISM2ID dictionary imported from CodonTransformer.CodonUtils, - which maps organism names to their corresponding IDs. + This function relies on the ORGANISM2ID dictionary imported from + CodonTransformer.CodonUtils, which maps organism names to their + corresponding IDs. """ if isinstance(organism, str): if organism not in ORGANISM2ID: raise ValueError( - f"Invalid organism name: {organism}. Please use a valid organism name or ID." + f"Invalid organism name: {organism}. " + "Please use a valid organism name or ID." ) organism_id = ORGANISM2ID[organism] organism_name = organism @@ -355,7 +402,8 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: elif isinstance(organism, int): if organism not in ORGANISM2ID.values(): raise ValueError( - f"Invalid organism ID: {organism}. Please use a valid organism name or ID." + f"Invalid organism ID: {organism}. " + "Please use a valid organism name or ID." ) organism_id = organism @@ -365,9 +413,6 @@ def validate_and_convert_organism(organism: Union[int, str]) -> Tuple[int, str]: if organism_name is None: raise ValueError(f"No organism name found for ID: {organism}") - else: - raise ValueError("Organism must be either a string (name) or an integer (ID).") - return organism_id, organism_name @@ -375,12 +420,13 @@ def get_high_frequency_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using High Frequency Choice (HFC) approach in which - the most frequent codon for a given amino acid is always chosen. + Return the DNA sequence optimized using High Frequency Choice (HFC) approach + in which the most frequent codon for a given amino acid is always chosen. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -400,7 +446,8 @@ def precompute_most_frequent_codons( Precompute the most frequent codon for each amino acid. Args: - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: Dict[str, str]: The most frequent codon for each amino acid. @@ -421,7 +468,8 @@ def get_high_frequency_choice_sequence_optimized( Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -436,17 +484,20 @@ def get_background_frequency_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using Background Frequency Choice (BFC) approach in which - a random codon for a given amino acid is chosen using the codon frequencies probability distribution. + Return the DNA sequence optimized using Background Frequency Choice (BFC) + approach in which a random codon for a given amino acid is chosen using + the codon frequencies probability distribution. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. """ - # Select a random codon for each amino acid based on the codon frequencies probability distribution + # Select a random codon for each amino acid based on the codon frequencies + # probability distribution dna_codons = [ np.random.choice( codon_frequencies[aminoacid][0], p=codon_frequencies[aminoacid][1] @@ -463,7 +514,8 @@ def precompute_cdf( Precompute the cumulative distribution function (CDF) for each amino acid. Args: - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: Dict[str, Tuple[List[str], Any]]: CDFs for each amino acid. @@ -486,7 +538,8 @@ def get_background_frequency_choice_sequence_optimized( Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -507,12 +560,14 @@ def get_uniform_random_choice_sequence( protein: str, codon_frequencies: Dict[str, Tuple[List[str], List[float]]] ) -> str: """ - Return the DNA sequence optimized using Uniform Random Choice (URC) approach in which - a random codon for a given amino acid is chosen using a uniform prior. + Return the DNA sequence optimized using Uniform Random Choice (URC) approach + in which a random codon for a given amino acid is chosen using a uniform + prior. Args: protein (str): The protein sequence. - codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon frequencies for each amino acid. + codon_frequencies (Dict[str, Tuple[List[str], List[float]]]): Codon + frequencies for each amino acid. Returns: str: The optimized DNA sequence. @@ -529,7 +584,8 @@ def get_icor_prediction(input_seq: str, model_path: str, stop_symbol: str) -> st Return the optimized codon sequence for the given protein sequence using ICOR. Credit: ICOR: improving codon optimization with recurrent neural networks - Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas Densmore + Rishab Jain, Aditya Jain, Elizabeth Mauro, Kevin LeShane, Douglas + Densmore Args: input_seq (str): The input protein sequence. diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index cf84b24..dd91e3f 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -4,18 +4,16 @@ Includes constants and helper functions used by other Python scripts. """ +import itertools import os -import re import pickle -import requests -import itertools - -import torch -import pandas as pd +import re from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple -from typing import Any, List, Dict, Tuple, Optional, Iterator - +import pandas as pd +import requests +import torch # List of all amino acids AMINO_ACIDS: List[str] = [ @@ -470,10 +468,10 @@ class DNASequencePrediction: A class to hold the output of the DNA sequence prediction. Attributes: - organism (str): The name of the organism used for prediction. - protein (str): The input protein sequence for which DNA sequence is predicted. - processed_input (str): The processed input sequence (merged protein and DNA). - predicted_dna (str): The predicted DNA sequence. + organism (str): Name of the organism used for prediction. + protein (str): Input protein sequence for which DNA sequence is predicted. + processed_input (str): Processed input sequence (merged protein and DNA). + predicted_dna (str): Predicted DNA sequence. """ organism: str @@ -488,7 +486,8 @@ class IterableData(torch.utils.data.IterableDataset): data) in parallel multi-processing environments, e.g., multi-GPU. Args: - dist_env (Optional[str]): The distribution environment identifier (e.g., "slurm"). + dist_env (Optional[str]): The distribution environment identifier + (e.g., "slurm"). Credit: Guillaume Filion """ @@ -501,7 +500,7 @@ def __init__(self, dist_env: Optional[str] = None): @property def iterator(self) -> Iterator: - """Define the stream logic for the dataset. Should be implemented in subclasses.""" + """Define the stream logic for the dataset. Implement in subclasses.""" raise NotImplementedError def __iter__(self) -> Iterator: @@ -573,7 +572,8 @@ def save_python_object_to_disk(input_object: Any, file_path: str) -> None: def find_pattern_in_fasta(keyword: str, text: str) -> str: """ - Find a specific keyword pattern in text. Helpful for identifying parts of a FASTA sequence. + Find a specific keyword pattern in text. Helpful for identifying parts + of a FASTA sequence. Args: keyword (str): The keyword pattern to search for. @@ -589,17 +589,19 @@ def find_pattern_in_fasta(keyword: str, text: str) -> str: def get_organism2id_dict(organism_reference: str) -> Dict[str, int]: """ - Return a dictionary mapping each organism in training data to an index used for training. + Return a dictionary mapping each organism in training data to an index + used for training. Args: - organism_reference (str): Path to a CSV file containing a list of all organisms. - The format of the CSV file should be as follows: - 0,Escherichia coli - 1,Homo sapiens - 2,Mus musculus + organism_reference (str): Path to a CSV file containing a list of + all organisms. The format of the CSV file should be as follows: + + 0,Escherichia coli + 1,Homo sapiens + 2,Mus musculus Returns: - Dict[str, int]: A dictionary mapping organism names to their respective indices. + Dict[str, int]: Dictionary mapping organism names to their respective indices. """ # Read the CSV file and create a dictionary mapping organisms to their indices organisms = pd.read_csv(organism_reference, index_col=0, header=None) diff --git a/README.md b/README.md index e1b4304..6b0b7a3 100644 --- a/README.md +++ b/README.md @@ -115,14 +115,14 @@ The package requires `python>=3.9`. The requirements are [availabe here](require To finetune CodonTransformer on your own data, follow these steps: 1. **Prepare your dataset** - + Create a CSV file with the following columns: - `dna`: DNA sequences (string, preferably uppercase ATCG) - `protein`: Protein sequences (string, preferably uppercase amino acid letters) - `organism`: Target organism (string or int, must be from `ORGANISM2ID` in `CodonUtils`) - Note: + Note: - Use organisms from the `FINE_TUNE_ORGANISMS` list for best results. - For E. coli, use `Escherichia coli general`. - DNA sequences should ideally contain only A, T, C, and G. Ambiguous codons are replaced with 'UNK' for tokenization. @@ -132,7 +132,7 @@ To finetune CodonTransformer on your own data, follow these steps:
2. **Prepare training data** - + Use the `prepare_training_data` function from `CodonData` to prepare training data from your dataset. ```python @@ -142,7 +142,7 @@ To finetune CodonTransformer on your own data, follow these steps:
3. **Run the finetuning script** - + Execute finetune.py with appropriate arguments: ```bash python finetune.py \ diff --git a/finetune.py b/finetune.py index dfc9dc3..fce46f3 100644 --- a/finetune.py +++ b/finetune.py @@ -4,22 +4,22 @@ Finetune the CodonTransformer model. The pretrained model is loaded directly from Hugging Face. -The dataset is a JSON file. You can use prepare_training_data from CodonData to prepare the dataset. -The repository Readme has a guide on how to prepare the dataset and use this script. +The dataset is a JSON file. You can use prepare_training_data from CodonData to +prepare the dataset. The repository README has a guide on how to prepare the +dataset and use this script. """ -import os -import torch import argparse +import os import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader - from transformers import AutoTokenizer, BigBirdForMaskedLM from CodonTransformer.CodonUtils import ( - TOKEN2MASK, MAX_LEN, + TOKEN2MASK, IterableJSONData, ) diff --git a/pretrain.py b/pretrain.py index bafd738..1b253cc 100644 --- a/pretrain.py +++ b/pretrain.py @@ -3,23 +3,23 @@ ------------------- Pretrain the CodonTransformer model. -The dataset is a JSON file. You can use prepare_training_data from CodonData to prepare the dataset. -The repository Readme has a guide on how to prepare the dataset and use this script. +The dataset is a JSON file. You can use prepare_training_data from CodonData to +prepare the dataset. The repository README has a guide on how to prepare the +dataset and use this script. """ -import os -import torch import argparse +import os import pytorch_lightning as pl +import torch from torch.utils.data import DataLoader - -from transformers import PreTrainedTokenizerFast, BigBirdConfig, BigBirdForMaskedLM +from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast from CodonTransformer.CodonUtils import ( - TOKEN2MASK, - NUM_ORGANISMS, MAX_LEN, + NUM_ORGANISMS, + TOKEN2MASK, IterableJSONData, ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e74248e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,53 @@ +[tool.poetry] +name = "CodonTransformer" +version = "1.5.2" +description = "The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms." +authors = ["Adibvafa Fallahpour "] +license = "Apache-2.0" +readme = "README.md" +homepage = "https://github.com/adibvafa/CodonTransformer" +repository = "https://github.com/adibvafa/CodonTransformer" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] + +[tool.poetry.dependencies] +python = "^3.9" +biopython = "^1.83" +ipywidgets = "^7.0.0" +numpy = "^1.26.4" +onnxruntime = "^1.17.3" +pandas = "^2.0.0" +python_codon_tables = "^0.1.12" +pytorch_lightning = "^2.2.1" +scikit-learn = "^1.2.2" +scipy = "^1.13.1" +setuptools = "^70.0.0" +torch = "^2.0.0" +tqdm = "^4.66.2" +transformers = "^4.40.0" +CAI-PyPI = "^2.0.1" + +[tool.poetry.dev-dependencies] +coverage = "^7.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 88 +indent-width = 4 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/requirements.txt b/requirements.txt index b02f5d5..61b3643 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ biopython>=1.83,<2.0 +CAI-PyPI>=2.0.1,<3.0 ipywidgets>=7.0.0,<10.0 numpy>=1.26.4,<3.0 onnxruntime>=1.17.3,<3.0 @@ -11,4 +12,3 @@ setuptools>=70.0.0 torch>=2.0.0,<3.0 tqdm>=4.66.2,<5.0 transformers>=4.40.0,<5.0 -CAI-PyPI>=2.0.1,<3.0 \ No newline at end of file diff --git a/setup.py b/setup.py index e4e9648..6ca0e55 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # setup.py import os -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def read_requirements(): @@ -23,7 +24,11 @@ def read_readme(): install_requires=read_requirements(), author="Adibvafa Fallahpour", author_email="Adibvafa.fallahpour@mail.utoronto.ca", - description="The ultimate tool for codon optimization, transforming protein sequences into optimized DNA sequences specific for your target organisms.", + description=( + "The ultimate tool for codon optimization, " + "transforming protein sequences into optimized DNA sequences " + "specific for your target organisms." + ), long_description=read_readme(), long_description_content_type="text/markdown", url="https://github.com/adibvafa/CodonTransformer", diff --git a/slurm/finetune.sh b/slurm/finetune.sh index dbaffba..980be12 100644 --- a/slurm/finetune.sh +++ b/slurm/finetune.sh @@ -35,4 +35,4 @@ stdbuf -oL -eL srun python finetune.py \ --learning_rate 0.00005 \ --warmup_fraction 0.1 \ --save_every_n_steps 512 \ - --seed 123 \ No newline at end of file + --seed 123 diff --git a/slurm/pretrain.sh b/slurm/pretrain.sh index 0935499..50d31b2 100644 --- a/slurm/pretrain.sh +++ b/slurm/pretrain.sh @@ -34,4 +34,4 @@ stdbuf -oL -eL srun python pretrain.py \ --learning_rate 0.00005 \ --warmup_fraction 0.1 \ --save_interval 5 \ - --seed 123 \ No newline at end of file + --seed 123 diff --git a/src/CodonTransformerTokenizer.json b/src/CodonTransformerTokenizer.json index 0d27966..7db063d 100644 --- a/src/CodonTransformerTokenizer.json +++ b/src/CodonTransformerTokenizer.json @@ -1 +1 @@ -{"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "special": true, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 1, "special": true, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 2, "special": true, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 3, "special": true, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 4, "special": true, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}], "normalizer": {"type": "Sequence", "normalizers": [{"type": "Lowercase"}]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"String": " "}, "behavior": "Isolated", "invert": false}, {"type": "Whitespace"}]}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": null, "model": {"type": "WordPiece", "unk_token": "[UNK]", "continuing_subword_prefix": "##", "max_input_chars_per_word": 100, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "a_unk": 5, "c_unk": 6, "d_unk": 7, "e_unk": 8, "f_unk": 9, "g_unk": 10, "h_unk": 11, "i_unk": 12, "k_unk": 13, "l_unk": 14, "m_unk": 15, "n_unk": 16, "p_unk": 17, "q_unk": 18, "r_unk": 19, "s_unk": 20, "t_unk": 21, "v_unk": 22, "w_unk": 23, "y_unk": 24, "__unk": 25, "k_aaa": 26, "n_aac": 27, "k_aag": 28, "n_aat": 29, "t_aca": 30, "t_acc": 31, "t_acg": 32, "t_act": 33, "r_aga": 34, "s_agc": 35, "r_agg": 36, "s_agt": 37, "i_ata": 38, "i_atc": 39, "m_atg": 40, "i_att": 41, "q_caa": 42, "h_cac": 43, "q_cag": 44, "h_cat": 45, "p_cca": 46, "p_ccc": 47, "p_ccg": 48, "p_cct": 49, "r_cga": 50, "r_cgc": 51, "r_cgg": 52, "r_cgt": 53, "l_cta": 54, "l_ctc": 55, "l_ctg": 56, "l_ctt": 57, "e_gaa": 58, "d_gac": 59, "e_gag": 60, "d_gat": 61, "a_gca": 62, "a_gcc": 63, "a_gcg": 64, "a_gct": 65, "g_gga": 66, "g_ggc": 67, "g_ggg": 68, "g_ggt": 69, "v_gta": 70, "v_gtc": 71, "v_gtg": 72, "v_gtt": 73, "__taa": 74, "y_tac": 75, "__tag": 76, "y_tat": 77, "s_tca": 78, "s_tcc": 79, "s_tcg": 80, "s_tct": 81, "__tga": 82, "c_tgc": 83, "w_tgg": 84, "c_tgt": 85, "l_tta": 86, "f_ttc": 87, "l_ttg": 88, "f_ttt": 89}}} \ No newline at end of file +{"version": "1.0", "truncation": null, "padding": null, "added_tokens": [{"id": 0, "special": true, "content": "[UNK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 1, "special": true, "content": "[CLS]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 2, "special": true, "content": "[SEP]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 3, "special": true, "content": "[PAD]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}, {"id": 4, "special": true, "content": "[MASK]", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false}], "normalizer": {"type": "Sequence", "normalizers": [{"type": "Lowercase"}]}, "pre_tokenizer": {"type": "Sequence", "pretokenizers": [{"type": "Split", "pattern": {"String": " "}, "behavior": "Isolated", "invert": false}, {"type": "Whitespace"}]}, "post_processor": {"type": "TemplateProcessing", "single": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}], "pair": [{"SpecialToken": {"id": "[CLS]", "type_id": 0}}, {"Sequence": {"id": "A", "type_id": 0}}, {"SpecialToken": {"id": "[SEP]", "type_id": 0}}, {"Sequence": {"id": "B", "type_id": 1}}, {"SpecialToken": {"id": "[SEP]", "type_id": 1}}], "special_tokens": {"[CLS]": {"id": "[CLS]", "ids": [1], "tokens": ["[CLS]"]}, "[SEP]": {"id": "[SEP]", "ids": [2], "tokens": ["[SEP]"]}}}, "decoder": null, "model": {"type": "WordPiece", "unk_token": "[UNK]", "continuing_subword_prefix": "##", "max_input_chars_per_word": 100, "vocab": {"[UNK]": 0, "[CLS]": 1, "[SEP]": 2, "[PAD]": 3, "[MASK]": 4, "a_unk": 5, "c_unk": 6, "d_unk": 7, "e_unk": 8, "f_unk": 9, "g_unk": 10, "h_unk": 11, "i_unk": 12, "k_unk": 13, "l_unk": 14, "m_unk": 15, "n_unk": 16, "p_unk": 17, "q_unk": 18, "r_unk": 19, "s_unk": 20, "t_unk": 21, "v_unk": 22, "w_unk": 23, "y_unk": 24, "__unk": 25, "k_aaa": 26, "n_aac": 27, "k_aag": 28, "n_aat": 29, "t_aca": 30, "t_acc": 31, "t_acg": 32, "t_act": 33, "r_aga": 34, "s_agc": 35, "r_agg": 36, "s_agt": 37, "i_ata": 38, "i_atc": 39, "m_atg": 40, "i_att": 41, "q_caa": 42, "h_cac": 43, "q_cag": 44, "h_cat": 45, "p_cca": 46, "p_ccc": 47, "p_ccg": 48, "p_cct": 49, "r_cga": 50, "r_cgc": 51, "r_cgg": 52, "r_cgt": 53, "l_cta": 54, "l_ctc": 55, "l_ctg": 56, "l_ctt": 57, "e_gaa": 58, "d_gac": 59, "e_gag": 60, "d_gat": 61, "a_gca": 62, "a_gcc": 63, "a_gcg": 64, "a_gct": 65, "g_gga": 66, "g_ggc": 67, "g_ggg": 68, "g_ggt": 69, "v_gta": 70, "v_gtc": 71, "v_gtg": 72, "v_gtt": 73, "__taa": 74, "y_tac": 75, "__tag": 76, "y_tat": 77, "s_tca": 78, "s_tcc": 79, "s_tcg": 80, "s_tct": 81, "__tga": 82, "c_tgc": 83, "w_tgg": 84, "c_tgt": 85, "l_tta": 86, "f_ttc": 87, "l_ttg": 88, "f_ttt": 89}}} diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py index ca7d464..42342c9 100644 --- a/tests/test_CodonData.py +++ b/tests/test_CodonData.py @@ -1,13 +1,15 @@ import tempfile import unittest + import pandas as pd +from Bio.Data.CodonTable import TranslationError + from CodonTransformer.CodonData import ( - read_fasta_file, build_amino2codon_skeleton, get_amino_acid_sequence, is_correct_seq, + read_fasta_file, ) -from Bio.Data.CodonTable import TranslationError class TestCodonData(unittest.TestCase): diff --git a/tests/test_CodonJupyter.py b/tests/test_CodonJupyter.py index 0ff6afa..6816813 100644 --- a/tests/test_CodonJupyter.py +++ b/tests/test_CodonJupyter.py @@ -1,13 +1,15 @@ import unittest + import ipywidgets + from CodonTransformer.CodonJupyter import ( + DNASequencePrediction, UserContainer, - create_organism_dropdown, create_dropdown_options, + create_organism_dropdown, display_organism_dropdown, display_protein_input, format_model_output, - DNASequencePrediction, ) from CodonTransformer.CodonUtils import ORGANISM2ID diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py index 199c383..3617c2b 100644 --- a/tests/test_CodonPrediction.py +++ b/tests/test_CodonPrediction.py @@ -1,50 +1,88 @@ import unittest +import warnings + import torch + from CodonTransformer.CodonPrediction import ( + load_model, + load_tokenizer, predict_dna_sequence, - # add other imported functions or classes as needed ) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class TestCodonPrediction(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Suppress warnings about loading from HuggingFace + for message in [ + "Tokenizer path not provided. Loading from HuggingFace.", + "Model path not provided. Loading from HuggingFace.", + ]: + warnings.filterwarnings("ignore", message=message) + + cls.model = load_model(device=cls.device) + cls.tokenizer = load_tokenizer() + def test_predict_dna_sequence_valid_input(self): - # Test predict_dna_sequence with a valid protein sequence and organism code protein_sequence = "MWWMW" organism = "Escherichia coli general" - result = predict_dna_sequence(protein_sequence, organism, device) - # Test if the output is a string and contains only A, T, C, G. + result = predict_dna_sequence( + protein_sequence, + organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) self.assertIsInstance(result.predicted_dna, str) + self.assertTrue( + all(nucleotide in "ATCG" for nucleotide in result.predicted_dna) + ) self.assertEqual(result.predicted_dna, "ATGTGGTGGATGTGGTGA") - def test_predict_dna_sequence_invalid_protein_sequence(self): - # Test predict_dna_sequence with an invalid protein sequence - protein_sequence = "MKTZZFVLLL" # 'Z' is not a valid amino acid + def test_predict_dna_sequence_non_deterministic(self): + protein_sequence = "MFWY" organism = "Escherichia coli general" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) - - def test_predict_dna_sequence_invalid_organism_code(self): - # Test predict_dna_sequence with an invalid organism code - protein_sequence = "MKTFFVLLL" - organism = "Alien $%#@!" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) - - def test_predict_dna_sequence_empty_protein_sequence(self): - # Test predict_dna_sequence with an empty protein sequence - protein_sequence = "" - organism = "Escherichia coli general" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) + num_iterations = 64 + possible_outputs = set() + possible_encodings_wo_stop = { + "ATGTTTTGGTAT", + "ATGTTCTGGTAT", + "ATGTTTTGGTAC", + "ATGTTCTGGTAC", + } - def test_predict_dna_sequence_none_protein_sequence(self): - # Test predict_dna_sequence with None as protein sequence - protein_sequence = None - organism = "Escherichia coli general" - with self.assertRaises(ValueError): - predict_dna_sequence(protein_sequence, organism, device) + for _ in range(num_iterations): + result = predict_dna_sequence( + protein=protein_sequence, + organism=organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + deterministic=False, + ) + possible_outputs.add(result.predicted_dna[:-3]) # Remove stop codon + + self.assertEqual(possible_outputs, possible_encodings_wo_stop) + + def test_predict_dna_sequence_invalid_inputs(self): + test_cases = [ + ("MKTZZFVLLL", "Escherichia coli general", "invalid protein sequence"), + ("MKTFFVLLL", "Alien $%#@!", "invalid organism code"), + ("", "Escherichia coli general", "empty protein sequence"), + ] + + for protein_sequence, organism, error_type in test_cases: + with self.subTest(error_type=error_type): + with self.assertRaises(ValueError): + predict_dna_sequence( + protein_sequence, + organism, + device=self.device, + tokenizer=self.tokenizer, + model=self.model, + ) if __name__ == "__main__": diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py index c7faf11..1d9a94c 100644 --- a/tests/test_CodonUtils.py +++ b/tests/test_CodonUtils.py @@ -2,14 +2,15 @@ import pickle import tempfile import unittest + from CodonTransformer.CodonUtils import ( - load_python_object_from_disk, - save_python_object_to_disk, find_pattern_in_fasta, get_organism2id_dict, get_taxonomy_id, - sort_amino2codon_skeleton, load_pkl_from_url, + load_python_object_from_disk, + save_python_object_to_disk, + sort_amino2codon_skeleton, ) @@ -32,7 +33,10 @@ def test_save_python_object_to_disk(self): os.remove(temp_file_name) def test_find_pattern_in_fasta(self): - text = ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n>seq2 [keyword=value2]\nGGTACGATCGATCGAT" + text = ( + ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n" + ">seq2 [keyword=value2]\nGGTACGATCGATCGAT" + ) self.assertEqual(find_pattern_in_fasta("keyword", text), "value1") self.assertEqual(find_pattern_in_fasta("nonexistent", text), "")