x | cl |
---|---|
11 | A |
33 | A |
39 | B |
44 | A |
50 | A |
56 | B |
70 | B |
Week 5: Trees and forests
We will cover:
Pros and cons:
Define
\[\mbox{MSE} = \frac{1}{n}\sum_{i=1}^{n} (y_i - \widehat{y}_i)^2\]
Split the data where combining MSE for left bucket (MSE_L) and right bucket (MSE_R), makes the biggest reduction from the overall MSE.
x | cl |
---|---|
11 | A |
33 | A |
39 | B |
44 | A |
50 | A |
56 | B |
70 | B |
Note: x is sorted from lowest to highest!
All possible splits shown by vertical lines
What do you think is the best split? 2, 3 or 5??
The left bucket is
x | cl |
---|---|
11 | A |
33 | A |
39 | B |
44 | A |
50 | A |
and the right bucket is
x | cl |
---|---|
56 | B |
70 | B |
Using Gini \(G = \sum_{k =1}^K \widehat{p}_{mk}(1 - \widehat{p}_{mk})\)
Left bucket:
\[\widehat{p}_{LA} = 4/5, \widehat{p}_{LB} = 1/5, ~~ p_L = 5/7\]
\[G_L=0.8(1-0.8)+0.2(1-0.2) = 0.32\]
Right bucket:
\[\widehat{p}_{RA} = 0/2, \widehat{p}_{RB} = 2/2, ~~ p_R = 2/7\]
\[G_R=0(1-0)+1(1-1) = 0\] Combine with weighted sum to get impurity for the split:
\[5/7G_L + 2/7G_R=0.23\]
Your turn: Compute the impurity for split 2.
Splits on categorical variables
Possible best split would be if koala then assign to Vic else assign to WA, because Vic has more koalas but and WA has more emus and roos.
Dealing with missing values on predictors
x1 | x2 | x3 | x4 | y |
---|---|---|---|---|
19 | -8 | 22 | -24 | A |
NA | -10 | 26 | -26 | A |
15 | NA | 32 | -27 | B |
17 | -6 | 27 | -25 | A |
18 | -5 | NA | -23 | A |
13 | -3 | 37 | NA | B |
12 | -1 | 35 | -30 | B |
11 | -7 | 24 | -31 | B |
50% of cases have missing values. Trees ignore missings only on a single variable.
Every other method ignores a full observation if missing on any variable. That is, would only be able to use half the data.
set.seed(1156)
p_split <- initial_split(p_sub, 2/3, strata=species)
p_tr <- training(p_split)
p_ts <- testing(p_split)
tree_spec <- decision_tree() |>
set_mode("classification") |>
set_engine("rpart")
p_fit_tree <- tree_spec |>
fit(species~., data=p_tr)
p_fit_tree
parsnip model object
n= 145
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 145 45 Adelie (0.690 0.310)
2) bl< 43 99 2 Adelie (0.980 0.020) *
3) bl>=43 46 3 Chinstrap (0.065 0.935) *
Can you draw the tree?
Defaults for rpart
are:
rpart.control(minsplit = 20,
minbucket = round(minsplit/3),
cp = 0.01,
maxcompete = 4,
maxsurrogate = 5,
usesurrogate = 2,
xval = 10,
surrogatestyle = 0, maxdepth = 30,
...)
tree_spec <- decision_tree() |>
set_mode("classification") |>
set_engine("rpart",
control = rpart.control(minsplit = 10),
model=TRUE)
p_fit_tree <- tree_spec |>
fit(species~., data=p_tr)
p_fit_tree
parsnip model object
n= 145
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 145 45 Adelie (0.690 0.310)
2) bl< 43 99 2 Adelie (0.980 0.020)
4) bl< 41 75 0 Adelie (1.000 0.000) *
5) bl>=41 24 2 Adelie (0.917 0.083)
10) bm>=3.4e+03 21 0 Adelie (1.000 0.000) *
11) bm< 3.4e+03 3 1 Chinstrap (0.333 0.667) *
3) bl>=43 46 3 Chinstrap (0.065 0.935)
6) bl< 46 10 3 Chinstrap (0.300 0.700)
12) bm>=3.8e+03 3 0 Adelie (1.000 0.000) *
13) bm< 3.8e+03 7 0 Chinstrap (0.000 1.000) *
7) bl>=46 36 0 Chinstrap (0.000 1.000) *
Model fit summary
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.946
# A tibble: 2 × 4
# Groups: species [2]
species Adelie Chinstrap Accuracy
<fct> <int> <int> <dbl>
1 Adelie 50 1 0.980
2 Chinstrap 3 20 0.870
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 bal_accuracy binary 0.925
Can you see the misclassified test cases?
Model-in-the-data-space
A random forest is an ensemble classifier, built from fitting multiple trees to different subsets of the training data.
Fit
parsnip model object
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~1000, mtry = min_cols(~2, x))
Type of random forest: classification
Number of trees: 1000
No. of variables tried at each split: 2
OOB estimate of error rate: 4.8%
Confusion matrix:
Adelie Chinstrap class.error
Adelie 96 4 0.040
Chinstrap 3 42 0.067
Predicted values
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.973
# A tibble: 2 × 4
# Groups: species [2]
species Adelie Chinstrap Accuracy
<fct> <int> <int> <dbl>
1 Adelie 51 0 1
2 Chinstrap 2 21 0.913
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 bal_accuracy binary 0.957
Warning: Don’t use the predict()
on the training set, you’ll always get 0 error. The object p_fit_rf$fit$predict
has the fitted values.
Adelie Chinstrap
1 1.0000 0.0000
2 1.0000 0.0000
3 0.9807 0.0193
4 1.0000 0.0000
5 1.0000 0.0000
6 1.0000 0.0000
7 1.0000 0.0000
8 0.3982 0.6018
9 1.0000 0.0000
10 1.0000 0.0000
11 1.0000 0.0000
12 0.8274 0.1726
13 0.3425 0.6575
14 1.0000 0.0000
15 1.0000 0.0000
16 0.7931 0.2069
17 1.0000 0.0000
18 1.0000 0.0000
19 0.9973 0.0027
20 1.0000 0.0000
21 0.7622 0.2378
22 1.0000 0.0000
23 0.9459 0.0541
24 1.0000 0.0000
25 1.0000 0.0000
26 0.8568 0.1432
27 1.0000 0.0000
28 1.0000 0.0000
29 1.0000 0.0000
30 1.0000 0.0000
31 1.0000 0.0000
32 1.0000 0.0000
33 1.0000 0.0000
34 1.0000 0.0000
35 1.0000 0.0000
36 1.0000 0.0000
37 1.0000 0.0000
38 1.0000 0.0000
39 1.0000 0.0000
40 1.0000 0.0000
41 1.0000 0.0000
42 1.0000 0.0000
43 1.0000 0.0000
44 1.0000 0.0000
45 0.2773 0.7227
46 1.0000 0.0000
47 0.9821 0.0179
48 1.0000 0.0000
49 0.9973 0.0027
50 1.0000 0.0000
51 1.0000 0.0000
52 1.0000 0.0000
53 1.0000 0.0000
54 1.0000 0.0000
55 1.0000 0.0000
56 1.0000 0.0000
57 1.0000 0.0000
58 1.0000 0.0000
59 1.0000 0.0000
60 1.0000 0.0000
61 0.9833 0.0167
62 1.0000 0.0000
63 0.9113 0.0887
64 1.0000 0.0000
65 1.0000 0.0000
66 1.0000 0.0000
67 1.0000 0.0000
68 1.0000 0.0000
69 0.9912 0.0088
70 1.0000 0.0000
71 0.9535 0.0465
72 0.9914 0.0086
73 1.0000 0.0000
74 0.9676 0.0324
75 1.0000 0.0000
76 1.0000 0.0000
77 1.0000 0.0000
78 1.0000 0.0000
79 1.0000 0.0000
80 1.0000 0.0000
81 1.0000 0.0000
82 0.9973 0.0027
83 1.0000 0.0000
84 1.0000 0.0000
85 1.0000 0.0000
86 0.4624 0.5376
87 0.6160 0.3840
88 1.0000 0.0000
89 1.0000 0.0000
90 1.0000 0.0000
91 1.0000 0.0000
92 0.9948 0.0052
93 1.0000 0.0000
94 0.9972 0.0028
95 1.0000 0.0000
96 1.0000 0.0000
97 1.0000 0.0000
98 1.0000 0.0000
99 1.0000 0.0000
100 1.0000 0.0000
101 0.0000 1.0000
102 0.0000 1.0000
103 0.0055 0.9945
104 0.0653 0.9347
105 0.0000 1.0000
106 0.0000 1.0000
107 0.0000 1.0000
108 0.0000 1.0000
109 0.0000 1.0000
110 0.1935 0.8065
111 0.0000 1.0000
112 0.0000 1.0000
113 0.0159 0.9841
114 0.0000 1.0000
115 0.0000 1.0000
116 0.2074 0.7926
117 0.0000 1.0000
118 0.0000 1.0000
119 0.0117 0.9883
120 1.0000 0.0000
121 0.0529 0.9471
122 0.9536 0.0464
123 0.0027 0.9973
124 0.0000 1.0000
125 0.0000 1.0000
126 0.0163 0.9837
127 0.0000 1.0000
128 0.0000 1.0000
129 0.0111 0.9889
130 0.0000 1.0000
131 0.0694 0.9306
132 0.0000 1.0000
133 0.0137 0.9863
134 0.0000 1.0000
135 0.0000 1.0000
136 0.0052 0.9948
137 0.0000 1.0000
138 0.0000 1.0000
139 0.0135 0.9865
140 0.0000 1.0000
141 0.0140 0.9860
142 0.0000 1.0000
143 0.6325 0.3675
144 0.0000 1.0000
145 0.0000 1.0000
attr(,"class")
[1] "matrix" "array" "votes"
Where are the Adelie penguins in the training set that are misclassified?
parsnip model object
Call:
randomForest(x = maybe_data_frame(x), y = y, ntree = ~1000, mtry = min_cols(~2, x))
Type of random forest: classification
Number of trees: 1000
No. of variables tried at each split: 2
OOB estimate of error rate: 4.8%
Confusion matrix:
Adelie Chinstrap class.error
Adelie 96 4 0.040
Chinstrap 3 42 0.067
Join data containing true, predicted and predictive probabilities, to diagnose.
# A tibble: 7 × 6
species bl bm pspecies Adelie Chinstrap
<fct> <dbl> <int> <fct> <dbl> <dbl>
1 Adelie 41.1 3200 Chinstrap 0.398 0.602
2 Adelie 46 4200 Chinstrap 0.342 0.658
3 Adelie 45.8 4150 Chinstrap 0.277 0.723
4 Adelie 44.1 4000 Chinstrap 0.462 0.538
5 Chinstrap 40.9 3200 Adelie 1 0
6 Chinstrap 42.5 3350 Adelie 0.954 0.0464
7 Chinstrap 43.5 3400 Adelie 0.632 0.368
3.Difference the votes for the correct class in the variable-permuted oob cases and the real oob cases. Average this number over all trees in the forest. If the value is large, then the variable is very important.
Alternatively, Gini importance adds up the difference in impurity value of the descendant nodes with the parent node. Quick to compute.
Read a fun explanation by Harriet Mason
This creates a similarity matrix between all pairs of observations.
The votes matrix yields more information than the confusion matrix, about the confidence that the model has in the prediction for each observation, in the training set.
It is a \(K\)-D object, but lives in \((K-1)\)-D because the rows add to 1.
Let’s re-fit the random forest model to the three species of the penguins.
DEMO: Use interactivity to investigate the uncertainty in the predictions.
library(detourr)
library(crosstalk)
library(plotly)
library(viridis)
p_tr2_std <- p_tr2 |>
mutate_if(is.numeric, function(x) (x-mean(x))/sd(x))
p_tr2_v <- bind_cols(p_tr2_std, p_rf_v_p[,1:2])
p_tr2_v_shared <- SharedData$new(p_tr2_v)
detour_plot <- detour(p_tr2_v_shared, tour_aes(
projection = bl:bm,
colour = species)) |>
tour_path(grand_tour(2),
max_bases=50, fps = 60) |>
show_scatter(alpha = 0.9, axes = FALSE,
width = "100%",
height = "450px",
palette = hcl.colors(3,
palette="Zissou 1"))
vot_mat <- plot_ly(p_tr2_v_shared,
x = ~x1,
y = ~x2,
color = ~species,
colors = hcl.colors(3,
palette="Zissou 1"),
height = 450) |>
highlight(on = "plotly_selected",
off = "plotly_doubleclick") %>%
add_trace(type = "scatter",
mode = "markers")
bscols(
detour_plot, vot_mat,
widths = c(5, 6)
)
Random forests build an ensemble of independent trees, while boosted trees build an ensemble from shallow trees in a sequence with each tree learning and improving on the previous one, by re-weighting observations to give mistakes more importance.
Boosting iteratively fits multiple trees, sequentially putting more weight on observations that have predicted inaccurately.
This StatQuest video by Josh Starmer, is the best explanation!
And this is a fun explanation of boosting by Harriet Mason.
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy multiclass 0.991
# A tibble: 3 × 4
# Groups: species [3]
species Adelie Chinstrap Accuracy
<fct> <int> <int> <dbl>
1 Adelie 50 1 0.980
2 Chinstrap 0 23 1
3 Gentoo 0 0 1
ETC3250/5250 Lecture 5 | iml.numbat.space