ETC3250/5250 Tutorial 8

Explainability (XAI)

Author

Prof. Di Cook

Published

22 April 2024

Load the libraries and avoid conflicts, and prepare data
# Load libraries used everywhere
library(tidyverse)
library(tidymodels)
library(patchwork)
library(mulgar)
library(GGally)
library(tourr)
library(plotly)
library(randomForest)
library(colorspace)
library(ggthemes)
library(conflicted)
library(DALEXtra)
# devtools::install_github("dandls/counterfactuals")
# You need the GitHub version
library(counterfactuals)
library(kernelshap)
library(shapviz)
library(lime)
library(palmerpenguins)
conflicts_prefer(dplyr::filter)
conflicts_prefer(dplyr::select)
conflicts_prefer(dplyr::slice)
conflicts_prefer(palmerpenguins::penguins)

p_tidy <- penguins |>
  select(species, bill_length_mm:body_mass_g) |>
  rename(bl=bill_length_mm,
         bd=bill_depth_mm,
         fl=flipper_length_mm,
         bm=body_mass_g) |>
  na.omit()

# `id` variable added to ensure we know which case
# when investigating the models
p_std <- p_tidy |>
  mutate_if(is.numeric, function(x) (x-mean(x))/sd(x)) |>
  mutate(id = 1:nrow(p_tidy)) 

# Only use Adelie and Chinstrap, because explainers are easy to calculate with only two groups
p_sub <- p_std |>
  filter(species != "Gentoo") |>
  mutate(species = factor(species)) # Fix factor

# Split intro training and test sets
set.seed(821)
p_split <- p_sub |> 
  select(species:id) |>
  initial_split(prop = 2/3, 
                strata=species)
p_train <- training(p_split)
p_test <- testing(p_split)

🎯 Objectives

The goal for this week is learn to diagnose a model, and understand variable importance and local explainers.

🔧 Preparation

  • Make sure you have all the necessary libraries installed. There are a few new ones this week!

Exercises:

Open your project for this unit called iml.Rproj.

CHALLENGE QUESTION: In the penguins data, find an observation where you think various models might differ in their prediction. Try to base your choice on the structure of the various models, not from that observation being in an overlap area between class clusters. (The code like that below will help to identify observations by their row number.)

ggplot(p_std, aes(x=bl, y=bd, 
                  colour=species, 
                  label=id)) +
  geom_point() +
  scale_color_discrete_divergingx(palette = "Zissou 1") +
  theme(legend.position="none", 
        axis.text = element_blank())
ggplotly()

1. Create and build - construct the (non-linear) model

Fit a random forest model to a subset of the penguins data containing only Adelie and Chinstrap. Report the summaries, and which variable(s) are globally important.

set.seed(857)
p_rf <- randomForest(species ~ ., 
                     data = p_train[,-6])

2. How does your model affect individuals?

Compute LIME, counterfactuals and SHAP for these cases: 19, 28, 37, 111, 122, 129, 281, 292, 295, 305. Report these values. (You can use this code to compute these.)

# Filter the selected observations
p_new <- p_sub |>
  filter(id %in% c(19, 28, 37, 111, 122, 129, 281, 292, 295, 305))
# Compute LIME and re-organise
p_rf_exp <- DALEX::explain(model = p_rf,  
                        data = p_new[, 2:5],
                        y = p_new$species == "Adelie",
                        verbose = FALSE)
model_type.dalex_explainer <-
  DALEXtra::model_type.dalex_explainer
predict_model.dalex_explainer <-
  DALEXtra::predict_model.dalex_explainer
p_rf_lime <- predict_surrogate(
  explainer = p_rf_exp, 
              new_observation = p_new[, 2:5], 
              n_features = 4, 
              n_permutations = 100,
              type = "lime")
# Re-format the output
p_rf_lime_coef <- p_rf_lime |>
  select(case, model_intercept, feature_weight) |>
  as_tibble() |>
  rename(bl = feature_weight) |>
  mutate(bd = 0, 
         fl = 0,
         bm = 0)
p_rf_lime_coef$bd[seq(1, 39, 4)] <-
  p_rf_lime_coef$bl[seq(2, 40, 4)]
p_rf_lime_coef$fl[seq(1, 39, 4)] <-
  p_rf_lime_coef$bl[seq(3, 40, 4)]
p_rf_lime_coef$bm[seq(1, 39, 4)] <-
  p_rf_lime_coef$bl[seq(4, 40, 4)]
p_rf_lime_coef <- p_rf_lime_coef[seq(1, 39, 4),]
# Compute counterfactuals
predictor_rf = iml::Predictor$new(p_rf, 
                                  type = "prob")
p_classif <- counterfactuals::NICEClassif$new(
  predictor_rf)

p_new_cf <- p_new
p_new_cf$species <- as.character(ifelse(p_new[,1]=="Adelie", 
                           "Chinstrap", "Adelie"))
# Must not match prediction
p_new_cf$species[6] <- as.character(p_new$species[6])
for (i in 1:nrow(p_new)) {
  p_cf = p_classif$find_counterfactuals(
    x_interest = p_new[i,2:5], 
    desired_class = as.character(p_new_cf[i,1]),
                 desired_prob = c(0.5, 1)
  )
  p_new_cf[i, 2] <- p_cf$data$bl
  p_new_cf[i, 3] <- p_cf$data$bd
  p_new_cf[i, 4] <- p_cf$data$fl
  p_new_cf[i, 5] <- p_cf$data$bm
}
# Compute SHAP values
p_explain <- kernelshap(
    p_rf,
    p_new[,2:5], 
    p_new[,2:5],
    verbose = FALSE
  )
p_shap <- p_new |>
  mutate(shapAbl = p_explain$S$Adelie[,1],
         shapAbd = p_explain$S$Adelie[,2],
         shapAfl = p_explain$S$Adelie[,3],
         shapAbm = p_explain$S$Adelie[,4],)

3. Putting the pieces back together

Explain what you learn about the fitted model by studying the local explainers for the selected cases. (You will want to compare the suggested variable importance of the local explainers, for an observation, and then make plots of those variables with the observation of interest marked.)

👋 Finishing up

Make sure you say thanks and good-bye to your tutor. This is a time to also report what you enjoyed and what you found difficult.