ETC3250/5250 Tutorial 8

Explainability (XAI)

Author

Prof. Di Cook

Published

22 April 2024

Load the libraries and avoid conflicts, and prepare data
# Load libraries used everywhere
library(tidyverse)
library(tidymodels)
library(patchwork)
library(mulgar)
library(GGally)
library(tourr)
library(plotly)
library(randomForest)
library(colorspace)
library(ggthemes)
library(conflicted)
library(DALEXtra)
# devtools::install_github("dandls/counterfactuals")
# You need the GitHub version
library(counterfactuals)
library(kernelshap)
library(shapviz)
library(lime)
library(palmerpenguins)
conflicts_prefer(dplyr::filter)
conflicts_prefer(dplyr::select)
conflicts_prefer(dplyr::slice)
conflicts_prefer(palmerpenguins::penguins)

p_tidy <- penguins |>
  select(species, bill_length_mm:body_mass_g) |>
  rename(bl=bill_length_mm,
         bd=bill_depth_mm,
         fl=flipper_length_mm,
         bm=body_mass_g) |>
  na.omit()

# `id` variable added to ensure we know which case
# when investigating the models
p_std <- p_tidy |>
  mutate_if(is.numeric, function(x) (x-mean(x))/sd(x)) |>
  mutate(id = 1:nrow(p_tidy)) 

# Only use Adelie and Chinstrap, because explainers are easy to calculate with only two groups
p_sub <- p_std |>
  filter(species != "Gentoo") |>
  mutate(species = factor(species)) # Fix factor

# Split intro training and test sets
set.seed(821)
p_split <- p_sub |> 
  select(species:id) |>
  initial_split(prop = 2/3, 
                strata=species)
p_train <- training(p_split)
p_test <- testing(p_split)

🎯 Objectives

The goal for this week is learn to diagnose a model, and understand variable importance and local explainers.

🔧 Preparation

  • Make sure you have all the necessary libraries installed. There are a few new ones this week!

Exercises:

Open your project for this unit called iml.Rproj.

CHALLENGE QUESTION: In the penguins data, find an observation where you think various models might differ in their prediction. Try to base your choice on the structure of the various models, not from that observation being in an overlap area between class clusters. (The code like that below will help to identify observations by their row number.)

