ETC3250/5250 Tutorial 6

Trees and forests

Author

Prof. Di Cook

Published

8 April 2024

Load the libraries and avoid conflicts
# Load libraries used everywhere
library(tidyverse)
library(tidymodels)
library(patchwork)
library(mulgar)
library(palmerpenguins)
library(GGally)
library(tourr)
library(MASS)
library(discrim)
library(classifly)
library(detourr)
library(crosstalk)
library(plotly)
library(viridis)
library(colorspace)
library(randomForest)
library(geozoo)
library(ggbeeswarm)
library(conflicted)
conflicts_prefer(dplyr::filter)
conflicts_prefer(dplyr::select)
conflicts_prefer(dplyr::slice)
conflicts_prefer(palmerpenguins::penguins)
conflicts_prefer(viridis::viridis_pal)

options(digits=2)
p_tidy <- penguins |>
  select(species, bill_length_mm:body_mass_g) |>
  rename(bl=bill_length_mm,
         bd=bill_depth_mm,
         fl=flipper_length_mm,
         bm=body_mass_g) |>
  filter(!is.na(bl)) |>
  arrange(species) |>
  na.omit()
p_tidy_std <- p_tidy |>
    mutate_if(is.numeric, function(x) (x-mean(x))/sd(x))

🎯 Objectives

The goal for this week is learn to fit, diagnose, assess assumptions, and predict from classification tree and random forest models.

🔧 Preparation

  • Make sure you have all the necessary libraries installed. There are a few new ones this week!

Exercises:

Open your project for this unit called iml.Rproj. For all the work we will use the penguins data. Start with splitting it into a training and test set, as follows.

set.seed(1156)
p_sub <- p_tidy_std |>
  filter(species != "Gentoo") |>
  mutate(species = factor(species)) |>
  select(species, bl, bm)
p_split <- initial_split(p_sub, 2/3, strata = species)
p_tr <- training(p_split)
p_ts <- testing(p_split)

1. Becoming a car mechanic - looking under the hood at the tree algorithm

  1. Write down the equation for the Gini measure of impurity, for two groups, and the parameter \(p\) which is the proportion of observations in class 1. Specify the domain of the function, and determine the value of \(p\) which gives the maximum value, and report what that maximum function value is.

\(G = p(1-p)\) where \(p\) is the proportion of class 1 in the subset of data. The domain is \([0, 1]\) and the maximum value of \(0.25\) is at \(p=0.5\).

  1. For two groups, how would the impurity of a split be measured? Give the equation.

\[p_L(p_{L1}(1-p_{L1})+p_{L2}(1-p_{L2})) + p_R(p_{R1}(1-p_{R1})+p_{R2}(1-p_{R2}))\] where \(p_L\) is the proportion of observations to the left of the split, \(p_{L1}\) is the proportion of observations of class 1 to the left of the split, and \(p_{R1}\) indicates the equivalent quantities for observations to the right of the split.

  1. Below is an R function to compute the Gini impurity for a particular split on a single variable. Work through the code of the function, and document what each step does. Make sure to include a not on what the minsplit parameter, does to prevent splitting on the edges fewer than the specified number of observations.
# This works for two classes, and one variable
mygini <- function(p) {
  g <- 0
  if (p>0 && p<1) {
    g <- 2*p*(1-p)
  }

  return(g)
}

mysplit <- function(x, spl, cl, minsplit=5) {
  # Assumes x is sorted
  # Count number of observations
  n <- length(x)
  
  # Check number of classes
  cl_unique <- unique(cl)
  
  # Split into two subsets on the given value
  left <- x[x<spl]
  cl_left <- cl[x<spl]
  n_l <- length(left)

  right <- x[x>=spl]
  cl_right <- cl[x>=spl]
  n_r <- length(right)
  
  # Don't calculate is either set is less than minsplit
  if ((n_l < minsplit) | (n_r < minsplit)) 
    impurity = NA
  else {
    # Compute the Gini value for the split
    p_l <- length(cl_left[cl_left == cl_unique[1]])/n_l
    p_r <- length(cl_right[cl_right == cl_unique[1]])/n_r
    if (is.na(p_l)) p_l<-0.5
    if (is.na(p_r)) p_r<-0.5
    impurity <- (n_l/n)*mygini(p_l) + (n_r/n)*mygini(p_r)
  }
  return(impurity)
}
  1. Apply the function to compute the value for all possible splits for the body mass (bm), setting minsplit to be 1, so that all possible splits will be evaluated. Make a plot of these values vs the variable.
