diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..080d1e7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,39 @@ +# .github/workflows/ci.yml + +name: CI + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install coverage + + - name: Run tests with coverage + run: | + make test_with_coverage + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v2 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 68bc17f..0a03dca 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,20 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Coverage reports +coverage.xml + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Temporary files +*.tmp +*.temp + +# PyTorch Lightning checkpoints +lightning_logs/ + +# PyTorch model weights +*.pth +*.pt \ No newline at end of file diff --git a/CodonTransformer/CodonData.py b/CodonTransformer/CodonData.py index 102d7b2..5932ff8 100644 --- a/CodonTransformer/CodonData.py +++ b/CodonTransformer/CodonData.py @@ -355,24 +355,29 @@ def get_amino_acid_sequence( def read_fasta_file( input_file: str, - output_path: str, + save_to_file: Optional[str] = None, organism: str = "", - return_dataframe: bool = True, buffer_size: int = 50000, ) -> pd.DataFrame: """ - Read a FASTA file of DNA sequences and save it to a Pandas DataFrame. + Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame. + Optionally, save the DataFrame to a CSV file. Args: input_file (str): Path to the input FASTA file. - output_path (str): Path to save the output DataFrame. - organism (str): Name of the organism. - return_dataframe (bool): Whether to return the DataFrame. - buffer_size (int): Buffer size for reading the file. + save_to_file (Optional[str]): Path to save the output DataFrame. If None, data is only returned. + organism (str): Name of the organism. If empty, it will be extracted from the FASTA description. + buffer_size (int): Number of records to process before writing to file. Returns: - pd.DataFrame: DataFrame containing the DNA sequences. + pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe is True, else None. + + Raises: + FileNotFoundError: If the input file does not exist. """ + if not os.path.exists(input_file): + raise FileNotFoundError(f"Input file not found: {input_file}") + buffer = [] columns = [ "dna", @@ -384,20 +389,25 @@ def read_fasta_file( "tokenized", ] - # Read the FASTA file and process each sequence record + # Initialize DataFrame to store all data if return_dataframe is True + all_data = pd.DataFrame(columns=columns) + with open(input_file, "r") as fasta_file: for record in tqdm( - SeqIO.parse(fasta_file, "fasta"), desc=f"{organism}", unit=" Rows" + SeqIO.parse(fasta_file, "fasta"), + desc=f"Processing {organism}", + unit=" Records", ): - dna = str(record.seq).strip() + dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence # Determine the organism from the record if not provided - if not organism: - organism = find_pattern_in_fasta("organism", record.description) - GeneID = find_pattern_in_fasta("GeneID", record.description) + current_organism = organism or find_pattern_in_fasta( + "organism", record.description + ) + gene_id = find_pattern_in_fasta("GeneID", record.description) # Get the appropriate codon table for the organism - codon_table = get_codon_table(organism) + codon_table = get_codon_table(current_organism) # Translate DNA to protein sequence protein, correct_seq = get_amino_acid_sequence( @@ -406,44 +416,46 @@ def read_fasta_file( codon_table=codon_table, return_correct_seq=True, ) - description = str(record.description[: record.description.find("[")]) - tokenized = get_merged_seq(protein, dna, seperator=STOP_SYMBOL) + description = record.description.split("[", 1)[0].strip() + tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL) # Create a data row for the current sequence data_row = { "dna": dna, "protein": protein, "correct_seq": correct_seq, - "organism": organism, - "GeneID": GeneID, + "organism": current_organism, + "GeneID": gene_id, "description": description, "tokenized": tokenized, } buffer.append(data_row) # Write buffer to CSV file when buffer size is reached - if len(buffer) >= buffer_size: - buffer_df = pd.DataFrame(buffer, columns=columns) - buffer_df.to_csv( - output_path, - mode="a", - header=(not os.path.exists(output_path)), - index=True, - ) + if save_to_file and len(buffer) >= buffer_size: + write_buffer_to_csv(buffer, save_to_file, columns) buffer = [] - # Write remaining buffer to CSV file - if buffer: - buffer_df = pd.DataFrame(buffer, columns=columns) - buffer_df.to_csv( - output_path, - mode="a", - header=(not os.path.exists(output_path)), - index=True, + all_data = pd.concat( + [all_data, pd.DataFrame([data_row])], ignore_index=True ) - if return_dataframe: - return pd.read_csv(output_path, index_col=0) + # Write remaining buffer to CSV file + if save_to_file and buffer: + write_buffer_to_csv(buffer, save_to_file, columns) + + return all_data + + +def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]): + """Helper function to write buffer to CSV file.""" + buffer_df = pd.DataFrame(buffer, columns=columns) + buffer_df.to_csv( + output_path, + mode="a", + header=(not os.path.exists(output_path)), + index=True, + ) def download_codon_frequencies_from_kazusa( diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 2240b53..7df8520 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -15,12 +15,13 @@ PreTrainedTokenizerFast, BigBirdConfig, AutoTokenizer, - BigBirdForMaskedLM + BigBirdForMaskedLM, ) import numpy as np from CodonTransformer.CodonData import get_merged_seq from CodonTransformer.CodonUtils import ( + AMINO_ACIDS, ORGANISM2ID, TOKEN2INDEX, INDEX2TOKEN, @@ -76,18 +77,18 @@ def predict_dna_sequence( >>> from transformers import AutoTokenizer, BigBirdForMaskedLM >>> from CodonTransformer.CodonPrediction import predict_dna_sequence >>> from CodonTransformer.CodonJupyter import format_model_output - >>> + >>> >>> # Set up device >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - >>> + >>> >>> # Load tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer") >>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device) - >>> + >>> >>> # Define protein sequence and organism >>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG" >>> organism = "Escherichia coli general" - >>> + >>> >>> # Predict DNA sequence >>> output = predict_dna_sequence( ... protein=protein, @@ -97,12 +98,19 @@ def predict_dna_sequence( ... model=model, ... attention_type="original_full" ... ) - >>> + >>> >>> print(format_model_output(output)) """ if not protein: raise ValueError("Protein sequence cannot be empty.") + if not isinstance(protein, str): + raise ValueError("Protein sequence must be a string.") + + # Test that the input protein sequence contains only valid amino acids + if not all(aminoacid in AMINO_ACIDS for aminoacid in protein): + raise ValueError("Invalid amino acid found in protein sequence.") + # Load tokenizer if not isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = load_tokenizer(tokenizer) @@ -127,9 +135,7 @@ def predict_dna_sequence( "codons": merged_seq, "organism": organism_id, } - tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to( - device - ) + tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device) # Get the model predictions output_dict = model(**tokenized_input, return_dict=True) @@ -202,7 +208,9 @@ def load_model( model.load_state_dict(state_dict) else: - raise ValueError("Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace.") + raise ValueError( + "Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace." + ) # Prepare model for evaluation model.bert.set_attention_type(attention_type) @@ -386,7 +394,7 @@ def get_high_frequency_choice_sequence( def precompute_most_frequent_codons( - codon_frequencies: Dict[str, Tuple[List[str], List[float]]] + codon_frequencies: Dict[str, Tuple[List[str], List[float]]], ) -> Dict[str, str]: """ Precompute the most frequent codon for each amino acid. @@ -449,7 +457,7 @@ def get_background_frequency_choice_sequence( def precompute_cdf( - codon_frequencies: Dict[str, Tuple[List[str], List[float]]] + codon_frequencies: Dict[str, Tuple[List[str], List[float]]], ) -> Dict[str, Tuple[List[str], Any]]: """ Precompute the cumulative distribution function (CDF) for each amino acid. diff --git a/CodonTransformer/CodonUtils.py b/CodonTransformer/CodonUtils.py index 8595ac1..cf84b24 100644 --- a/CodonTransformer/CodonUtils.py +++ b/CodonTransformer/CodonUtils.py @@ -593,6 +593,10 @@ def get_organism2id_dict(organism_reference: str) -> Dict[str, int]: Args: organism_reference (str): Path to a CSV file containing a list of all organisms. + The format of the CSV file should be as follows: + 0,Escherichia coli + 1,Homo sapiens + 2,Mus musculus Returns: Dict[str, int]: A dictionary mapping organism names to their respective indices. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2527ef0 --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +# Makefile + +.PHONY: test +test: + python -m unittest discover -s tests + +.PHONY: test_with_coverage +test_with_coverage: + coverage run -m unittest discover -s tests diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_CodonData.py b/tests/test_CodonData.py new file mode 100644 index 0000000..ca7d464 --- /dev/null +++ b/tests/test_CodonData.py @@ -0,0 +1,85 @@ +import tempfile +import unittest +import pandas as pd +from CodonTransformer.CodonData import ( + read_fasta_file, + build_amino2codon_skeleton, + get_amino_acid_sequence, + is_correct_seq, +) +from Bio.Data.CodonTable import TranslationError + + +class TestCodonData(unittest.TestCase): + def test_read_fasta_file(self): + fasta_content = ">sequence1\n" "ATGATGATGATGATG\n" ">sequence2\n" "TGATGATGATGA" + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".fasta" + ) as temp_file: + temp_file.write(fasta_content) + temp_file_name = temp_file.name + + try: + sequences = read_fasta_file(temp_file_name, save_to_file=None) + self.assertIsInstance(sequences, pd.DataFrame) + self.assertEqual(len(sequences), 2) + self.assertEqual(sequences.iloc[0]["dna"], "ATGATGATGATGATG") + self.assertEqual(sequences.iloc[1]["dna"], "TGATGATGATGA") + finally: + import os + + os.unlink(temp_file_name) + + def test_build_amino2codon_skeleton(self): + organism = "Homo sapiens" + codon_skeleton = build_amino2codon_skeleton(organism) + + expected_amino_acids = "ARNDCQEGHILKMFPSTWYV_" + + for amino_acid in expected_amino_acids: + self.assertIn(amino_acid, codon_skeleton) + codons, frequencies = codon_skeleton[amino_acid] + self.assertIsInstance(codons, list) + self.assertIsInstance(frequencies, list) + self.assertEqual(len(codons), len(frequencies)) + self.assertTrue(all(isinstance(codon, str) for codon in codons)) + self.assertTrue(all(freq == 0 for freq in frequencies)) + + all_codons = set( + codon for codons, _ in codon_skeleton.values() for codon in codons + ) + self.assertEqual(len(all_codons), 64) # There should be 64 unique codons + + def test_get_amino_acid_sequence(self): + dna = "ATGGCCTGA" + protein, is_correct = get_amino_acid_sequence(dna, return_correct_seq=True) + self.assertEqual(protein, "MA_") + self.assertTrue(is_correct) + + def test_is_correct_seq(self): + dna = "ATGGCCTGA" + protein = "MA_" + self.assertTrue(is_correct_seq(dna, protein)) + + def test_read_fasta_file_raises_exception_for_non_dna(self): + non_dna_content = ">sequence1\nATGATGATGXYZATG\n>sequence2\nTGATGATGATGA" + + with tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".fasta" + ) as temp_file: + temp_file.write(non_dna_content) + temp_file_name = temp_file.name + + try: + with self.assertRaises(TranslationError) as context: + read_fasta_file(temp_file_name) + self.assertIn("Codon 'XYZ' is invalid", str(context.exception)) + finally: + import os + + os.unlink(temp_file_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodonJupyter.py b/tests/test_CodonJupyter.py new file mode 100644 index 0000000..0ff6afa --- /dev/null +++ b/tests/test_CodonJupyter.py @@ -0,0 +1,116 @@ +import unittest +import ipywidgets +from CodonTransformer.CodonJupyter import ( + UserContainer, + create_organism_dropdown, + create_dropdown_options, + display_organism_dropdown, + display_protein_input, + format_model_output, + DNASequencePrediction, +) +from CodonTransformer.CodonUtils import ORGANISM2ID + + +class TestCodonJupyter(unittest.TestCase): + def test_UserContainer(self): + user_container = UserContainer() + self.assertEqual(user_container.organism, -1) + self.assertEqual(user_container.protein, "") + + def test_create_organism_dropdown(self): + container = UserContainer() + dropdown = create_organism_dropdown(container) + + self.assertIsInstance(dropdown, ipywidgets.Dropdown) + self.assertGreater(len(dropdown.options), 0) + self.assertEqual(dropdown.description, "") + self.assertEqual(dropdown.layout.width, "40%") + self.assertEqual(dropdown.layout.margin, "0 0 10px 0") + self.assertEqual(dropdown.style.description_width, "initial") + + # Test the dropdown options + options = dropdown.options + self.assertIn("", options) + self.assertIn("Selected Organisms", options) + self.assertIn("All Organisms", options) + + def test_create_dropdown_options(self): + options = create_dropdown_options(ORGANISM2ID) + self.assertIsInstance(options, list) + self.assertGreater(len(options), 0) + + def test_display_organism_dropdown(self): + container = UserContainer() + with unittest.mock.patch( + "CodonTransformer.CodonJupyter.display" + ) as mock_display: + display_organism_dropdown(container) + + # Check that display was called twice (for container_widget and HTML) + self.assertEqual(mock_display.call_count, 2) + + # Check that the first call to display was with a VBox widget + self.assertIsInstance(mock_display.call_args_list[0][0][0], ipywidgets.VBox) + + # Check that the VBox contains a Dropdown + dropdown = mock_display.call_args_list[0][0][0].children[1] + self.assertIsInstance(dropdown, ipywidgets.Dropdown) + self.assertGreater(len(dropdown.options), 0) + + def test_display_protein_input(self): + container = UserContainer() + with unittest.mock.patch( + "CodonTransformer.CodonJupyter.display" + ) as mock_display: + display_protein_input(container) + + # Check that display was called twice (for container_widget and HTML) + self.assertEqual(mock_display.call_count, 2) + + # Check that the first call to display was with a VBox widget + self.assertIsInstance(mock_display.call_args_list[0][0][0], ipywidgets.VBox) + + # Check that the VBox contains a Textarea + textarea = mock_display.call_args_list[0][0][0].children[1] + self.assertIsInstance(textarea, ipywidgets.Textarea) + + # Verify the properties of the Textarea + self.assertEqual(textarea.value, "") + self.assertEqual(textarea.placeholder, "Enter here...") + self.assertEqual(textarea.description, "") + self.assertEqual(textarea.layout.width, "100%") + self.assertEqual(textarea.layout.height, "100px") + self.assertEqual(textarea.layout.margin, "0 0 10px 0") + self.assertEqual(textarea.style.description_width, "initial") + + def test_format_model_output(self): + output = DNASequencePrediction( + organism="Escherichia coli", + protein="MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + processed_input="MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + predicted_dna="ATGAAAACTGTTCGTCAGGAACGTCTGAAATCTATTGTTCGTATTCTGGAACGTTCTAAAGAACCGGTTTCTGGTGCTCAACTGGCTGAAGAACTGTCTGTTTCTCGTCAGGTTATTGTTCAGGACATTGCTTACCTGCGTTCTCTGGGTTATAA", + ) + formatted_output = format_model_output(output) + self.assertIsInstance(formatted_output, str) + self.assertIn("Organism", formatted_output) + self.assertIn("Escherichia coli", formatted_output) + self.assertIn("Input Protein", formatted_output) + self.assertIn( + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + formatted_output, + ) + self.assertIn("Processed Input", formatted_output) + self.assertIn( + "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", + formatted_output, + ) + self.assertIn("Predicted DNA", formatted_output) + self.assertIn( + "ATGAAAACTGTTCGTCAGGAACGTCTGAAATCTATTGTTCGTATTCTGGAACGTTCTAAAGAACCGGTTTCTGGTGCTCAACTGGCTGAAGAACTGTCTGTTTCTCGTCAGGTTATTGTTCAGGACATTGCTTACCTGCGTTCTCTGGGTTATAA", + formatted_output, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodonPrediction.py b/tests/test_CodonPrediction.py new file mode 100644 index 0000000..199c383 --- /dev/null +++ b/tests/test_CodonPrediction.py @@ -0,0 +1,51 @@ +import unittest +import torch +from CodonTransformer.CodonPrediction import ( + predict_dna_sequence, + # add other imported functions or classes as needed +) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class TestCodonPrediction(unittest.TestCase): + def test_predict_dna_sequence_valid_input(self): + # Test predict_dna_sequence with a valid protein sequence and organism code + protein_sequence = "MWWMW" + organism = "Escherichia coli general" + result = predict_dna_sequence(protein_sequence, organism, device) + # Test if the output is a string and contains only A, T, C, G. + self.assertIsInstance(result.predicted_dna, str) + self.assertEqual(result.predicted_dna, "ATGTGGTGGATGTGGTGA") + + def test_predict_dna_sequence_invalid_protein_sequence(self): + # Test predict_dna_sequence with an invalid protein sequence + protein_sequence = "MKTZZFVLLL" # 'Z' is not a valid amino acid + organism = "Escherichia coli general" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + def test_predict_dna_sequence_invalid_organism_code(self): + # Test predict_dna_sequence with an invalid organism code + protein_sequence = "MKTFFVLLL" + organism = "Alien $%#@!" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + def test_predict_dna_sequence_empty_protein_sequence(self): + # Test predict_dna_sequence with an empty protein sequence + protein_sequence = "" + organism = "Escherichia coli general" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + def test_predict_dna_sequence_none_protein_sequence(self): + # Test predict_dna_sequence with None as protein sequence + protein_sequence = None + organism = "Escherichia coli general" + with self.assertRaises(ValueError): + predict_dna_sequence(protein_sequence, organism, device) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_CodonUtils.py b/tests/test_CodonUtils.py new file mode 100644 index 0000000..c7faf11 --- /dev/null +++ b/tests/test_CodonUtils.py @@ -0,0 +1,89 @@ +import os +import pickle +import tempfile +import unittest +from CodonTransformer.CodonUtils import ( + load_python_object_from_disk, + save_python_object_to_disk, + find_pattern_in_fasta, + get_organism2id_dict, + get_taxonomy_id, + sort_amino2codon_skeleton, + load_pkl_from_url, +) + + +class TestCodonUtils(unittest.TestCase): + def test_load_python_object_from_disk(self): + test_obj = {"key1": "value1", "key2": 2} + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: + temp_file_name = temp_file.name + save_python_object_to_disk(test_obj, temp_file_name) + loaded_obj = load_python_object_from_disk(temp_file_name) + self.assertEqual(test_obj, loaded_obj) + os.remove(temp_file_name) + + def test_save_python_object_to_disk(self): + test_obj = [1, 2, 3, 4, 5] + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as temp_file: + temp_file_name = temp_file.name + save_python_object_to_disk(test_obj, temp_file_name) + self.assertTrue(os.path.exists(temp_file_name)) + os.remove(temp_file_name) + + def test_find_pattern_in_fasta(self): + text = ">seq1 [keyword=value1]\nATGCGTACGTAGCTAG\n>seq2 [keyword=value2]\nGGTACGATCGATCGAT" + self.assertEqual(find_pattern_in_fasta("keyword", text), "value1") + self.assertEqual(find_pattern_in_fasta("nonexistent", text), "") + + def test_get_organism2id_dict(self): + with tempfile.NamedTemporaryFile( + mode="w", delete=True, suffix=".csv" + ) as temp_file: + temp_file.write("0,Escherichia coli\n1,Homo sapiens\n2,Mus musculus") + temp_file.flush() + organism2id = get_organism2id_dict(temp_file.name) + self.assertEqual( + organism2id, + {"Escherichia coli": 0, "Homo sapiens": 1, "Mus musculus": 2}, + ) + + def test_get_taxonomy_id(self): + taxonomy_dict = { + "Escherichia coli": 562, + "Homo sapiens": 9606, + "Mus musculus": 10090, + } + with tempfile.NamedTemporaryFile(suffix=".pkl", delete=True) as temp_file: + temp_file_name = temp_file.name + save_python_object_to_disk(taxonomy_dict, temp_file_name) + self.assertEqual(get_taxonomy_id(temp_file_name, "Escherichia coli"), 562) + self.assertEqual( + get_taxonomy_id(temp_file_name, return_dict=True), taxonomy_dict + ) + + def test_sort_amino2codon_skeleton(self): + amino2codon = { + "A": (["GCT", "GCC", "GCA", "GCG"], [0.0, 0.0, 0.0, 0.0]), + "C": (["TGT", "TGC"], [0.0, 0.0]), + } + sorted_amino2codon = sort_amino2codon_skeleton(amino2codon) + self.assertEqual( + sorted_amino2codon, + { + "A": (["GCA", "GCC", "GCG", "GCT"], [0.0, 0.0, 0.0, 0.0]), + "C": (["TGC", "TGT"], [0.0, 0.0]), + }, + ) + + def test_load_pkl_from_url(self): + url = "https://example.com/test.pkl" + expected_obj = {"key": "value"} + with unittest.mock.patch("requests.get") as mock_get: + mock_get.return_value.content = pickle.dumps(expected_obj) + loaded_obj = load_pkl_from_url(url) + self.assertEqual(loaded_obj, expected_obj) + + +if __name__ == "__main__": + unittest.main()