ETC3250/5250 Tutorial 6

Trees and forests

Author

Prof. Di Cook

Published

8 April 2024

Load the libraries and avoid conflicts
# Load libraries used everywhere
library(tidyverse)
library(tidymodels)
library(patchwork)
library(mulgar)
library(palmerpenguins)
library(GGally)
library(tourr)
library(MASS)
library(discrim)
library(classifly)
library(detourr)
library(crosstalk)
library(plotly)
library(viridis)
library(colorspace)
library(randomForest)
library(geozoo)
library(ggbeeswarm)
library(conflicted)
conflicts_prefer(dplyr::filter)
conflicts_prefer(dplyr::select)
conflicts_prefer(dplyr::slice)
conflicts_prefer(palmerpenguins::penguins)
conflicts_prefer(viridis::viridis_pal)

options(digits=2)
p_tidy <- penguins |>
  select(species, bill_length_mm:body_mass_g) |>
  rename(bl=bill_length_mm,
         bd=bill_depth_mm,
         fl=flipper_length_mm,
         bm=body_mass_g) |>
  filter(!is.na(bl)) |>
  arrange(species) |>
  na.omit()
p_tidy_std <- p_tidy |>
    mutate_if(is.numeric, function(x) (x-mean(x))/sd(x))

🎯 Objectives

The goal for this week is learn to fit, diagnose, assess assumptions, and predict from classification tree and random forest models.

🔧 Preparation

  • Make sure you have all the necessary libraries installed. There are a few new ones this week!

Exercises:

Open your project for this unit called iml.Rproj. For all the work we will use the penguins data. Start with splitting it into a training and test set, as follows.

set.seed(1156)
p_sub <- p_tidy_std |>
  filter(species != "Gentoo") |>
  mutate(species = factor(species)) |>
  select(species, bl, bm)
p_split <- initial_split(p_sub, 2/3, strata = species)
p_tr <- training(p_split)
p_ts <- testing(p_split)

1. Becoming a car mechanic - looking under the hood at the tree algorithm

  1. Write down the equation for the Gini measure of impurity, for two groups, and the parameter \(p\) which is the proportion of observations in class 1. Specify the domain of the function, and determine the value of \(p\) which gives the maximum value, and report what that maximum function value is.
  1. For two groups, how would the impurity of a split be measured? Give the equation.
  1. Below is an R function to compute the Gini impurity for a particular split on a single variable. Work through the code of the function, and document what each step does. Make sure to include a not on what the minsplit parameter, does to prevent splitting on the edges fewer than the specified number of observations.
# This works for two classes, and one variable
mygini <- function(p) {
  g <- 0
  if (p>0 && p<1) {
    g <- 2*p*(1-p)
  }

  return(g)
}

mysplit <- function(x, spl, cl, minsplit=5) {
  # Assumes x is sorted
  # Count number of observations
  n <- length(x)
  
  # Check number of classes
  cl_unique <- unique(cl)
  
  # Split into two subsets on the given value
  left <- x[x<spl]
  cl_left <- cl[x<spl]
  n_l <- length(left)

  right <- x[x>=spl]
  cl_right <- cl[x>=spl]
  n_r <- length(right)
  
  # Don't calculate is either set is less than minsplit
  if ((n_l < minsplit) | (n_r < minsplit)) 
    impurity = NA
  else {
    # Compute the Gini value for the split
    p_l <- length(cl_left[cl_left == cl_unique[1]])/n_l
    p_r <- length(cl_right[cl_right == cl_unique[1]])/n_r
    if (is.na(p_l)) p_l<-0.5
    if (is.na(p_r)) p_r<-0.5
    impurity <- (n_l/n)*mygini(p_l) + (n_r/n)*mygini(p_r)
  }
  return(impurity)
}
  1. Apply the function to compute the value for all possible splits for the body mass (bm), setting minsplit to be 1, so that all possible splits will be evaluated. Make a plot of these values vs the variable.
  1. Use your function to compute the first two steps of a classification tree model for separating Adelie from Chinstrap penguins, after setting minsplit to be 5. Make a scatterplot of the two variables that would be used in the splits, with points coloured by species, and the splits as line segments.

2. Digging deeper into diagnosing an error

  1. Fit the random forest model to the full penguins data.
  1. Report the confusion matrix.
  1. Use linked brushing to learn which was the Gentoo penguin that the model was confused about. When we looked at the data in a tour, there was one Gentoo penguin that was an outlier, appearing to be away from the other Gentoos and closer to the Chinstrap group. We would expect this to be the penguin that the forest model is confused about. Is it?