x <- p_tr |> 
  select(species, bm) |>
  arrange(bm)
unique_splits <- unique(x$bm)
nsplits <- length(unique_splits)-1
splits <- (unique_splits[1:nsplits] + unique_splits[2:(nsplits+1)])/2
imp <- NULL;
for (i in 1:length(splits)) {
  s <- splits[i]
  a <- mysplit(x$bm, s, x$species, minsplit=1)
  imp <- c(imp, a)
}
d_impurity <- tibble(splits, imp)
d_impurity_bm <- d_impurity[which.min(d_impurity$imp),]
ggplot() + geom_line(data=d_impurity, aes(x=splits, y=imp)) +
  geom_rug(data=x, aes(x=bm, colour=species), alpha=0.3) + 
  ylab("Gini impurity") +
  xlab("bm") +
  scale_color_brewer("", palette="Dark2")

  1. Use your function to compute the first two steps of a classification tree model for separating Adelie from Chinstrap penguins, after setting minsplit to be 5. Make a scatterplot of the two variables that would be used in the splits, with points coloured by species, and the splits as line segments.
# bl: this is the only one needed for the first split
# because it is so better separated than any others
x <- p_tr |> 
  select(species, bl) |>
  arrange(bl)
unique_splits <- unique(x$bl)
nsplits <- length(unique_splits)-1
splits <- (unique_splits[1:nsplits] + unique_splits[2:(nsplits+1)])/2
imp <- NULL;
for (i in 1:length(splits)) {
  s <- splits[i]
  a <- mysplit(x$bl, s, x$species, minsplit=1)
  imp <- c(imp, a)
}
d_impurity <- tibble(splits, imp)
d_impurity_bl <- d_impurity[which.min(d_impurity$imp),]

ggplot() + 
  geom_line(data=d_impurity, aes(x=splits, y=imp)) +
  geom_rug(data=x, aes(x=bl, colour=species), alpha=0.3) + 
  ylab("Gini impurity") +
  xlab("bl") +
  scale_color_brewer("", palette="Dark2")

p_tr_L <- p_tr |>
  filter(bl < d_impurity_bl$splits)

p_tr_R <- p_tr |>
  filter(bl > d_impurity_bl$splits)

# Make a function to make calculations easier
best_split <- function(x, cl, minsplit=5) {
  unique_splits <- unique(x)
  nsplits <- length(unique_splits)-1
  splits <- (unique_splits[1:nsplits] + unique_splits[2:(nsplits+1)])/2
  imp <- NULL;
  for (i in 1:length(splits)) {
    s <- splits[i]
    a <- mysplit(x, s, cl, minsplit)
    imp <- c(imp, a)
  }
  d_impurity <- tibble(splits, imp)
  d_impurity_best <- d_impurity[which.min(d_impurity$imp),]
  return(d_impurity_best)
}

s1 <- best_split(p_tr$bl, p_tr$species, minsplit=5)
s2 <- best_split(p_tr_R$bm, p_tr_R$species, minsplit=5)

ggplot(p_tr, aes(x=bl, y=bm, colour=species)) +
  geom_point() +
  geom_vline(xintercept=s1$splits) +
  annotate("segment", x = s1$splits,
                xend = max(p_tr$bl),
                y = s2$splits, 
                yend = s2$splits) +
  scale_colour_brewer("", palette="Dark2") +
  theme(aspect.ratio = 1)

2. Digging deeper into diagnosing an error

  1. Fit the random forest model to the full penguins data.
set.seed(923)
p_split2 <- initial_split(p_tidy_std, 2/3,
                          strata=species)
p_tr2 <- training(p_split2)
p_ts2 <- testing(p_split2)

rf_spec <- rand_forest(mtry=2, trees=1000) |>
  set_mode("classification") |>
  set_engine("randomForest")
p_fit_rf <- rf_spec |> 
  fit(species ~ ., data = p_tr2)
  1. Report the confusion matrix.
p_fit_rf
parsnip model object


Call:
 randomForest(x = maybe_data_frame(x), y = y, ntree = ~1000, mtry = min_cols(~2,      x)) 
               Type of random forest: classification
                     Number of trees: 1000
No. of variables tried at each split: 2

        OOB estimate of  error rate: 2.6%
Confusion matrix:
          Adelie Chinstrap Gentoo class.error
Adelie        97         2      1       0.030
Chinstrap      2        43      0       0.044
Gentoo         0         1     81       0.012
  1. Use linked brushing to learn which was the Gentoo penguin that the model was confused about. When we looked at the data in a tour, there was one Gentoo penguin that was an outlier, appearing to be away from the other Gentoos and closer to the Chinstrap group. We would expect this to be the penguin that the forest model is confused about. Is it?

