Skip to content

Commit

Permalink
ci: lint _generate_testdata.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
DaniBodor committed Aug 5, 2024
1 parent 7de3af9 commit 8df424c
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions tests/data/hdf5/_generate_testdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
"PATH_TEST = ROOT / \"tests\"\n",
"import glob\n",
"import os\n",
"import re\n",
"import sys\n",
"\n",
"import h5py\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from deeprank2.dataset import save_hdf5_keys\n",
Expand Down Expand Up @@ -79,7 +76,7 @@
" chain_ids=[chain_id1, chain_id2],\n",
" targets=targets,\n",
" pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n",
" )\n",
" ),\n",
" )\n",
"\n",
" # Generate graphs and save them in hdf5 files\n",
Expand Down Expand Up @@ -128,8 +125,8 @@
"csv_data = pd.read_csv(csv_file_path)\n",
"csv_data.cluster = csv_data.cluster.fillna(-1)\n",
"pdb_ids_csv = [pdb_file.split(\"/\")[-1].split(\".\")[0].replace(\"-\", \"_\") for pdb_file in pdb_files]\n",
"clusters = [csv_data[pdb_id == csv_data.ID].cluster.values[0] for pdb_id in pdb_ids_csv]\n",
"bas = [csv_data[pdb_id == csv_data.ID].measurement_value.values[0] for pdb_id in pdb_ids_csv]\n",
"clusters = [csv_data[pdb_id == csv_data.ID].cluster.to_numpy()[0] for pdb_id in pdb_ids_csv]\n",
"bas = [csv_data[pdb_id == csv_data.ID].measurement_value.to_numpy()[0] for pdb_id in pdb_ids_csv]\n",
"\n",
"queries = QueryCollection()\n",
"print(f\"Adding {len(pdb_files)} queries to the query collection ...\")\n",
Expand All @@ -147,7 +144,7 @@
" \"cluster\": clusters[i],\n",
" },\n",
" pssm_paths={\"M\": pssm_m[i], \"P\": pssm_p[i]},\n",
" )\n",
" ),\n",
" )\n",
"print(\"Queries created and ready to be processed.\\n\")\n",
"\n",
Expand Down Expand Up @@ -183,7 +180,7 @@
"test_ids = []\n",
"\n",
"with h5py.File(hdf5_path, \"r\") as hdf5:\n",
" for key in hdf5.keys():\n",
" for key in hdf5:\n",
" feature_value = float(hdf5[key][target][feature][()])\n",
" if feature_value in train_clusters:\n",
" train_ids.append(key)\n",
Expand All @@ -192,7 +189,7 @@
" elif feature_value in test_clusters:\n",
" test_ids.append(key)\n",
"\n",
" if feature_value in clusters.keys():\n",
" if feature_value in clusters:\n",
" clusters[int(feature_value)] += 1\n",
" else:\n",
" clusters[int(feature_value)] = 1\n",
Expand Down Expand Up @@ -278,8 +275,12 @@
" targets = compute_ppi_scores(pdb_path, ref_path)\n",
" queries.add(\n",
" ProteinProteinInterfaceQuery(\n",
" pdb_path=pdb_path, resolution=\"atom\", chain_ids=[chain_id1, chain_id2], targets=targets, pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2}\n",
" )\n",
" pdb_path=pdb_path,\n",
" resolution=\"atom\",\n",
" chain_ids=[chain_id1, chain_id2],\n",
" targets=targets,\n",
" pssm_paths={chain_id1: pssm_path1, chain_id2: pssm_path2},\n",
" ),\n",
" )\n",
"\n",
"# Generate graphs and save them in hdf5 files\n",
Expand All @@ -303,7 +304,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
},
"orig_nbformat": 4,
"vscode": {
Expand Down

0 comments on commit 8df424c

Please sign in to comment.