Skip to content

Commit

Permalink
Add test code and Codecov integration
Browse files Browse the repository at this point in the history
- Added test cases for various modules.
- Integrated Codecov for test coverage reporting.
- Improved .gitignore to exclude additional files and directories.
  • Loading branch information
gui11aume committed Sep 18, 2024
1 parent a89b980 commit 7e935e7
Show file tree
Hide file tree
Showing 11 changed files with 479 additions and 49 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# .github/workflows/ci.yml

name: CI

on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install coverage
- name: Run tests with coverage
run: |
make test_with_coverage
coverage report
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: true
17 changes: 17 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,20 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Coverage reports
coverage.xml

# Jupyter Notebook checkpoints
.ipynb_checkpoints/

# Temporary files
*.tmp
*.temp

# PyTorch Lightning checkpoints
lightning_logs/

# PyTorch model weights
*.pth
*.pt
86 changes: 49 additions & 37 deletions CodonTransformer/CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,24 +355,29 @@ def get_amino_acid_sequence(

def read_fasta_file(
input_file: str,
output_path: str,
save_to_file: Optional[str] = None,
organism: str = "",
return_dataframe: bool = True,
buffer_size: int = 50000,
) -> pd.DataFrame:
"""
Read a FASTA file of DNA sequences and save it to a Pandas DataFrame.
Read a FASTA file of DNA sequences and convert it to a Pandas DataFrame.
Optionally, save the DataFrame to a CSV file.
Args:
input_file (str): Path to the input FASTA file.
output_path (str): Path to save the output DataFrame.
organism (str): Name of the organism.
return_dataframe (bool): Whether to return the DataFrame.
buffer_size (int): Buffer size for reading the file.
save_to_file (Optional[str]): Path to save the output DataFrame. If None, data is only returned.
organism (str): Name of the organism. If empty, it will be extracted from the FASTA description.
buffer_size (int): Number of records to process before writing to file.
Returns:
pd.DataFrame: DataFrame containing the DNA sequences.
pd.DataFrame: DataFrame containing the DNA sequences if return_dataframe is True, else None.
Raises:
FileNotFoundError: If the input file does not exist.
"""
if not os.path.exists(input_file):
raise FileNotFoundError(f"Input file not found: {input_file}")

buffer = []
columns = [
"dna",
Expand All @@ -384,20 +389,25 @@ def read_fasta_file(
"tokenized",
]

# Read the FASTA file and process each sequence record
# Initialize DataFrame to store all data if return_dataframe is True
all_data = pd.DataFrame(columns=columns)

with open(input_file, "r") as fasta_file:
for record in tqdm(
SeqIO.parse(fasta_file, "fasta"), desc=f"{organism}", unit=" Rows"
SeqIO.parse(fasta_file, "fasta"),
desc=f"Processing {organism}",
unit=" Records",
):
dna = str(record.seq).strip()
dna = str(record.seq).strip().upper() # Ensure uppercase DNA sequence

# Determine the organism from the record if not provided
if not organism:
organism = find_pattern_in_fasta("organism", record.description)
GeneID = find_pattern_in_fasta("GeneID", record.description)
current_organism = organism or find_pattern_in_fasta(
"organism", record.description
)
gene_id = find_pattern_in_fasta("GeneID", record.description)

# Get the appropriate codon table for the organism
codon_table = get_codon_table(organism)
codon_table = get_codon_table(current_organism)

# Translate DNA to protein sequence
protein, correct_seq = get_amino_acid_sequence(
Expand All @@ -406,44 +416,46 @@ def read_fasta_file(
codon_table=codon_table,
return_correct_seq=True,
)
description = str(record.description[: record.description.find("[")])
tokenized = get_merged_seq(protein, dna, seperator=STOP_SYMBOL)
description = record.description.split("[", 1)[0].strip()
tokenized = get_merged_seq(protein, dna, separator=STOP_SYMBOL)

# Create a data row for the current sequence
data_row = {
"dna": dna,
"protein": protein,
"correct_seq": correct_seq,
"organism": organism,
"GeneID": GeneID,
"organism": current_organism,
"GeneID": gene_id,
"description": description,
"tokenized": tokenized,
}
buffer.append(data_row)

# Write buffer to CSV file when buffer size is reached
if len(buffer) >= buffer_size:
buffer_df = pd.DataFrame(buffer, columns=columns)
buffer_df.to_csv(
output_path,
mode="a",
header=(not os.path.exists(output_path)),
index=True,
)
if save_to_file and len(buffer) >= buffer_size:
write_buffer_to_csv(buffer, save_to_file, columns)
buffer = []

# Write remaining buffer to CSV file
if buffer:
buffer_df = pd.DataFrame(buffer, columns=columns)
buffer_df.to_csv(
output_path,
mode="a",
header=(not os.path.exists(output_path)),
index=True,
all_data = pd.concat(
[all_data, pd.DataFrame([data_row])], ignore_index=True
)

if return_dataframe:
return pd.read_csv(output_path, index_col=0)
# Write remaining buffer to CSV file
if save_to_file and buffer:
write_buffer_to_csv(buffer, save_to_file, columns)

return all_data


def write_buffer_to_csv(buffer: List[Dict], output_path: str, columns: List[str]):
"""Helper function to write buffer to CSV file."""
buffer_df = pd.DataFrame(buffer, columns=columns)
buffer_df.to_csv(
output_path,
mode="a",
header=(not os.path.exists(output_path)),
index=True,
)


def download_codon_frequencies_from_kazusa(
Expand Down
32 changes: 20 additions & 12 deletions CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
PreTrainedTokenizerFast,
BigBirdConfig,
AutoTokenizer,
BigBirdForMaskedLM
BigBirdForMaskedLM,
)
import numpy as np

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACIDS,
ORGANISM2ID,
TOKEN2INDEX,
INDEX2TOKEN,
Expand Down Expand Up @@ -76,18 +77,18 @@ def predict_dna_sequence(
>>> from transformers import AutoTokenizer, BigBirdForMaskedLM
>>> from CodonTransformer.CodonPrediction import predict_dna_sequence
>>> from CodonTransformer.CodonJupyter import format_model_output
>>>
>>>
>>> # Set up device
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>>
>>>
>>> # Load tokenizer and model
>>> tokenizer = AutoTokenizer.from_pretrained("adibvafa/CodonTransformer")
>>> model = BigBirdForMaskedLM.from_pretrained("adibvafa/CodonTransformer").to(device)
>>>
>>>
>>> # Define protein sequence and organism
>>> protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
>>> organism = "Escherichia coli general"
>>>
>>>
>>> # Predict DNA sequence
>>> output = predict_dna_sequence(
... protein=protein,
Expand All @@ -97,12 +98,19 @@ def predict_dna_sequence(
... model=model,
... attention_type="original_full"
... )
>>>
>>>
>>> print(format_model_output(output))
"""
if not protein:
raise ValueError("Protein sequence cannot be empty.")

if not isinstance(protein, str):
raise ValueError("Protein sequence must be a string.")

# Test that the input protein sequence contains only valid amino acids
if not all(aminoacid in AMINO_ACIDS for aminoacid in protein):
raise ValueError("Invalid amino acid found in protein sequence.")

# Load tokenizer
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = load_tokenizer(tokenizer)
Expand All @@ -127,9 +135,7 @@ def predict_dna_sequence(
"codons": merged_seq,
"organism": organism_id,
}
tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(
device
)
tokenized_input = tokenize([input_dict], tokenizer=tokenizer).to(device)

# Get the model predictions
output_dict = model(**tokenized_input, return_dict=True)
Expand Down Expand Up @@ -202,7 +208,9 @@ def load_model(
model.load_state_dict(state_dict)

else:
raise ValueError("Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace.")
raise ValueError(
"Unsupported file type. Please provide a .ckpt or .pt file, or None to load from HuggingFace."
)

# Prepare model for evaluation
model.bert.set_attention_type(attention_type)
Expand Down Expand Up @@ -386,7 +394,7 @@ def get_high_frequency_choice_sequence(


def precompute_most_frequent_codons(
codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
) -> Dict[str, str]:
"""
Precompute the most frequent codon for each amino acid.
Expand Down Expand Up @@ -449,7 +457,7 @@ def get_background_frequency_choice_sequence(


def precompute_cdf(
codon_frequencies: Dict[str, Tuple[List[str], List[float]]]
codon_frequencies: Dict[str, Tuple[List[str], List[float]]],
) -> Dict[str, Tuple[List[str], Any]]:
"""
Precompute the cumulative distribution function (CDF) for each amino acid.
Expand Down
4 changes: 4 additions & 0 deletions CodonTransformer/CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,10 @@ def get_organism2id_dict(organism_reference: str) -> Dict[str, int]:
Args:
organism_reference (str): Path to a CSV file containing a list of all organisms.
The format of the CSV file should be as follows:
0,Escherichia coli
1,Homo sapiens
2,Mus musculus
Returns:
Dict[str, int]: A dictionary mapping organism names to their respective indices.
Expand Down
9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Makefile

.PHONY: test
test:
python -m unittest discover -s tests

.PHONY: test_with_coverage
test_with_coverage:
coverage run -m unittest discover -s tests
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit 7e935e7

Please sign in to comment.