ggplot(p_std, aes(x=bl, y=bd, 
                  colour=species, 
                  label=id)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  theme(legend.position="none", 
        axis.text = element_blank())
ggplotly()
ggscatmat(p_std, 
          columns=2:5, 
          color="species", 
          alpha=0.5) +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  theme(legend.position="none", 
        axis.text = element_blank())

A scatterplot matrix is useful to get an overall look at the data. Then plot two variables, and if you use plotly to mouse over the plot you can get the row number to show. This can help find observations.

These are the ones that I selected to investigate below:

  • 19, 28, 37, 111, 122, 129 Adelie
  • 185, 189, 250, 253 Gentoo
  • 281, 292, 295, 305 Chinstrap

Most of these choices are because they are outliers in their group, and model fits could possibly have different boundaries in these regions: orthogonal to axes like in a tree/forest rather than oblique like LDA, logistic and NN.

Did you find any others?

1. Create and build - construct the (non-linear) model

Fit a random forest model to a subset of the penguins data containing only Adelie and Chinstrap. Report the summaries, and which variable(s) are globally important.

set.seed(857)
p_rf <- randomForest(species ~ ., 
                     data = p_train[,-6])
p_rf

Call:
 randomForest(formula = species ~ ., data = p_train[, -6]) 
               Type of random forest: classification
                     Number of trees: 500
No. of variables tried at each split: 2

        OOB estimate of  error rate: 4.83%
Confusion matrix:
          Adelie Chinstrap class.error
Adelie        96         4  0.04000000
Chinstrap      3        42  0.06666667
p_rf$importance
   MeanDecreaseGini
bl        47.172447
bd         4.336294
fl         6.556190
bm         3.545455

bl is much more important than any of the other variables.

p_rf_tr_pred <- p_train |>
  mutate(pspecies = p_rf$predicted)

p_rf_ts_pred <- p_test |>
  mutate(pspecies = predict(p_rf, 
                            p_test, 
                            type="response")) 
accuracy(p_rf_ts_pred, species, pspecies)
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.973
p_rf_ts_pred |>
  count(species, pspecies) |>
  group_by(species) |>
  mutate(Accuracy = n[species==pspecies]/sum(n)) |>
  pivot_wider(names_from = "pspecies", 
              values_from = n, 
              values_fill = 0) |>
  select(species, Adelie, Chinstrap, Accuracy)
# A tibble: 2 × 4
# Groups:   species [2]
  species   Adelie Chinstrap Accuracy
  <fct>      <int>     <int>    <dbl>
1 Adelie        49         2    0.961
2 Chinstrap      0        23    1    

2. How does your model affect individuals?

Compute LIME, counterfactuals and SHAP for these cases: 19, 28, 37, 111, 122, 129, 281, 292, 295, 305. Report these values. (You can use this code to compute these.)

# Filter the selected observations
p_new <- p_sub |>
  filter(id %in% c(19, 28, 37, 111, 122, 129, 281, 292, 295, 305))
# Compute LIME and re-organise
p_rf_exp <- DALEX::explain(model = p_rf,  
                        data = p_new[, 2:5],
                        y = p_new$species == "Adelie",
                        verbose = FALSE)
model_type.dalex_explainer <-
  DALEXtra::model_type.dalex_explainer
predict_model.dalex_explainer <-
  DALEXtra::predict_model.dalex_explainer
p_rf_lime <- predict_surrogate(
  explainer = p_rf_exp, 
              new_observation = p_new[, 2:5], 
              n_features = 4, 
              n_permutations = 100,
              type = "lime")
# Re-format the output
p_rf_lime_coef <- p_rf_lime |>
  select(case, model_intercept, feature_weight) |>
  as_tibble() |>
  rename(bl = feature_weight) |>
  mutate(bd = 0, 
         fl = 0,
         bm = 0)
p_rf_lime_coef$bd[seq(1, 39, 4)] <-
  p_rf_lime_coef$bl[seq(2, 40, 4)]
p_rf_lime_coef$fl[seq(1, 39, 4)] <-
  p_rf_lime_coef$bl[seq(3, 40, 4)]
p_rf_lime_coef$bm[seq(1, 39, 4)] <-
  p_rf_lime_coef$bl[seq(4, 40, 4)]
p_rf_lime_coef <- p_rf_lime_coef[seq(1, 39, 4),]
# Compute counterfactuals
predictor_rf = iml::Predictor$new(p_rf, 
                                  type = "prob")
p_classif <- counterfactuals::NICEClassif$new(
  predictor_rf)

p_new_cf <- p_new
p_new_cf$species <- as.character(ifelse(p_new[,1]=="Adelie", 
                           "Chinstrap", "Adelie"))
# Must not match prediction
p_new_cf$species[6] <- as.character(p_new$species[6])
for (i in 1:nrow(p_new)) {
  p_cf = p_classif$find_counterfactuals(
    x_interest = p_new[i,2:5], 
    desired_class = as.character(p_new_cf[i,1]),
                 desired_prob = c(0.5, 1)
  )
  p_new_cf[i, 2] <- p_cf$data$bl
  p_new_cf[i, 3] <- p_cf$data$bd
  p_new_cf[i, 4] <- p_cf$data$fl
  p_new_cf[i, 5] <- p_cf$data$bm
}
# Compute SHAP values
p_explain <- kernelshap(
    p_rf,
    p_new[,2:5], 
    p_new[,2:5],
    verbose = FALSE
  )
p_shap <- p_new |>
  mutate(shapAbl = p_explain$S$Adelie[,1],
         shapAbd = p_explain$S$Adelie[,2],
         shapAfl = p_explain$S$Adelie[,3],
         shapAbm = p_explain$S$Adelie[,4],)
# Check misclassified
p_rf_tr_pred |> 
  filter(species != pspecies)
# A tibble: 7 × 7
  species        bl      bd      fl       bm    id pspecies 
  <fct>       <dbl>   <dbl>   <dbl>    <dbl> <int> <fct>    
1 Adelie     0.381   2.20   -0.492  -0.00219    19 Chinstrap
2 Adelie    -0.132   0.683  -0.634  -0.127      99 Chinstrap
3 Adelie     0.307   1.59   -0.705   0.497     111 Chinstrap
4 Adelie    -0.151   1.04   -0.278  -0.875     131 Chinstrap
5 Chinstrap -0.279   0.0754 -1.42   -0.750     295 Adelie   
6 Chinstrap -0.554  -0.279  -0.990  -1.25      305 Adelie   
7 Chinstrap -0.0773  0.480   0.0771 -1.00      339 Adelie   
p_rf_ts_pred |> 
  filter(species != pspecies)
# A tibble: 2 × 7
  species     bl    bd     fl      bm    id pspecies 
  <fct>    <dbl> <dbl>  <dbl>   <dbl> <int> <fct>    
1 Adelie  0.344  0.886 -0.278 -0.0645    73 Chinstrap
2 Adelie  0.0326 0.430  0.646 -0.252    129 Chinstrap

Observations 19, 111, 295 and 305 are misclassified in the training set.

Observation 129 is misclassified in the test set.

p_rf_lime_coef
# A tibble: 10 × 6
   case  model_intercept      bl      bd       fl       bm
   <chr>           <dbl>   <dbl>   <dbl>    <dbl>    <dbl>
 1 1               0.426  0.456  -0.166   0.0359  -0.0568 
 2 2               0.703 -0.627  -0.0843  0.0180   0.0130 
 3 3               0.519 -0.223  -0.119   0.0621   0.0281 
 4 4               0.473  0.336  -0.171   0.00202 -0.206  
 5 5               0.701 -0.671   0.0235 -0.00609  0.0523 
 6 6               0.334  0.288   0.128   0.0712  -0.0506 
 7 7               0.339  0.480   0.0657  0.0248   0.00233
 8 8               0.376  0.511   0.0263 -0.00937 -0.0214 
 9 9               0.445 -0.0888  0.162  -0.0538   0.0316 
10 10              0.564 -0.617   0.0968  0.0826   0.152  
  • For most observations, bl is most important.
  • For observation 4 (111) bd and bm have some importance.
  • For observations 6 (129) and 9 (295) bd has some importance.
p_new |> mutate(bl = bl-p_new_cf$bl,
                bd = bd-p_new_cf$bd,
                fl = fl-p_new_cf$fl,
                bm = bm-p_new_cf$bm)
# A tibble: 10 × 6
   species        bl    bd     fl    bm    id
   <fct>       <dbl> <dbl>  <dbl> <dbl> <int>
 1 Adelie     0      1.32   0         0    19
 2 Adelie    -1.50   0      0         0    28
 3 Adelie    -0.0366 0      0         0    37
 4 Adelie     0      0.709 -0.284     0   111
 5 Adelie    -0.403  0      0         0   122
 6 Adelie     0.549  0      0         0   129
 7 Chinstrap  0.916  0      0         0   281
 8 Chinstrap  3.50   0      0         0   292
 9 Chinstrap  0.641  0      0         0   295
10 Chinstrap  0.0733 0      0         0   305

You need to look at the difference between the original values, and the new values to understand the importance.

  • For most observations bl is the only important variable.
  • For observation 1 (19), bd is most important
  • For observation 4 (111), bd and fl are important
p_shap[,6:10]
# A tibble: 10 × 5
      id shapAbl  shapAbd  shapAfl shapAbm
   <int>   <dbl>    <dbl>    <dbl>   <dbl>
 1    19  -0.281  0.419    0.0187   0.344 
 2    28   0.566  0.0156   0.00312 -0.0844
 3    37   0.475  0.0750   0.0250  -0.0750
 4   111  -0.150  0.138    0.0875   0.425 
 5   122   0.559  0.00937  0.00937 -0.0781
 6   129  -0.397 -0.0594  -0.0594   0.0156
 7   281  -0.366 -0.0156  -0.00312 -0.116 
 8   292  -0.416 -0.0156  -0.00312 -0.0656
 9   295  -0.213 -0.0875  -0.0500  -0.15  
10   305   0.222 -0.478   -0.0281  -0.216 
  • SHAP values suggest bl is most important for most of these observations
  • For observation 1 (19) and 10 (305), bl, bd and bm are most important
  • For observation 4 (111), similarly, but bm is more important than bl and bd
  • For observation 7 (281) and 9 (295), bm is important, along with bl

3. Putting the pieces back together

Explain what you learn about the fitted model by studying the local explainers for the selected cases. (You will want to compare the suggested variable importance of the local explainers, for an observation, and then make plots of those variables with the observation of interest marked.)

The model has non-linear boundaries because the local variable importance do have different values than the global variable importance. So we have something to investigate locally!

Here’s a summary of what is learned about each observation form the local explainers.

id species predicted LIME CF SHAP
19 Adelie Chinstrap bd bd, bm
28 Adelie Adelie
37 Adelie Adelie
111 Adelie Chinstrap bd, bm bd, fl bm
122 Adelie Adelie
129 Adelie Chinstrap bd
281 Chinstrap Chinstrap bm
292 Chinstrap Chinstrap
295 Chinstrap Adelie bd fl bm
305 Chinstrap Adelie bd, bm

Plotting two variables and identifying the observation could help understand the model at this point.

  • Focus on the misclassified observations, because this helps understand where the model is going wrong. We choose 19 here, which is an observation in the training set, so was used to build the model.
  • Plot the variables that are important, with the observation of interest marked. Always bl was important. For observation 19, two explainers suggested bd, with one also suggesting bm.
p1 <- ggplot(data=p_sub, aes(x=bl, y=bd, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[1,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p2 <- ggplot(data=p_sub, aes(x=bl, y=bm, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[1,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p3 <- ggplot(data=p_sub, aes(x=bd, y=bm, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[1,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p1 + p2 + p3 + plot_layout(ncol=3)

Penguin 19 is a slightly unusual. It has a high value of bd, slightly higher than all other Adelie (blue) penguins. It also has slightly higher bm compared to other Adelie penguins with similar bl values. We suspect that the model, in the way that it makes “boxy” boundaries carved the region in bd vs bl as a prediction region for the Chinstrap (red). (Note, bd and bm together show nothing interesting.)

p1 <- ggplot(data=p_sub, aes(x=bl, y=bd, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[6,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p2 <- ggplot(data=p_sub, aes(x=bl, y=bm, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[6,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p3 <- ggplot(data=p_sub, aes(x=bd, y=bm, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[6,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p1 + p2 + p3 + plot_layout(ncol=3)

Penguin 129 is in the test set. Only LIME suggested any variables other than bl were important, and these were bd and bm. We can see that in each of these variables when plotted against bl that this penguin is in the confusion region between Adelie and Chinstrap. The error is likely because the training set had no other Adelie similar bl, bd and bm characteristics, that all penguins in this region in the training set were Chinstrap.

p1 <- ggplot(data=p_sub, aes(x=bl, y=fl, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[9,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p2 <- ggplot(data=p_sub, aes(x=bl, y=bm, 
                  colour=species)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  geom_point(data=p_new[9,], shape=1, size=3) +
  theme(legend.position="none", 
        axis.text = element_blank())
p1 + p2 + plot_layout(ncol=2)

Penguin 295 is a Chinstrap and was in the training set. It is unusual in both fl and bm relative to bl. The error most likely occurs because of the “boxy” boundaries of forests. Most of the penguins with these characteristics are Adelie, hence the misclassification.

👋 Finishing up

Make sure you say thanks and good-bye to your tutor. This is a time to also report what you enjoyed and what you found difficult.