From ad907821e88cbff7fc89e5ddc0e23c48545d4ce9 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Fri, 10 Jan 2025 10:16:54 -0500 Subject: [PATCH] docs: :memo: Update tutorials --- docs/source/tutorials.rst | 43 ++++++-- docs/source/tutorials/generate.rst | 12 +- docs/source/tutorials/visualize.rst | 165 +++++++++++++++++++++++++++- 3 files changed, 197 insertions(+), 23 deletions(-) diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst index 57c9bf0..01845e4 100644 --- a/docs/source/tutorials.rst +++ b/docs/source/tutorials.rst @@ -14,7 +14,7 @@ QuAC, or Quantitative Attributions with Counterfactuals, is a method for generat Let's assume, for instance, that you have images of cells grown in two different conditions. To your eye, the phenotypic difference between the two conditions is hidden within the cell-to-cell variability of the dataset, but you know it is there because you've trained a classifier to differentiate the two conditions and it works. So how do you pull out the differences? -We begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image. +Assuming that you already have a classifier that does your task, we begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image. Using information learned from **reference** images, the StarGAN is trained in such a way that the **generated** image will have a different class! While very powerful, these generative networks *can potentially* make some changes that are not necessary to the classification. @@ -28,33 +28,56 @@ It is as close as possible to the original image, with only the necessary change :width: 800 :align: center -Before you begin, download [the data]() and [the pre-trained models]() for an example. -Then, make sure you've installed QuAC by following the :doc:`Installation guide `. +Before you begin, make sure you've installed QuAC by following the :doc:`Installation guide `. + +The classifier +============== +To use QuAC, we assume that you have a classifier trained on your data. +There are many different packages already to help you do that, for this code-base we will need you to have the weights to your classifier as a JIT-compiled `pytorch` model. + +If you just want to try QuAC as a learning experience, you can use one of the datasets `in this collection `_, and the pre-trained models we provide. The conversion network =============================== +Once you've set up your data and your classifier, you can move on to training the conversion network. +We'll use a StarGAN for this. +There are two options for training the StarGAN, but we recommend :doc:`training it using a YAML file `. +This will make it easier to keep track of your experiments! +If you prefer to define parameters directly in Python, however, you can follow the :doc:`alternative training tutorial ` instead. +Note that in both cases, you will need the JIT-compiled classifier model! -You have two options for training the StarGAN, you can either :doc:`define parameters directly in Python ` or :doc:`train it using a YAML file `. -We recommend the latter, which will make it easier to keep track of your experiments! -Once you've trained a decent model, generate a set of images using the :doc:`image generation tutorial ` before moving on to the next steps. +Once you've trained a decent model, you can generate a set of images using the :doc:`image generation tutorial `. +We recommend taking a look at your generated images, to make sure that they look like what you expect. +If that is the case, you can move on to the next steps! .. toctree:: :maxdepth: 1 - tutorials/train - tutorials/train_yaml - Generating images Attribution and evaluation ========================== -With the generated images in hand, we can now run the attribution and evaluation steps. +With the generated images in hand, we can now run the :doc:`attribution ` step, then the :doc:`evaluation ` step. These two steps allow us to overcome the limitations of the generative network to create *truly* minimal counterfactual images, and to score the query-counterfactual pairs based on how well they explain the classifier. +Visualizing results +=================== + +Finally, we can visualize the results of the attribution and evaluation steps using the :doc:`visualization tutorial `. +This will allow you to see the quantification results, in the form of QuAC curves. +It will also help you choose the best attribution method for each example, and load the counterfactual visual explanations for these examples. + +Table of Contents +================= +Here's a list of all available tutorials, in case you want to navigate directly to one of them. + .. toctree:: :maxdepth: 1 + Training the generator (recommended) + Training the generator (alternative) + Generating images Attribution Evaluation Visualizing results diff --git a/docs/source/tutorials/generate.rst b/docs/source/tutorials/generate.rst index 0775941..c3d2a49 100644 --- a/docs/source/tutorials/generate.rst +++ b/docs/source/tutorials/generate.rst @@ -4,10 +4,6 @@ How to generate images from a pre-trained network ================================================= -.. attention:: - This tutorial is still under construction. Come back soon for updates! - - Defining the dataset ==================== @@ -22,7 +18,7 @@ For example, below, we are going to be using the validation data, and our source from quac.generate import load_data img_size = 224 - data_directory = Path("root_directory/val/0_No_DR") + data_directory = Path("/path/to/directory/holding/the/data/source_class") dataset = load_data(data_directory, img_size, grayscale=False) @@ -86,7 +82,7 @@ Finally, we can run the image generation. from quac.generate import get_counterfactual from torchvision.utils import save_image - output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/") + output_directory = Path("/path/to/output/latent/source_class/target_class/") for x, name in tqdm(dataset): xcf = get_counterfactual( @@ -117,7 +113,7 @@ The first thing we need to do is to get the reference images. .. code-block:: python :linenos: - reference_data_directory = Path(f"{root_directory}/val/1_Mild") + reference_data_directory = Path("/path/to/directory/holding/the/data/target_class") reference_dataset = load_data(reference_data_directory, img_size, grayscale=False) Loading the StarGAN @@ -148,7 +144,7 @@ Finally, we combine the two by changing the `kind` in our counterfactual generat from torchvision.utils import save_image - output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/") + output_directory = Path("/path/to/output/reference/source_class/target_class/") for x, name in tqdm(dataset): xcf = get_counterfactual( diff --git a/docs/source/tutorials/visualize.rst b/docs/source/tutorials/visualize.rst index 3bd9022..52936d0 100644 --- a/docs/source/tutorials/visualize.rst +++ b/docs/source/tutorials/visualize.rst @@ -2,9 +2,164 @@ Visualizing the results ======================= -.. attention:: - This tutorial is still under construction. Come back soon for updates! +In this tutorial, we will show you how to visualize the results of the attribution and evaluation steps. +Make sure to modify the paths to the reports and the classifier to match your setup! - .. image:: ../assets/quac.png - :width: 100 - :align: center +Obtaining the QuAC curves +========================= +Let's start by loading the reports obtained in the previous step. + +.. code-block:: python + :linenos: + report_directory = "/path/to/report/directory/" + + + from quac.report import Report + + reports = { + method: Report(name=method) + for method in [ + "DDeepLift", + "DIntegratedGradients", + ] + } + + for method, report in reports.items(): + report.load(report_directory + method + "/default.json") + +Next, we can plot the QuAC curves for each method. +This allows us to get an idea of how well each method is performing, overall. + +.. code-block:: python + :linenos: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + for method, report in reports.items(): + report.plot_curve(ax=ax) + # Add the legend + plt.legend() + plt.show() + + +Choosing the best attribution method for each sample +==================================================== + +While one attribution method may be better than another on average, it is possible that the best method for a given example is different. +Therefore, we will make a list of the best method for each example by comparing the quac scores. +.. code-block:: python + :linenos: + quac_scores = pd.DataFrame( + {method: report.quac_scores for method, report in reports.items()} + ) + best_methods = quac_scores.idxmax(axis=1) + best_quac_scores = quac_scores.max(axis=1) + +We'll also want to load the classifier at this point, so we can look at the classifications of the counterfactual images. + +.. code-block:: python + :linenos: + import torch + + classifier = torch.jit.load("/path/to/classifier/model.pt") + + +Choosing the best examples +========================== +Next we want to choose the best example, given the best method. +This is done by ordering the examples by the QuAC score, and then choosing the one with the highest score. + +.. code-block:: python + :linenos: + order = best_quac_scores[::-1].argsort() + + # For example, choose the 10th best example + idx = 10 + # Get the corresponding report + report = reports[best_methods[order[idx]]] + +We will then load that example and its counterfactual from its path, and visualize it. +We also want to see the classification of both the original and the counterfactual. + +.. code-block:: python + :linenos: + # Transform to apply to the images so they match each other + # loading + from PIL import Image + + image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]] + image, generated_image = Image.open(image_path), Image.open(generated_path) + + prediction = report.predictions[order[idx]] + target_prediction = report.target_predictions[order[idx]] + + image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]] + image, generated_image = Image.open(image_path), Image.open(generated_path) + + prediction = report.predictions[order[idx]] + target_prediction = report.target_predictions[order[idx]] + +Loading the attribution +======================= +We next want to load the attribution for the example, and visualize it. + +.. code-block:: python + :linenos: + + attribution_path = report.attribution_paths[order[idx]] + attribution = np.load(attribution_path) + +Getting the processor +===================== +We want to see the specific mask that was optimal in this case. +To do this, we will need to get the optimal threshold, and get the processor used for masking. + +.. code-block:: python + :linenos: + from quac.evaluation import Processor + + gaussian_kernel_size = 11 + struc = 10 + thresh = report.optimal_thresholds()[order[idx]] + print(thresh) + processor = Processor(gaussian_kernel_size=gaussian_kernel_size, struc=struc) + + mask, _ = processor.create_mask(attribution, thresh) + rgb_mask = mask.transpose(1, 2, 0) + # zero-out the green and blue channels + rgb_mask[:, :, 1] = 0 + rgb_mask[:, :, 2] = 0 + counterfactual = np.array(generated_image) / 255 * rgb_mask + np.array(image) / 255 * (1.0 - rgb_mask) + +Let's also get the classifier output for the counterfactual image. + +.. code-block:: python + :linenos: + + classifier_output = classifier( + torch.tensor(counterfactual).permute(2, 0, 1).float().unsqueeze(0).to(device) + ) + counterfactual_prediction = softmax(classifier_output[0].detach().cpu().numpy()) + +Visualizing the results +======================= +Finally, we can visualize the results. + +.. code-block:: python + :linenos: + + fig, axes = plt.subplots(2, 4) + axes[1, 0].imshow(image) + axes[0, 0].bar(np.arange(len(prediction)), prediction) + axes[1, 1].imshow(generated_image) + axes[0, 1].bar(np.arange(len(target_prediction)), target_prediction) + axes[0, 2].bar(np.arange(len(counterfactual_prediction)), counterfactual_prediction) + axes[1, 2].imshow(counterfactual) + axes[1, 3].imshow(rgb_mask) + axes[0, 3].axis("off") + fig.suptitle(f"QuAC Score: {report.quac_scores[order[idx]]}") + plt.show() + +You can now see the original image, the generated image, the counterfactual image, and the mask. +From here, you can choose to visualize other examples, of save the images for later use.