Skip to content

Commit

Permalink
docs: 📝 Update tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Jan 10, 2025
1 parent f5de7f4 commit ad90782
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 23 deletions.
43 changes: 33 additions & 10 deletions docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ QuAC, or Quantitative Attributions with Counterfactuals, is a method for generat
Let's assume, for instance, that you have images of cells grown in two different conditions.
To your eye, the phenotypic difference between the two conditions is hidden within the cell-to-cell variability of the dataset, but you know it is there because you've trained a classifier to differentiate the two conditions and it works. So how do you pull out the differences?

We begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image.
Assuming that you already have a classifier that does your task, we begin by training a generative neural network to convert your images from one class to another. Here, we'll use a StarGAN. This allows us to go from our real, **query** image, to our **generated** image.
Using information learned from **reference** images, the StarGAN is trained in such a way that the **generated** image will have a different class!

While very powerful, these generative networks *can potentially* make some changes that are not necessary to the classification.
Expand All @@ -28,33 +28,56 @@ It is as close as possible to the original image, with only the necessary change
:width: 800
:align: center

Before you begin, download [the data]() and [the pre-trained models]() for an example.
Then, make sure you've installed QuAC by following the :doc:`Installation guide <install>`.

Before you begin, make sure you've installed QuAC by following the :doc:`Installation guide <install>`.

The classifier
==============
To use QuAC, we assume that you have a classifier trained on your data.
There are many different packages already to help you do that, for this code-base we will need you to have the weights to your classifier as a JIT-compiled `pytorch` model.

If you just want to try QuAC as a learning experience, you can use one of the datasets `in this collection <https://doi.org/10.25378/janelia.c.7620737.v1>`_, and the pre-trained models we provide.

The conversion network
===============================
Once you've set up your data and your classifier, you can move on to training the conversion network.
We'll use a StarGAN for this.
There are two options for training the StarGAN, but we recommend :doc:`training it using a YAML file <tutorials/train_yaml>`.
This will make it easier to keep track of your experiments!
If you prefer to define parameters directly in Python, however, you can follow the :doc:`alternative training tutorial <tutorials/train>` instead.
Note that in both cases, you will need the JIT-compiled classifier model!

You have two options for training the StarGAN, you can either :doc:`define parameters directly in Python <tutorials/train>` or :doc:`train it using a YAML file <tutorials/train_yaml>`.
We recommend the latter, which will make it easier to keep track of your experiments!
Once you've trained a decent model, generate a set of images using the :doc:`image generation tutorial <tutorials/generate>` before moving on to the next steps.
Once you've trained a decent model, you can generate a set of images using the :doc:`image generation tutorial <tutorials/generate>`.
We recommend taking a look at your generated images, to make sure that they look like what you expect.
If that is the case, you can move on to the next steps!

.. toctree::
:maxdepth: 1

tutorials/train
tutorials/train_yaml
Generating images <tutorials/generate>

Attribution and evaluation
==========================

With the generated images in hand, we can now run the attribution and evaluation steps.
With the generated images in hand, we can now run the :doc:`attribution <tutorials/attribute>` step, then the :doc:`evaluation <tutorials/evaluate>` step.
These two steps allow us to overcome the limitations of the generative network to create *truly* minimal counterfactual images, and to score the query-counterfactual pairs based on how well they explain the classifier.

Visualizing results
===================

Finally, we can visualize the results of the attribution and evaluation steps using the :doc:`visualization tutorial <tutorials/visualize>`.
This will allow you to see the quantification results, in the form of QuAC curves.
It will also help you choose the best attribution method for each example, and load the counterfactual visual explanations for these examples.

Table of Contents
=================
Here's a list of all available tutorials, in case you want to navigate directly to one of them.

.. toctree::
:maxdepth: 1

Training the generator (recommended) <tutorials/train_yaml>
Training the generator (alternative) <tutorials/train>
Generating images <tutorials/generate>
Attribution <tutorials/attribute>
Evaluation <tutorials/evaluate>
Visualizing results <tutorials/visualize>
12 changes: 4 additions & 8 deletions docs/source/tutorials/generate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
How to generate images from a pre-trained network
=================================================

.. attention::
This tutorial is still under construction. Come back soon for updates!


Defining the dataset
====================

