Skip to content

Commit

Permalink
ci: fix linting for tutorial notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Aug 5, 2024
1 parent 8df424c commit 66456a4
Show file tree
Hide file tree
Showing 3 changed files with 828 additions and 826 deletions.
56 changes: 28 additions & 28 deletions tutorials/data_generation_ppi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import glob\n",
"import os\n",
"\n",
"import h5py\n",
"import matplotlib.image as img\n",
"import matplotlib.pyplot as plt\n",
"from deeprank2.query import QueryCollection\n",
"from deeprank2.query import ProteinProteinInterfaceQuery, ProteinProteinInterfaceQuery\n",
"import pandas as pd\n",
"\n",
"from deeprank2.dataset import GraphDataset\n",
"from deeprank2.features import components, contact\n",
"from deeprank2.utils.grid import GridSettings, MapMethod\n",
"from deeprank2.dataset import GraphDataset"
"from deeprank2.query import ProteinProteinInterfaceQuery, QueryCollection\n",
"from deeprank2.utils.grid import GridSettings, MapMethod"
]
},
{
Expand Down Expand Up @@ -131,14 +132,15 @@
"metadata": {},
"outputs": [],
"source": [
"def get_pdb_files_and_target_data(data_path):\n",
"def get_pdb_files_and_target_data(data_path: str) -> tuple[list[str], list[float]]:\n",
" csv_data = pd.read_csv(os.path.join(data_path, \"BA_values.csv\"))\n",
" pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.pdb\"))\n",
" pdb_files.sort()\n",
" pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0] for pdb_file in pdb_files]\n",
" csv_data_indexed = csv_data.set_index(\"ID\")\n",
" csv_data_indexed = csv_data_indexed.loc[pdb_ids_csv]\n",
" bas = csv_data_indexed.measurement_value.values.tolist()\n",
" bas = csv_data_indexed.measurement_value.tolist()\n",
"\n",
" return pdb_files, bas\n",
"\n",
"\n",
Expand Down Expand Up @@ -192,9 +194,9 @@
"\n",
"influence_radius = 8 # max distance in Å between two interacting residues/atoms of two proteins\n",
"max_edge_length = 8\n",
"binary_target_value = 500\n",
"\n",
"print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n",
"count = 0\n",
"for i in range(len(pdb_files)):\n",
" queries.add(\n",
" ProteinProteinInterfaceQuery(\n",
Expand All @@ -204,16 +206,15 @@
" influence_radius=influence_radius,\n",
" max_edge_length=max_edge_length,\n",
" targets={\n",
" \"binary\": int(float(bas[i]) <= 500), # binary target value\n",
" \"binary\": int(float(bas[i]) <= binary_target_value),\n",
" \"BA\": bas[i], # continuous target value\n",
" },\n",
" )\n",
" ),\n",
" )\n",
" count += 1\n",
" if count % 20 == 0:\n",
" print(f\"{count} queries added to the collection.\")\n",
" if i + 1 % 20 == 0:\n",
" print(f\"{i+1} queries added to the collection.\")\n",
"\n",
"print(\"Queries ready to be processed.\\n\")"
"print(f\"{i+1} queries ready to be processed.\\n\")"
]
},
{
Expand Down Expand Up @@ -340,8 +341,8 @@
"source": [
"processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n",
"dataset = GraphDataset(processed_data, target=\"binary\")\n",
"df = dataset.hdf5_to_pandas()\n",
"df.head()"
"dataset_df = dataset.hdf5_to_pandas()\n",
"dataset_df.head()"
]
},
{
Expand All @@ -358,7 +359,7 @@
"metadata": {},
"outputs": [],
"source": [
"fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n",
"fname = os.path.join(processed_data_path, \"residue\", \"res_mass_distance_electrostatic\")\n",
"dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n",
"\n",
"im = img.imread(fname + \".png\")\n",
Expand Down Expand Up @@ -429,9 +430,9 @@
"\n",
"influence_radius = 5 # max distance in Å between two interacting residues/atoms of two proteins\n",
"max_edge_length = 5\n",
"binary_target_value = 500\n",
"\n",
"print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n",
"count = 0\n",
"for i in range(len(pdb_files)):\n",
" queries.add(\n",
" ProteinProteinInterfaceQuery(\n",
Expand All @@ -441,16 +442,15 @@
" influence_radius=influence_radius,\n",
" max_edge_length=max_edge_length,\n",
" targets={\n",
" \"binary\": int(float(bas[i]) <= 500), # binary target value\n",
" \"binary\": int(float(bas[i]) <= binary_target_value),\n",
" \"BA\": bas[i], # continuous target value\n",
" },\n",
" )\n",
" ),\n",
" )\n",
" count += 1\n",
" if count % 20 == 0:\n",
" print(f\"{count} queries added to the collection.\")\n",
" if i + 1 % 20 == 0:\n",
" print(f\"{i+1} queries added to the collection.\")\n",
"\n",
"print(\"Queries ready to be processed.\\n\")"
"print(f\"{i+1} queries ready to be processed.\\n\")"
]
},
{
Expand Down Expand Up @@ -495,8 +495,8 @@
"source": [
"processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n",
"dataset = GraphDataset(processed_data, target=\"binary\")\n",
"df = dataset.hdf5_to_pandas()\n",
"df.head()"
"dataset_df = dataset.hdf5_to_pandas()\n",
"dataset_df.head()"
]
},
{
Expand Down Expand Up @@ -540,7 +540,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
},
"orig_nbformat": 4
},
Expand Down
59 changes: 29 additions & 30 deletions tutorials/data_generation_srv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,19 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import glob\n",
"import os\n",
"\n",
"import h5py\n",
"import matplotlib.image as img\n",
"import matplotlib.pyplot as plt\n",
"from deeprank2.query import QueryCollection\n",
"from deeprank2.query import SingleResidueVariantQuery, SingleResidueVariantQuery\n",
"import pandas as pd\n",
"\n",
"from deeprank2.dataset import GraphDataset\n",
"from deeprank2.domain.aminoacidlist import amino_acids_by_code\n",
"from deeprank2.features import components, contact\n",
"from deeprank2.utils.grid import GridSettings, MapMethod\n",
"from deeprank2.dataset import GraphDataset"
"from deeprank2.query import QueryCollection, SingleResidueVariantQuery\n",
"from deeprank2.utils.grid import GridSettings, MapMethod"
]
},
{
Expand Down Expand Up @@ -132,19 +133,20 @@
"metadata": {},
"outputs": [],
"source": [
"def get_pdb_files_and_target_data(data_path):\n",
"def get_pdb_files_and_target_data(data_path: str) -> tuple[list[str], list[int], list[str], list[str], list[float]]:\n",
" csv_data = pd.read_csv(os.path.join(data_path, \"srv_target_values.csv\"))\n",
" pdb_files = glob.glob(os.path.join(data_path, \"pdb\", \"*.ent\"))\n",
" pdb_files.sort()\n",
" pdb_file_names = [os.path.basename(pdb_file) for pdb_file in pdb_files]\n",
" csv_data_indexed = csv_data.set_index(\"pdb_file\")\n",
" csv_data_indexed = csv_data_indexed.loc[pdb_file_names]\n",
" res_numbers = csv_data_indexed.res_number.values.tolist()\n",
" res_wildtypes = csv_data_indexed.res_wildtype.values.tolist()\n",
" res_variants = csv_data_indexed.res_variant.values.tolist()\n",
" targets = csv_data_indexed.target.values.tolist()\n",
" pdb_names = csv_data_indexed.index.values.tolist()\n",
" res_numbers = csv_data_indexed.res_number.tolist()\n",
" res_wildtypes = csv_data_indexed.res_wildtype.tolist()\n",
" res_variants = csv_data_indexed.res_variant.tolist()\n",
" targets = csv_data_indexed.target.tolist()\n",
" pdb_names = csv_data_indexed.index.tolist()\n",
" pdb_files = [data_path + \"/pdb/\" + pdb_name for pdb_name in pdb_names]\n",
"\n",
" return pdb_files, res_numbers, res_wildtypes, res_variants, targets\n",
"\n",
"\n",
Expand Down Expand Up @@ -204,7 +206,6 @@
"max_edge_length = 4.5 # ??\n",
"\n",
"print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n",
"count = 0\n",
"for i in range(len(pdb_files)):\n",
" queries.add(\n",
" SingleResidueVariantQuery(\n",
Expand All @@ -218,13 +219,12 @@
" targets={\"binary\": targets[i]},\n",
" influence_radius=influence_radius,\n",
" max_edge_length=max_edge_length,\n",
" )\n",
" ),\n",
" )\n",
" count += 1\n",
" if count % 20 == 0:\n",
" print(f\"{count} queries added to the collection.\")\n",
" if i + 1 % 20 == 0:\n",
" print(f\"{i+1} queries added to the collection.\")\n",
"\n",
"print(f\"Queries ready to be processed.\\n\")"
"print(f\"{i+1} queries ready to be processed.\\n\")"
]
},
{
Expand Down Expand Up @@ -358,8 +358,8 @@
"source": [
"processed_data = glob.glob(os.path.join(processed_data_path, \"residue\", \"*.hdf5\"))\n",
"dataset = GraphDataset(processed_data, target=\"binary\")\n",
"df = dataset.hdf5_to_pandas()\n",
"df.head()"
"dataset_df = dataset.hdf5_to_pandas()\n",
"dataset_df.head()"
]
},
{
Expand All @@ -376,7 +376,8 @@
"metadata": {},
"outputs": [],
"source": [
"fname = os.path.join(processed_data_path, \"residue\", \"_\".join([\"res_mass\", \"distance\", \"electrostatic\"]))\n",
"fname = os.path.join(processed_data_path, \"residue\", \"res_mass_distance_electrostatic\")\n",
"\n",
"dataset.save_hist(features=[\"res_mass\", \"distance\", \"electrostatic\"], fname=fname)\n",
"\n",
"im = img.imread(fname + \".png\")\n",
Expand Down Expand Up @@ -450,7 +451,6 @@
"max_edge_length = 4.5 # ??\n",
"\n",
"print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n",
"count = 0\n",
"for i in range(len(pdb_files)):\n",
" queries.add(\n",
" SingleResidueVariantQuery(\n",
Expand All @@ -464,13 +464,12 @@
" targets={\"binary\": targets[i]},\n",
" influence_radius=influence_radius,\n",
" max_edge_length=max_edge_length,\n",
" )\n",
" ),\n",
" )\n",
" count += 1\n",
" if count % 20 == 0:\n",
" print(f\"{count} queries added to the collection.\")\n",
" if i + 1 % 20 == 0:\n",
" print(f\"{i+1} queries added to the collection.\")\n",
"\n",
"print(\"Queries ready to be processed.\\n\")"
"print(f\"{i+1} queries ready to be processed.\\n\")"
]
},
{
Expand Down Expand Up @@ -515,8 +514,8 @@
"source": [
"processed_data = glob.glob(os.path.join(processed_data_path, \"atomic\", \"*.hdf5\"))\n",
"dataset = GraphDataset(processed_data, target=\"binary\")\n",
"df = dataset.hdf5_to_pandas()\n",
"df.head()"
"dataset_df = dataset.hdf5_to_pandas()\n",
"dataset_df.head()"
]
},
{
Expand Down Expand Up @@ -565,7 +564,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
},
"orig_nbformat": 4
},
Expand Down
Loading

0 comments on commit 66456a4

Please sign in to comment.