Have a look at the other misclassifications, to understand whether they are ones we’d expect to misclassify, or whether the model is not well constructed.

p_cl <- p_tr2 |>
  mutate(pspecies = p_fit_rf$fit$predicted) |>
  dplyr::select(bl:bm, species, pspecies) |>
  mutate(sp_jit = jitter(as.numeric(species)),
         psp_jit = jitter(as.numeric(pspecies)))
p_cl_shared <- SharedData$new(p_cl)

detour_plot <- detour(p_cl_shared, tour_aes(
  projection = bl:bm,
  colour = species)) |>
  tour_path(grand_tour(2),
            max_bases=50, fps = 60) |>
  show_scatter(alpha = 0.9, axes = FALSE,
               width = "100%", height = "450px")

conf_mat <- plot_ly(p_cl_shared,
                    x = ~psp_jit,
                    y = ~sp_jit,
                    color = ~species,
                    colors = viridis_pal(option = "D")(3),
                    height = 450) |>
  highlight(on = "plotly_selected",
            off = "plotly_doubleclick") |>
  add_trace(type = "scatter",
            mode = "markers")

bscols(
  detour_plot, conf_mat,
  widths = c(5, 6)
)

3. Deciding on variables in a large data problem

  1. Fit a random forest to the bushfire data. You can read more about the bushfire data at https://dicook.github.io/mulgar_book/A2-data.html. Examine the votes matrix using a tour. What do you learn about the confusion between fire causes?

This code might help:

data(bushfires)

bushfires_sub <- bushfires[,c(5, 8:45, 48:55, 57:60)] |>
  mutate(cause = factor(cause))

set.seed(1239)
bf_split <- initial_split(bushfires_sub, 3/4, strata=cause)
bf_tr <- training(bf_split)
bf_ts <- testing(bf_split)

rf_spec <- rand_forest(mtry=5, trees=1000) |>
  set_mode("classification") |>
  set_engine("ranger", probability = TRUE, 
             importance="permutation")
bf_fit_rf <- rf_spec |> 
  fit(cause~., data = bf_tr)

# Create votes matrix data
bf_rf_votes <- bf_fit_rf$fit$predictions |>
  as_tibble() |>
  mutate(cause = bf_tr$cause)

# Project 4D into 3D
proj <- t(geozoo::f_helmert(4)[-1,])
bf_rf_v_p <- as.matrix(bf_rf_votes[,1:4]) %*% proj
colnames(bf_rf_v_p) <- c("x1", "x2", "x3")
bf_rf_v_p <- bf_rf_v_p |>
  as.data.frame() |>
  mutate(cause = bf_tr$cause)
  
# Add simplex
simp <- simplex(p=3)
sp <- data.frame(simp$points)
colnames(sp) <- c("x1", "x2", "x3")
sp$cause = ""
bf_rf_v_p_s <- bind_rows(sp, bf_rf_v_p) |>
  mutate(cause = factor(cause))
labels <- c("accident" , "arson", 
                "burning_off", "lightning", 
                rep("", nrow(bf_rf_v_p)))
# Examine votes matrix with bounding simplex
animate_xy(bf_rf_v_p_s[,1:3], col = bf_rf_v_p_s$cause, 
           axes = "off", half_range = 1.3,
           edges = as.matrix(simp$edges),
           obs_labels = labels)
  1. Check the variable importance. Plot the most important variables.

This code might help:

bf_fit_rf$fit$variable.importance |> 
  as_tibble() |> 
  rename(imp=value) |>
  mutate(var = colnames(bf_tr)[1:50]) |>
  select(var, imp) |>
  arrange(desc(imp)) |> 
  print(n=50)

4. Can boosting better detect bushfire cause?

Fit a boosted tree model using xgboost to the bushfires data. You can use the code below. Compute the confusion tables and the balanced accuracy for the test data for both the forest model and the boosted tree model, to make the comparison.

set.seed(121)
bf_spec2 <- boost_tree() |>
  set_mode("classification") |>
  set_engine("xgboost")
bf_fit_bt <- bf_spec2 |> 
  fit(cause~., data = bf_tr)

👋 Finishing up

Make sure you say thanks and good-bye to your tutor. This is a time to also report what you enjoyed and what you found difficult.