Have a look at the other misclassifications, to understand whether they are ones we’d expect to misclassify, or whether the model is not well constructed.

p_cl <- p_tr2 |>
  mutate(pspecies = p_fit_rf$fit$predicted) |>
  dplyr::select(bl:bm, species, pspecies) |>
  mutate(sp_jit = jitter(as.numeric(species)),
         psp_jit = jitter(as.numeric(pspecies)))
p_cl_shared <- SharedData$new(p_cl)

detour_plot <- detour(p_cl_shared, tour_aes(
  projection = bl:bm,
  colour = species)) |>
  tour_path(grand_tour(2),
            max_bases=50, fps = 60) |>
  show_scatter(alpha = 0.9, axes = FALSE,
               width = "100%", height = "450px")

conf_mat <- plot_ly(p_cl_shared,
                    x = ~psp_jit,
                    y = ~sp_jit,
                    color = ~species,
                    colors = viridis_pal(option = "D")(3),
                    height = 450) |>
  highlight(on = "plotly_selected",
            off = "plotly_doubleclick") |>
  add_trace(type = "scatter",
            mode = "markers")

bscols(
  detour_plot, conf_mat,
  widths = c(5, 6)
)

3. Deciding on variables in a large data problem

  1. Fit a random forest to the bushfire data. You can read more about the bushfire data at https://dicook.github.io/mulgar_book/A2-data.html. Examine the votes matrix using a tour. What do you learn about the confusion between fire causes?

This code might help:

data(bushfires)

bushfires_sub <- bushfires[,c(5, 8:45, 48:55, 57:60)] |>
  mutate(cause = factor(cause))

set.seed(1239)
bf_split <- initial_split(bushfires_sub, 3/4, strata=cause)
bf_tr <- training(bf_split)
bf_ts <- testing(bf_split)

rf_spec <- rand_forest(mtry=5, trees=1000) |>
  set_mode("classification") |>
  set_engine("ranger", probability = TRUE, 
             importance="permutation")
bf_fit_rf <- rf_spec |> 
  fit(cause~., data = bf_tr)

# Create votes matrix data
bf_rf_votes <- bf_fit_rf$fit$predictions |>
  as_tibble() |>
  mutate(cause = bf_tr$cause)

# Project 4D into 3D
proj <- t(geozoo::f_helmert(4)[-1,])
bf_rf_v_p <- as.matrix(bf_rf_votes[,1:4]) %*% proj
colnames(bf_rf_v_p) <- c("x1", "x2", "x3")
bf_rf_v_p <- bf_rf_v_p |>
  as.data.frame() |>
  mutate(cause = bf_tr$cause)
  
# Add simplex
simp <- simplex(p=3)
sp <- data.frame(simp$points)
colnames(sp) <- c("x1", "x2", "x3")
sp$cause = ""
bf_rf_v_p_s <- bind_rows(sp, bf_rf_v_p) |>
  mutate(cause = factor(cause))
labels <- c("accident" , "arson", 
                "burning_off", "lightning", 
                rep("", nrow(bf_rf_v_p)))
# Examine votes matrix with bounding simplex
animate_xy(bf_rf_v_p_s[,1:3], col = bf_rf_v_p_s$cause, 
           axes = "off", half_range = 1.3,
           edges = as.matrix(simp$edges),
           obs_labels = labels)

The pattern is that points are bunched at the vertex corresponding to lightning, extending along the edge leading to accident. We could also say that the points do extend on the face corresponding to lightning, accident and arson, too. The primary confusion for each of the other classes is with lightning. Few points are predicted to be burning_off because this is typically only occurring outside of fire season.

Part of the reason that the forest predicts predominantly to the lightning class is because it is a highly imbalanced problem. One approach is to change the weights for each class, to give the lightning class a lower priority. This will change the model predictions to be more often the other three classes.

  1. Check the variable importance. Plot the most important variables.

This code might help:

bf_fit_rf$fit$variable.importance |> 
  as_tibble() |> 
  rename(imp=value) |>
  mutate(var = colnames(bf_tr)[1:50]) |>
  select(var, imp) |>
  arrange(desc(imp)) |> 
  print(n=50)
p1 <- ggplot(bf_tr, aes(x=cause, y=log_dist_road)) +
  geom_quasirandom(alpha=0.5) +
  stat_summary(aes(group = cause), 
               fun = median, 
               fun.min = median, 
               fun.max = median, 
               geom = "crossbar", 
               color = "orange", 
               width = 0.7, 
               lwd = 0.5) +
  xlab("") +
  coord_flip() 
