Skip to content

Commit

Permalink
Feat/logging cif support (#402)
Browse files Browse the repository at this point in the history
* added cif support

* add cif tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove prints

* make cath metadata download optional

* modify cif tests

* pdb download formats

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bump biopandas pinned version to latest

* update torch index url

* add deprecation decorator to regnetwork

* ignore GRN tutorial notebook in tests due to RegNetwork going offline

* configure torch==1.13.0 install index url

* rm +cpu flag for torch >=2.0 install

* try removing pyg lib from CI

* use latest miniconda setup actions

* rm unused and deprecated 'U'flag in file read

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Arian Jamasb <arian.jamasb@roche.com>
Co-authored-by: Arian Jamasb <arjamasb@gmail.com>
  • Loading branch information
4 people authored Aug 2, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent f6d9d72 commit 848a3f8
Showing 17 changed files with 11,229 additions and 50 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ jobs:
uses: actions/checkout@v3
# See: https://github.com/marketplace/actions/setup-miniconda
- name: Setup miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
miniforge-variant: Mambaforge
@@ -48,15 +48,18 @@ jobs:
run: conda install dssp -c salilab
- name: Install mmseqs
run: mamba install -c conda-forge -c bioconda mmseqs2
- name: Install PyTorch
#run: mamba install -c pytorch pytorch==${{matrix.torch}} cpuonly
run: pip install torch==${{matrix.torch}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch (1.13.0)
if: matrix.torch == '1.13.0'
run: pip install torch==${{matrix.torch}}+cpu --extra-index-url https://download.pytorch.org/whl/cpu
- name: Install PyTorch (2.0+)
if: matrix.torch != '1.13.0'
run: pip install torch==${{matrix.torch}} -f https://download.pytorch.org/whl/cpu
- name: Install PyG
#run: mamba install -c pyg pyg
run: pip install torch_geometric
- name: Install torch-cluster
#run: mamba install pytorch-cluster -c pyg
run: pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{matrix.torch}}+cpu.html
run: pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${{matrix.torch}}+cpu.html
- name: Install BLAST
run: sudo apt install ncbi-blast+
- name: Install Graphein
@@ -70,4 +73,4 @@ jobs:
- name: Run unit tests and generate coverage report
run: pytest .
- name: Test notebook execution
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb"
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb" --ignore-glob="notebooks/grn_tutorial.ipynb"
4 changes: 2 additions & 2 deletions .github/workflows/minimal__install.yaml
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ jobs:
uses: actions/checkout@v3
# See: https://github.com/marketplace/actions/setup-miniconda
- name: Setup miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
miniforge-variant: Mambaforge
@@ -50,4 +50,4 @@ jobs:
- name: Run unit tests and generate coverage report
run: pytest . --ignore-glob="tests/protein/tensor" --ignore="tests/ml/test_conversion.py" --ignore="tests/ml/test_torch_geometric_dataset.py"
- name: Test notebook execution
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/higher_order_graphs.ipynb" --ignore-glob="notebooks/protein_graph_analytics.ipynb" --ignore-glob="notebooks/subgraphing_tutorial.ipynb" --ignore-glob="notebooks/splitting_a_dataset.ipynb" --ignore-glob="notebooks/protein_tensors.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb" --ignore-glob="notebooks/creating_datasets_from_the_pdb.ipynb"
run: pytest --nbval-lax notebooks/ --current-env --ignore-glob="notebooks/dataloader_tutorial.ipynb" --ignore-glob="notebooks/higher_order_graphs.ipynb" --ignore-glob="notebooks/protein_graph_analytics.ipynb" --ignore-glob="notebooks/subgraphing_tutorial.ipynb" --ignore-glob="notebooks/splitting_a_dataset.ipynb" --ignore-glob="notebooks/protein_tensors.ipynb" --ignore-glob="notebooks/datasets_and_dataloaders.ipynb" --ignore-glob="notebooks/foldcomp.ipynb" --ignore-glob="notebooks/creating_datasets_from_the_pdb.ipynb" --ignore-glob="notebooks/grn_tutorial.ipynb"
2 changes: 1 addition & 1 deletion .requirements/base.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pandas<2.0.0
biopandas>=0.5.0.dev0
biopandas>=0.5.1
biopython
bioservices>=1.10.0
deepdiff
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
* Fix cluster file loading bug in `pdb_data.py` [#396](https://github.com/a-r-j/graphein/pull/396)

#### Misc
* set logging to false by default and added mmcif support [#402](https://github.com/a-r-j/graphein/pull/402)
* add metadata options for uniprot, ecnumber and CATH code to pdb manager [#398](https://github.com/a-r-j/graphein/pull/398)
* bumped logging level down from `INFO` to `DEBUG` at several places to reduced output length [#391](https://github.com/a-r-j/graphein/pull/391)
* exposed `fill_value` and `bfactor` option to `protein_to_pyg` function. [#385](https://github.com/a-r-j/graphein/pull/385) and [#388](https://github.com/a-r-j/graphein/pull/388)
2 changes: 2 additions & 0 deletions graphein/__init__.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,8 @@
]
)

logger.disable("graphein")


def verbose(enabled: bool = False):
"""Enable/Disable logging.
20 changes: 19 additions & 1 deletion graphein/grn/parse_regnetwork.py
Original file line number Diff line number Diff line change
@@ -18,9 +18,12 @@
import wget
from loguru import logger as log

from graphein.utils.utils import filter_dataframe, ping
from graphein.utils.utils import deprecated, filter_dataframe, ping


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def _download_RegNetwork(
root_dir: Optional[Path] = None, network_type: str = "human"
) -> str:
@@ -86,6 +89,9 @@ def _download_RegNetwork(
return file


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
"""
Downloads RegNetwork regulatory interactions types to the root directory.
@@ -124,6 +130,9 @@ def _download_RegNetwork_regtypes(root_dir: Optional[Path] = None) -> str:
return file


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
@functools.lru_cache()
def load_RegNetwork_interactions(
root_dir: Optional[Path] = None,
@@ -144,6 +153,9 @@ def load_RegNetwork_interactions(
)


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
@functools.lru_cache()
def load_RegNetwork_regulation_types(
root_dir: Optional[Path] = None,
@@ -168,6 +180,9 @@ def load_RegNetwork_regulation_types(
)


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def parse_RegNetwork(
gene_list: List[str], root_dir: Optional[Path] = None
) -> pd.DataFrame:
@@ -244,6 +259,9 @@ def standardise_RegNetwork(df: pd.DataFrame) -> pd.DataFrame:
return df


@deprecated(
"RegNetwork appears to be down. This warning will be removed in a future release if the service is restored."
)
def RegNetwork_df(
gene_list: List[str],
root_dir: Optional[Path] = None,
2 changes: 1 addition & 1 deletion graphein/ml/clustering.py
Original file line number Diff line number Diff line change
@@ -85,7 +85,7 @@ def get_seq_records(
"Alphabet given. Only checking for terminating *!\n"
)
check_sequences = False
with open(filename, "rU") as handle:
with open(filename, "r") as handle:
records = list(SeqIO.parse(handle, file_format, alphabet=alphabet))
del handle
if check_sequences:
4 changes: 0 additions & 4 deletions graphein/ml/datasets/foldcomp_dataset.py
Original file line number Diff line number Diff line change
@@ -236,10 +236,6 @@ def download(self):
asyncio.run(_)
os.chdir(curr_dir)
log.info("Download complete.")
# log.info("Moving files to raw directory...")

# for f in self._database_files:
# shutil.move(f, self.root)
else:
log.info(f"FoldComp database already downloaded: {self.root}.")

11 changes: 5 additions & 6 deletions graphein/ml/datasets/pdb_data.py
Original file line number Diff line number Diff line change
@@ -120,6 +120,7 @@ def __init__(
).name

self.list_columns = ["ligands"]
self.labels = labels

# Data
self.download_metadata()
@@ -165,9 +166,10 @@ def download_metadata(self):
self._download_entry_metadata()
self._download_exp_type()
self._download_pdb_availability()
self._download_pdb_chain_cath_uniprot_map()
self._download_cath_id_cath_code_map()
self._download_pdb_chain_ec_number_map()
if self.labels:
self._download_pdb_chain_cath_uniprot_map()
self._download_cath_id_cath_code_map()
self._download_pdb_chain_ec_number_map()

def get_unavailable_pdb_files(
self, splits: Optional[List[str]] = None
@@ -643,15 +645,12 @@ def _parse_cath_code(self) -> Dict[str, str]:
with gzip.open(
self.root_dir / self.cath_id_cath_code_filename, "rt"
) as f:
print(f)
for line in f:
print(line)
try:
cath_id, cath_version, cath_code, cath_segment = (
line.strip().split()
)
cath_mapping[cath_id] = cath_code
print(cath_id, cath_code)
except ValueError:
continue
return cath_mapping
14 changes: 11 additions & 3 deletions graphein/protein/graphs.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import networkx as nx
import numpy as np
import pandas as pd
from biopandas.mmcif import PandasMmcif
from biopandas.mmtf import PandasMmtf
from biopandas.pdb import PandasPdb
from loguru import logger as log
@@ -111,21 +112,28 @@ def read_pdb_to_dataframe(
atomic_df = PandasPdb().read_pdb(path)
elif path.endswith(".mmtf") or path.endswith(".mmtf.gz"):
atomic_df = PandasMmtf().read_mmtf(path)
elif (
path.endswith(".cif")
or path.endswith(".cif.gz")
or path.endswith(".mmcif")
or path.endswith(".mmcif.gz")
):
atomic_df = PandasMmcif().read_mmcif(path)
else:
raise ValueError(
f"File {path} must be either .pdb(.gz), .mmtf(.gz) or .ent, not {path.split('.')[-1]}"
f"File {path} must be either .pdb(.gz), .mmtf(.gz), .(mm)cif(.gz) or .ent, not {path.split('.')[-1]}"
)
elif uniprot_id is not None:
atomic_df = PandasPdb().fetch_pdb(
uniprot_id=uniprot_id, source="alphafold2-v3"
)
else:
atomic_df = PandasPdb().fetch_pdb(pdb_code)

atomic_df = atomic_df.get_model(model_index)
if len(atomic_df.df["ATOM"]) == 0:
raise ValueError(f"No model found for index: {model_index}")

if isinstance(atomic_df, PandasMmcif):
atomic_df = atomic_df.convert_to_pandas_pdb()
return pd.concat([atomic_df.df["ATOM"], atomic_df.df["HETATM"]])


10 changes: 7 additions & 3 deletions graphein/protein/tensor/io.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@
conda_channel="pyg",
pip_install=True,
)
log.debug(message)
log.warning(message)

try:
import torch
@@ -60,7 +60,7 @@
conda_channel="pytorch",
pip_install=True,
)
log.debug(message)
log.warning(message)


def get_protein_length(df: pd.DataFrame, insertions: bool = True) -> int:
@@ -246,7 +246,9 @@ def protein_to_pyg(

out = Data(
coords=protein_df_to_tensor(
df, atoms_to_keep=atom_types, fill_value=fill_value_coords
df,
atoms_to_keep=atom_types,
fill_value=fill_value_coords,
),
residues=get_sequence(
df,
@@ -259,6 +261,7 @@ def protein_to_pyg(
residue_type=residue_type_tensor(df),
chains=protein_df_to_chain_tensor(df),
)

if store_het:
out.hetatms = [het_coords]

@@ -360,6 +363,7 @@ def protein_df_to_tensor(
positions[residue_indices, atom_indices] = torch.tensor(
df[["x_coord", "y_coord", "z_coord"]].values
).float()

return positions


14 changes: 7 additions & 7 deletions graphein/protein/utils.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from functools import lru_cache, partial
from multiprocessing import Pool
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
from urllib.error import HTTPError
from urllib.request import urlopen

@@ -96,7 +96,7 @@ def read_fasta(file_path: str) -> Dict[str, str]:
def download_pdb_multiprocessing(
pdb_codes: List[str],
out_dir: Union[str, Path], # type: ignore
format: str = "pdb",
format: Literal["pdb", "mmtf", "mmcif", "cif", "bcif"] = "pdb",
overwrite: bool = False,
strict: bool = False,
max_workers: int = 16,
@@ -108,7 +108,7 @@ def download_pdb_multiprocessing(
:type pdb_codes: List[str]
:param out_dir: Path to directory to download PDB structures to.
:type out_dir: Union[str, Path]
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif`` or ``bcif``.
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif``/``cif`` or ``bcif``.
:type format: str
:param overwrite: Whether to overwrite existing files, defaults to
``False``.
@@ -146,7 +146,7 @@ def download_pdb_multiprocessing(
def download_pdb(
pdb_code: str,
out_dir: Optional[Union[str, Path]] = None,
format: str = "pdb",
format: Literal["pdb", "mmtf", "mmcif", "cif", "bcif"] = "pdb",
check_obsolete: bool = False,
overwrite: bool = False,
strict: bool = True,
@@ -162,7 +162,7 @@ def download_pdb(
:param out_dir: Path to directory to download PDB structure to. If ``None``,
will download to a temporary directory.
:type out_dir: Optional[Union[str, Path]]
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif`` or ``bcif``.
:param format: Filetype to download. ``pdb``, ``mmtf``, ``mmcif``/``cif`` or ``bcif``.
:type format: str
:param check_obsolete: Whether to check for obsolete PDB codes,
defaults to ``False``. If an obsolete PDB code is found, the updated PDB
@@ -183,15 +183,15 @@ def download_pdb(
elif format == "mmtf":
BASE_URL = "https://mmtf.rcsb.org/v1.0/full/"
extension = ".mmtf.gz"
elif format == "mmcif":
elif format == "cif" or format == "mmcif":
BASE_URL = "https://files.rcsb.org/download/"
extension = ".cif.gz"
elif format == "bcif":
BASE_URL = "https://models.rcsb.org/"
extension = ".bcif.gz"
else:
raise ValueError(
f"Invalid format: {format}. Must be 'pdb', 'mmtf', 'mmcif' or 'bcif'."
f"Invalid format: {format}. Must be 'pdb', 'mmtf', '(mm)cif' or 'bcif'."
)

# Make output directory if it doesn't exist or set it to tempdir if None
Loading

0 comments on commit 848a3f8

Please sign in to comment.