The goal for this week is learn to fit, diagnose, assess assumptions, and predict from classification tree and random forest models.
🔧 Preparation
Make sure you have all the necessary libraries installed. There are a few new ones this week!
Exercises:
Open your project for this unit called iml.Rproj. For all the work we will use the penguins data. Start with splitting it into a training and test set, as follows.
1. Becoming a car mechanic - looking under the hood at the tree algorithm
Write down the equation for the Gini measure of impurity, for two groups, and the parameter \(p\) which is the proportion of observations in class 1. Specify the domain of the function, and determine the value of \(p\) which gives the maximum value, and report what that maximum function value is.
Solution
\(G = p(1-p)\) where \(p\) is the proportion of class 1 in the subset of data. The domain is \([0, 1]\) and the maximum value of \(0.25\) is at \(p=0.5\).
For two groups, how would the impurity of a split be measured? Give the equation.
Solution
\[p_L(p_{L1}(1-p_{L1})+p_{L2}(1-p_{L2})) + p_R(p_{R1}(1-p_{R1})+p_{R2}(1-p_{R2}))\] where \(p_L\) is the proportion of observations to the left of the split, \(p_{L1}\) is the proportion of observations of class 1 to the left of the split, and \(p_{R1}\) indicates the equivalent quantities for observations to the right of the split.
Below is an R function to compute the Gini impurity for a particular split on a single variable. Work through the code of the function, and document what each step does. Make sure to include a not on what the minsplit parameter, does to prevent splitting on the edges fewer than the specified number of observations.
# This works for two classes, and one variablemygini <-function(p) { g <-0if (p>0&& p<1) { g <-2*p*(1-p) }return(g)}mysplit <-function(x, spl, cl, minsplit=5) {# Assumes x is sorted# Count number of observations n <-length(x)# Check number of classes cl_unique <-unique(cl)# Split into two subsets on the given value left <- x[x<spl] cl_left <- cl[x<spl] n_l <-length(left) right <- x[x>=spl] cl_right <- cl[x>=spl] n_r <-length(right)# Don't calculate is either set is less than minsplitif ((n_l < minsplit) | (n_r < minsplit)) impurity =NAelse {# Compute the Gini value for the split p_l <-length(cl_left[cl_left == cl_unique[1]])/n_l p_r <-length(cl_right[cl_right == cl_unique[1]])/n_rif (is.na(p_l)) p_l<-0.5if (is.na(p_r)) p_r<-0.5 impurity <- (n_l/n)*mygini(p_l) + (n_r/n)*mygini(p_r) }return(impurity)}
Apply the function to compute the value for all possible splits for the body mass (bm), setting minsplit to be 1, so that all possible splits will be evaluated. Make a plot of these values vs the variable.
Use your function to compute the first two steps of a classification tree model for separating Adelie from Chinstrap penguins, after setting minsplit to be 5. Make a scatterplot of the two variables that would be used in the splits, with points coloured by species, and the splits as line segments.
Solution
# bl: this is the only one needed for the first split# because it is so better separated than any othersx <- p_tr |>select(species, bl) |>arrange(bl)unique_splits <-unique(x$bl)nsplits <-length(unique_splits)-1splits <- (unique_splits[1:nsplits] + unique_splits[2:(nsplits+1)])/2imp <-NULL;for (i in1:length(splits)) { s <- splits[i] a <-mysplit(x$bl, s, x$species, minsplit=1) imp <-c(imp, a)}d_impurity <-tibble(splits, imp)d_impurity_bl <- d_impurity[which.min(d_impurity$imp),]ggplot() +geom_line(data=d_impurity, aes(x=splits, y=imp)) +geom_rug(data=x, aes(x=bl, colour=species), alpha=0.3) +ylab("Gini impurity") +xlab("bl") +scale_color_brewer("", palette="Dark2")
p_tr_L <- p_tr |>filter(bl < d_impurity_bl$splits)p_tr_R <- p_tr |>filter(bl > d_impurity_bl$splits)# Make a function to make calculations easierbest_split <-function(x, cl, minsplit=5) { unique_splits <-unique(x) nsplits <-length(unique_splits)-1 splits <- (unique_splits[1:nsplits] + unique_splits[2:(nsplits+1)])/2 imp <-NULL;for (i in1:length(splits)) { s <- splits[i] a <-mysplit(x, s, cl, minsplit) imp <-c(imp, a) } d_impurity <-tibble(splits, imp) d_impurity_best <- d_impurity[which.min(d_impurity$imp),]return(d_impurity_best)}s1 <-best_split(p_tr$bl, p_tr$species, minsplit=5)s2 <-best_split(p_tr_R$bm, p_tr_R$species, minsplit=5)ggplot(p_tr, aes(x=bl, y=bm, colour=species)) +geom_point() +geom_vline(xintercept=s1$splits) +annotate("segment", x = s1$splits,xend =max(p_tr$bl),y = s2$splits, yend = s2$splits) +scale_colour_brewer("", palette="Dark2") +theme(aspect.ratio =1)
2. Digging deeper into diagnosing an error
Fit the random forest model to the full penguins data.
parsnip model object
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~1000, mtry = min_cols(~2, x))
Type of random forest: classification
Number of trees: 1000
No. of variables tried at each split: 2
OOB estimate of error rate: 2.6%
Confusion matrix:
Adelie Chinstrap Gentoo class.error
Adelie 97 2 1 0.030
Chinstrap 2 43 0 0.044
Gentoo 0 1 81 0.012
Use linked brushing to learn which was the Gentoo penguin that the model was confused about. When we looked at the data in a tour, there was one Gentoo penguin that was an outlier, appearing to be away from the other Gentoos and closer to the Chinstrap group. We would expect this to be the penguin that the forest model is confused about. Is it?
Have a look at the other misclassifications, to understand whether they are ones we’d expect to misclassify, or whether the model is not well constructed.
Fit a random forest to the bushfire data. You can read more about the bushfire data at https://dicook.github.io/mulgar_book/A2-data.html. Examine the votes matrix using a tour. What do you learn about the confusion between fire causes?
# Examine votes matrix with bounding simplexanimate_xy(bf_rf_v_p_s[,1:3], col = bf_rf_v_p_s$cause, axes ="off", half_range =1.3,edges =as.matrix(simp$edges),obs_labels = labels)
Solution
The pattern is that points are bunched at the vertex corresponding to lightning, extending along the edge leading to accident. We could also say that the points do extend on the face corresponding to lightning, accident and arson, too. The primary confusion for each of the other classes is with lightning. Few points are predicted to be burning_off because this is typically only occurring outside of fire season.
Part of the reason that the forest predicts predominantly to the lightning class is because it is a highly imbalanced problem. One approach is to change the weights for each class, to give the lightning class a lower priority. This will change the model predictions to be more often the other three classes.
Check the variable importance. Plot the most important variables.
Each of these variables has some difference in median value between the classes, but none shows any separation between them. If the three most important variables show little separation, it indicates the difficulty in distinguishing between these classes. However, it looks like if the distance from a road, or CFA station is bigger, the chance of the cause being a lightning start is higher. This makes sense, because these would be locations further from human activity, and thus the fire is less likely to started by people. The arf360 relates to rain from a year ago. It also appears that if the rainfall was higher a year ago, lightning is more likely the cause. This also makes some sense, because with more rain in the previous year, there should be more vegetation. Particularly, if recent months have been dry, then there is likely a lot of dry vegetation which is combustible. Ideally we would create a new variable (feature engineering) that looks at difference in rainfall from the previous year to just before the current year’s fire season, to model these types of conditions.
4. Can boosting better detect bushfire cause?
Fit a boosted tree model using xgboost to the bushfires data. You can use the code below. Compute the confusion tables and the balanced accuracy for the test data for both the forest model and the boosted tree model, to make the comparison.
set.seed(121)bf_spec2 <-boost_tree() |>set_mode("classification") |>set_engine("xgboost")bf_fit_bt <- bf_spec2 |>fit(cause~., data = bf_tr)