p2 <- ggplot(bf_tr, aes(x=cause, y=arf360)) +
  geom_quasirandom(alpha=0.5) +
  stat_summary(aes(group = cause), 
               fun = median, 
               fun.min = median, 
               fun.max = median, 
               geom = "crossbar", 
               color = "orange", 
               width = 0.7, 
               lwd = 0.5) +
  xlab("") +
  coord_flip()
p3 <- ggplot(bf_tr, aes(x=cause, y=log_dist_cfa)) +
  geom_quasirandom(alpha=0.5) +
  stat_summary(aes(group = cause), 
               fun = median, 
               fun.min = median, 
               fun.max = median, 
               geom = "crossbar", 
               color = "orange", 
               width = 0.7, 
               lwd = 0.5) +
  xlab("") +
  coord_flip()
p1 + p2 + p3 + plot_layout(ncol=3)

Each of these variables has some difference in median value between the classes, but none shows any separation between them. If the three most important variables show little separation, it indicates the difficulty in distinguishing between these classes. However, it looks like if the distance from a road, or CFA station is bigger, the chance of the cause being a lightning start is higher. This makes sense, because these would be locations further from human activity, and thus the fire is less likely to started by people. The arf360 relates to rain from a year ago. It also appears that if the rainfall was higher a year ago, lightning is more likely the cause. This also makes some sense, because with more rain in the previous year, there should be more vegetation. Particularly, if recent months have been dry, then there is likely a lot of dry vegetation which is combustible. Ideally we would create a new variable (feature engineering) that looks at difference in rainfall from the previous year to just before the current year’s fire season, to model these types of conditions.

4. Can boosting better detect bushfire cause?

Fit a boosted tree model using xgboost to the bushfires data. You can use the code below. Compute the confusion tables and the balanced accuracy for the test data for both the forest model and the boosted tree model, to make the comparison.

set.seed(121)
bf_spec2 <- boost_tree() |>
  set_mode("classification") |>
  set_engine("xgboost")
bf_fit_bt <- bf_spec2 |> 
  fit(cause~., data = bf_tr)

The results for the random forest are:

bf_ts_rf_pred <- bf_ts |>
  mutate(pcause = predict(bf_fit_rf, bf_ts)$.pred_class)
bal_accuracy(bf_ts_rf_pred, cause, pcause)
# A tibble: 1 × 3
  .metric      .estimator .estimate
  <chr>        <chr>          <dbl>
1 bal_accuracy macro          0.638
bf_ts_rf_pred |>
  count(cause, pcause) |>
  group_by(cause) |>
  mutate(Accuracy = n[cause==pcause]/sum(n)) |>
  pivot_wider(names_from = "pcause", 
              values_from = n, values_fill = 0) |>
  select(cause, accident, arson, burning_off, lightning, Accuracy)
# A tibble: 4 × 6
# Groups:   cause [4]
  cause       accident arson burning_off lightning Accuracy
  <fct>          <int> <int>       <int>     <int>    <dbl>
1 accident          14     0           0        19   0.424 
2 arson              2     1           0        10   0.0769
3 burning_off        0     0           1         3   0.25  
4 lightning          0     0           0       206   1     

and for the boosted tree are:

bf_ts_bt_pred <- bf_ts |>
  mutate(pcause = predict(bf_fit_bt, 
                            bf_ts)$.pred_class)
bal_accuracy(bf_ts_bt_pred, cause, pcause)
# A tibble: 1 × 3
  .metric      .estimator .estimate
  <chr>        <chr>          <dbl>
1 bal_accuracy macro          0.765
bf_ts_bt_pred |>
  count(cause, pcause) |>
  group_by(cause) |>
  mutate(Accuracy = n[cause==pcause]/sum(n)) |>
  pivot_wider(names_from = "pcause", 
              values_from = n, values_fill = 0) |>
  select(cause, accident, arson, burning_off, lightning, Accuracy)
# A tibble: 4 × 6
# Groups:   cause [4]
  cause       accident arson burning_off lightning Accuracy
  <fct>          <int> <int>       <int>     <int>    <dbl>
1 accident          19     1           0        13    0.576
2 arson              4     6           0         3    0.462
3 burning_off        0     0           2         2    0.5  
4 lightning          3     1           0       202    0.981

The boosted tree does improve the balanced accuracy.

👋 Finishing up

Make sure you say thanks and good-bye to your tutor. This is a time to also report what you enjoyed and what you found difficult.