diff --git a/deeprank2/dataset.py b/deeprank2/dataset.py index 4e3959d9..3718f301 100644 --- a/deeprank2/dataset.py +++ b/deeprank2/dataset.py @@ -151,32 +151,37 @@ def _check_hdf5_files(self) -> None: self.hdf5_paths.remove(hdf5_path) def _check_task_and_classes(self, task: str, classes: str | None = None) -> None: - if self.target in [targets.IRMSD, targets.LRMSD, targets.FNAT, targets.DOCKQ]: - self.task = targets.REGRESS - - elif self.target in [targets.BINARY, targets.CAPRI]: - self.task = targets.CLASSIF - + # Determine the task based on the target or use the provided task + if task is None: + target_to_task_map = { + targets.IRMSD: targets.REGRESS, + targets.LRMSD: targets.REGRESS, + targets.FNAT: targets.REGRESS, + targets.DOCKQ: targets.REGRESS, + targets.BINARY: targets.CLASSIF, + targets.CAPRI: targets.CLASSIF, + } + self.task = target_to_task_map.get(self.target) else: self.task = task + # Validate the task if self.task not in [targets.CLASSIF, targets.REGRESS] and self.target is not None: msg = f"User target detected: {self.target} -> The task argument must be 'classif' or 'regress', currently set as {self.task}" raise ValueError(msg) - if task != self.task and task is not None: + # Warn if the user-set task does not match the determined task + if task and task != self.task: warnings.warn( - f"Target {self.target} expects {self.task}, but was set to task {task} by user.\nUser set task is ignored and {self.task} will be used.", + f"Target {self.target} expects {self.task}, but was set to task {task} by user. User set task is ignored and {self.task} will be used.", ) + # Handle classification task if self.task == targets.CLASSIF: if classes is None: - self.classes = [0, 1] - _log.info(f"Target classes set to: {self.classes}") - else: - self.classes = classes - + self.classes = [0, 1, 2, 3, 4, 5] if self.target == targets.CAPRI else [0, 1] self.classes_to_index = {class_: index for index, class_ in enumerate(self.classes)} + _log.info(f"Target classes set to: {self.classes}") else: self.classes = None self.classes_to_index = None diff --git a/docs/getstarted.md b/docs/getstarted.md index 1f2eb561..ed908a98 100644 --- a/docs/getstarted.md +++ b/docs/getstarted.md @@ -391,6 +391,8 @@ output_test = pd.read_hdf(os.path.join("", "output_exporter. The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results. +For classification tasks, the `output` column contains a list of probabilities that each class occurs, and each list sums to 1 (for more details, please see documentation on the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)). Note that the order of the classes in the list depends on the `classes` attribute of the DeeprankDataset instances. For classification tasks, if `classes` is not specified (as in this example case), it is defaulted to [0, 1]. + Example for plotting training loss curves using [Plotly Express](https://plotly.com/python/plotly-express/): ```python diff --git a/tutorials/training.ipynb b/tutorials/training.ipynb index ae8870c4..401c3403 100644 --- a/tutorials/training.ipynb +++ b/tutorials/training.ipynb @@ -420,12 +420,8 @@ "metadata": {}, "outputs": [], "source": [ - "output_train = pd.read_hdf(\n", - " os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\"\n", - ")\n", - "output_test = pd.read_hdf(\n", - " os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\"\n", - ")\n", + "output_train = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", + "output_test = pd.read_hdf(os.path.join(output_path, f\"gnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", "output_train.head()" ] }, @@ -436,7 +432,9 @@ "source": [ "The dataframes contain `phase`, `epoch`, `entry`, `output`, `target`, and `loss` columns, and can be easily used to visualize the results.\n", "\n", - "For example, the loss across the epochs can be plotted for the training and the validation sets:\n" + "For classification tasks, the `output` column contains a list of probabilities that each class occurs, and each list sums to 1 (for more details, please see documentation on the [softmax function](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)). Note that the order of the classes in the list depends on the `classes` attribute of the DeeprankDataset instances. For classification tasks, if `classes` is not specified (as in this example case), it is defaulted to [0, 1].\n", + "\n", + "The loss across the epochs can be plotted for the training and the validation sets:\n" ] }, { @@ -671,12 +669,8 @@ "metadata": {}, "outputs": [], "source": [ - "output_train = pd.read_hdf(\n", - " os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\"\n", - ")\n", - "output_test = pd.read_hdf(\n", - " os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\"\n", - ")\n", + "output_train = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"training\")\n", + "output_test = pd.read_hdf(os.path.join(output_path, f\"cnn_{task}\", \"output_exporter.hdf5\"), key=\"testing\")\n", "output_train.head()" ] }, @@ -767,7 +761,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.12" }, "orig_nbformat": 4 },