Expand All @@ -22,7 +18,7 @@ For example, below, we are going to be using the validation data, and our source
from quac.generate import load_data
img_size = 224
data_directory = Path("root_directory/val/0_No_DR")
data_directory = Path("/path/to/directory/holding/the/data/source_class")
dataset = load_data(data_directory, img_size, grayscale=False)
Expand Down Expand Up @@ -86,7 +82,7 @@ Finally, we can run the image generation.
from quac.generate import get_counterfactual
from torchvision.utils import save_image
output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/")
output_directory = Path("/path/to/output/latent/source_class/target_class/")
for x, name in tqdm(dataset):
xcf = get_counterfactual(
Expand Down Expand Up @@ -117,7 +113,7 @@ The first thing we need to do is to get the reference images.
.. code-block:: python
:linenos:
reference_data_directory = Path(f"{root_directory}/val/1_Mild")
reference_data_directory = Path("/path/to/directory/holding/the/data/target_class")
reference_dataset = load_data(reference_data_directory, img_size, grayscale=False)
Loading the StarGAN
Expand Down Expand Up @@ -148,7 +144,7 @@ Finally, we combine the two by changing the `kind` in our counterfactual generat
from torchvision.utils import save_image
output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/")
output_directory = Path("/path/to/output/reference/source_class/target_class/")
for x, name in tqdm(dataset):
xcf = get_counterfactual(
Expand Down
165 changes: 160 additions & 5 deletions docs/source/tutorials/visualize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,164 @@
Visualizing the results
=======================

.. attention::
This tutorial is still under construction. Come back soon for updates!
In this tutorial, we will show you how to visualize the results of the attribution and evaluation steps.
Make sure to modify the paths to the reports and the classifier to match your setup!

.. image:: ../assets/quac.png
:width: 100
:align: center
Obtaining the QuAC curves
=========================
Let's start by loading the reports obtained in the previous step.

.. code-block:: python
:linenos:
report_directory = "/path/to/report/directory/"
from quac.report import Report
reports = {
method: Report(name=method)
for method in [
"DDeepLift",
"DIntegratedGradients",
]
}
for method, report in reports.items():
report.load(report_directory + method + "/default.json")
Next, we can plot the QuAC curves for each method.
This allows us to get an idea of how well each method is performing, overall.

.. code-block:: python
:linenos:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
for method, report in reports.items():
report.plot_curve(ax=ax)
# Add the legend
plt.legend()
plt.show()
Choosing the best attribution method for each sample
====================================================

While one attribution method may be better than another on average, it is possible that the best method for a given example is different.
Therefore, we will make a list of the best method for each example by comparing the quac scores.
.. code-block:: python
:linenos:
quac_scores = pd.DataFrame(
{method: report.quac_scores for method, report in reports.items()}
)
best_methods = quac_scores.idxmax(axis=1)
best_quac_scores = quac_scores.max(axis=1)
We'll also want to load the classifier at this point, so we can look at the classifications of the counterfactual images.

.. code-block:: python
:linenos:
import torch
classifier = torch.jit.load("/path/to/classifier/model.pt")
Choosing the best examples
==========================
Next we want to choose the best example, given the best method.
This is done by ordering the examples by the QuAC score, and then choosing the one with the highest score.

.. code-block:: python
:linenos:
order = best_quac_scores[::-1].argsort()
# For example, choose the 10th best example
idx = 10
# Get the corresponding report
report = reports[best_methods[order[idx]]]
We will then load that example and its counterfactual from its path, and visualize it.
We also want to see the classification of both the original and the counterfactual.

.. code-block:: python
:linenos:
# Transform to apply to the images so they match each other
# loading
from PIL import Image
image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]]
image, generated_image = Image.open(image_path), Image.open(generated_path)
prediction = report.predictions[order[idx]]
target_prediction = report.target_predictions[order[idx]]
image_path, generated_path = report.paths[order[idx]], report.target_paths[order[idx]]
image, generated_image = Image.open(image_path), Image.open(generated_path)
prediction = report.predictions[order[idx]]
target_prediction = report.target_predictions[order[idx]]
Loading the attribution
=======================
We next want to load the attribution for the example, and visualize it.

.. code-block:: python
:linenos:
attribution_path = report.attribution_paths[order[idx]]
attribution = np.load(attribution_path)
Getting the processor
=====================
We want to see the specific mask that was optimal in this case.
To do this, we will need to get the optimal threshold, and get the processor used for masking.

.. code-block:: python
:linenos:
from quac.evaluation import Processor
gaussian_kernel_size = 11
struc = 10
thresh = report.optimal_thresholds()[order[idx]]
print(thresh)
processor = Processor(gaussian_kernel_size=gaussian_kernel_size, struc=struc)
mask, _ = processor.create_mask(attribution, thresh)
rgb_mask = mask.transpose(1, 2, 0)
# zero-out the green and blue channels
rgb_mask[:, :, 1] = 0
rgb_mask[:, :, 2] = 0
counterfactual = np.array(generated_image) / 255 * rgb_mask + np.array(image) / 255 * (1.0 - rgb_mask)
Let's also get the classifier output for the counterfactual image.

.. code-block:: python
:linenos:
classifier_output = classifier(
torch.tensor(counterfactual).permute(2, 0, 1).float().unsqueeze(0).to(device)
)
counterfactual_prediction = softmax(classifier_output[0].detach().cpu().numpy())
Visualizing the results
=======================
Finally, we can visualize the results.

.. code-block:: python
:linenos:
fig, axes = plt.subplots(2, 4)
axes[1, 0].imshow(image)
axes[0, 0].bar(np.arange(len(prediction)), prediction)
axes[1, 1].imshow(generated_image)
axes[0, 1].bar(np.arange(len(target_prediction)), target_prediction)
axes[0, 2].bar(np.arange(len(counterfactual_prediction)), counterfactual_prediction)
axes[1, 2].imshow(counterfactual)
axes[1, 3].imshow(rgb_mask)
axes[0, 3].axis("off")
fig.suptitle(f"QuAC Score: {report.quac_scores[order[idx]]}")
plt.show()
You can now see the original image, the generated image, the counterfactual image, and the mask.
From here, you can choose to visualize other examples, of save the images for later use.

0 comments on commit ad90782

Please sign in to comment.