Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

local methods #12

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 269 additions & 0 deletions exercises/08_iml/08-01_Local_Interpretations.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
---
title: "Local Interpretation Methods"
subtitle: "In-class exercise 08-1"
output:
prettydoc::html_pretty:
theme: hpstr
toc: yes
toc_depth: 2
css: ../_template/prettydoc-hpstr.css
self_contained: yes
---

```{r, include=FALSE}
sys.source("setup.R", envir = knitr::knit_global())
```

# Goal

Apply what you have learned about local interpretation methods to real data.

# Prerequisites

We need the following libraries:

```{r library}
library('randomForest')
library('iml')
library("rchallenge")
library("counterfactuals")
```

# Data

### German Credit

The data set contains 1000 observations with 21 features and the binary target variable `credit_risk`. For illustrative purposes, we only consider 8 of the 21 features in the following:

```{r}
data(german, package = "rchallenge")
credit = german[, c("duration", "amount", "purpose", "age",
"employment_duration", "housing", "number_credits", "credit_risk")]
```

### Wine

In the data set are quality ratings (assessed by blind tasting, on a scale from 0 to 10) for roughly 6500 red and white wines with given physiochemical properties:

```{r}
wine = read.csv("wine.csv")
wine = na.omit(wine)
wine$type = as.factor(wine$type)
str(wine)
```

# Fit a Random Forest

We fit a random forest to predict the quality and use a train-test split, reserving 5,000 observations for testing:

```{r}
set.seed(124)
ids = sample(1:nrow(wine), size = 500, replace = FALSE)
train = wine[-ids,]
test = wine[ids,]

# Fit random forest:
forest = randomForest(quality ~ ., data = train)
forest
```

# Exercises

In the following exercises, we shall always compute interpretations on the test data.

## Exercise 1: Predictor Objects in `iml`

The interpretation methods in `iml` require the model and data to be wrapped in a `Predictor` object. Create a `Predictor` object for the `wine` (test) data and the `forest` model.

<details>
<summary>**Hint 1:**</summary>

Checkout the help page `?Predictor` to understand how to construct the `Predictor` object.

</details>

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("<!--")
```

### Solution:

<details>
<summary>Click me:</summary>

```{r, eval = show.solution}
rfpred = Predictor$new(forest, data = test, y = "quality")
```

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("-->")
```

## Exercise 2: LIME

LIME (Local Interpretable Model-Agnostic Explanations) is a technique used to explain individual predictions of any machine learning model by approximating it locally with an interpretable model. This simplification helps identify which features contribute most to a specific prediction, providing valuable insights into model behavior. Let's create a `LocalModel` object, an R6 class in `iml`, to explain the random forest prediction for the first observation in the test data with a linear regression model. The results will be stored in the field `$results`.

<details>
<summary>**Hint 1:**</summary>

Checkout the help page `?LocalModel` to understand how to construct the `LocalModel` object. The details section explains the explicit method used.

</details>

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("<!--")
```

### Solution:

<details>
<summary>Click me:</summary>

```{r, eval = show.solution}
lime <- LocalModel$new(rfpred, x.interest = test[1, ])
lime$results
```

This can also be visualized:

```{r, eval = show.solution}
plot(lime)
```

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("-->")
```

## Exercise 3: ICE

Individual Conditional Expectation (ICE) plots display one line per instance that shows how the instance’s prediction changes when a feature changes. An ICE plot visualizes the dependence of the prediction on a feature for each instance separately, resulting in one line per instance, compared to one line overall in partial dependence plots (PDP). Create an ICE plot for the feature `alcohol` with the class `FeatureEffect`. Describe the feature-target relationship you can observe from this.

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("<!--")
```

### Solution:

<details>
<summary>Click me:</summary>

```{r, eval = show.solution}
ice <- FeatureEffect$new(predictor = rfpred,
feature = "alcohol",
method = "ice",
grid.size = 30)
plot(ice)
```

For most (but not all) observations, alcohol seems to positively influence predicted quality.

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("-->")
```

## Exercise 4: Counterfactual Explanations

Counterfactual explanations try to identify the smallest possible changes to the input features of a given observation that would lead to a different prediction. In other words, a counterfactual explanation provides an answer to the question: “What changes in the current feature values are necessary to achieve a different prediction?”. To do so, we first train a model to predict whether a credit is good or bad, omitting one observation from the training data, which is the individual for which we want to find such explanations later on.

```{r}
forest2 = randomForest(credit_risk ~ ., data = credit[-998L,])
```

Let's use Multi-Objective Counterfactual (MOC) Explanations [Dandl et al. 2020](https://link.springer.com/chapter/10.1007/978-3-030-58112-1_31) to compute counterfactuals. The corresponding `counterfactuals` package builds upon `iml`, e.g. by using the `Predictor` objects:

```{r}
pred = Predictor$new(forest2, type = "prob")
x_interest = credit[998L, ]
pred$predict(x_interest)
```

For the individual of interest, the model predicts a probability of being a bad credit of 0.38.

### 4a) Create a MOC object

Now, we want to examine which factors need to be changed to increase the predicted probability of being a good credit risk to more than 60%. Since we want to apply MOC to a classification model, we can initialize a `MOCClassif` object. Use the documentation to learn about the arguments required and construct a suitable object. We want the following things:

* Penalization: Individuals whose prediction is farther away from the desired prediction than epsilon should be penalized.
* A person's `age` and `employment_duration` are non-actionable parameters, meaning that they cannot be changed. This should be considered.
* Set `quiet = TRUE`, `termination_crit = "genstag"` and `n_generations = 10L`.

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("<!--")
```

### Solution:

<details>
<summary>Click me:</summary>

```{r, eval = show.solution}
moc_classif = MOCClassif$new(
pred, epsilon = 0, fixed_features = c("age", "employment_duration"),
quiet = TRUE, termination_crit = "genstag", n_generations = 10L
)
```

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("-->")
```

### 4b) Find counterfactuals

Now, use the `$find_counterfactuals()` method to find counterfactuals for `x_interest`. Remember that we aim to find counterfactuals with a predicted probability of being a good credit of at least 60%.

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("<!--")
```

### Solution:

<details>
<summary>Click me:</summary>

```{r, eval = show.solution}
cfactuals = moc_classif$find_counterfactuals(
x_interest, desired_class = "good", desired_prob = c(0.6, 1)
)
```

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("-->")
```

### 4c) Analyze counterfactuals

The resulting `Counterfactuals` object holds the counterfactuals in the `data` field and possesses several methods for their evaluation and visualization. Printing a `Counterfactuals` object gives an overview of the results:

```{r}
print(cfactuals)
```

The `$predict()` method returns the predictions for the counterfactuals:

```{r}
head(cfactuals$predict(), 3L)
```

By design, not all counterfactuals generated with MOC have a prediction equal to the desired prediction. Inspect the docu for the `Counterfactuals` class to find a method that subsets counterfactuals to omit all counterfactuals that do not achieve the desired prediction.

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("<!--")
```

### Solution:

<details>
<summary>Click me:</summary>

```{r, eval = show.solution}
cfactuals$subset_to_valid()
head(cfactuals$data)
```

```{r, eval = !show.solution, echo = FALSE, results='asis'}
cat("-->")
```

# Summary

We learned how to implement interpretation methods for LIME, ICE and Counterfactuals.