-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTeam14_Lab8.3.1.Rmd
More file actions
137 lines (105 loc) · 4.58 KB
/
Team14_Lab8.3.1.Rmd
File metadata and controls
137 lines (105 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
---
title: "8.3.1 Lab: Fitting Classification Trees"
author: "Group 14"
date: "2/24/2021"
output: pdf_document
---
```{r setup, include = FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
# Lab 8.3.1: Fitting Classification trees
This lab involves fitting classification trees which are used to predict qualitative responses.
```{r}
remove(list = ls())
```
## Load and View Data
If the tree library is not already installed, use install.packages('tree') first. This library allows for the construction of classification & regression trees. We will use the Carseats dataset which is in the ISLR library.
```{r}
library(tree)
library(ISLR)
attach(Carseats)
View(Carseats)
```
## Create Dataframe
In the Carseats dataset, the Sales variable is continuous so we use an ifelse statement to categorize the values as "No" if it is less than or equal to 8, and as "Yes" if the values are greater than 8. Then we attach this variable, High, to the Carseats dataframe.
```{r}
High = ifelse(Sales <= 8, "No", "Yes ")
Carseats = data.frame(Carseats ,High)
```
Note a slight difference in the code from the textbook is they did not use the as.factor() function to convert the High column to factors instead of strings. Using the class() function you can see it is now factors instead of characters.
```{r}
class(Carseats$High)
Carseats$High = as.factor(Carseats$High)
class(Carseats$High)
```
## Fit the Model
The goal is to predict High using every variable (except Sales). The tree() function is used to fit a classification tree to do this prediction. The training error rate is 9% (shown in last line of summary).
```{r}
tree.carseats = tree(High ~.-Sales , Carseats )
summary(tree.carseats)
```
## Display Classification Tree
Use the plot() function to plot the tree, and the text() function to display the node labels. The argument pretty = 0 shows category names instead of a single letter for each category.
```{r}
plot(tree.carseats)
text(tree.carseats, pretty = 0)
```
## Display Detailed Results
Typing the name of the tree object outputs the split criteria for each branch of the tree as well as number of observations, deviance, Yes or No prediction, and the percent of observations that take on Yes/No. An asterisk shows the node is terminal.
```{r}
tree.carseats
```
## Create Train and Test sets
Now we split the data into test and training sets For this model we will use 50% of the data in the train and test values.
```{r}
set.seed(2)
train = sample(1: nrow(Carseats), 200)
Carseats.test = Carseats[-train,]
High.test = High[-train]
```
## Create Model using Train data & Make Predictions
Next evaluate the performance of the model using the predict() function, and create a confusion matrix. This predicts the correct outcome 77% of the time.
```{r}
tree.carseats = tree(High~.-Sales, Carseats, subset=train)
tree.pred = predict(tree.carseats, Carseats.test, type="class")
table(tree.pred, High.test)
(104 + 50)/200
```
## Preform Cross-Validation to find Optimal Tree size
It is important to preform a cross-validation to determine the optimal level of tree complexity. The argument FUN = prune.misclass tells cv.tree to use the classification error rate to guide the cross validated and pruning process.
```{r}
set.seed(3)
cv.carseats = cv.tree(tree.carseats, FUN=prune.misclass)
names(cv.carseats)
cv.carseats
```
## Plot the CV Results
```{r}
par(mfrow = c(1,2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")
```
## Make model using Best Pruned Tree
It can be seen that the tree with 21 terminal nodes results in the lower cross-validation rate with 74 errors.
```{r}
prune.carseats = prune.misclass(tree.carseats, best = 21)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
```
## Find New Test Error Rate
We want to then test how well the pruned tree preforms on the data set. Use the predict() function again. This resulted with predicting 77.5% of the data correctly, which is slightly higher than our original result.
```{r}
tree.pred = predict(prune.carseats, Carseats.test, type = "class")
table(tree.pred, High.test)
(104 + 51)/200
```
## Test Against Different Node Values
It can be seen that by changing the value of best to a different size, like 14, we will get a slightly lower accuracy rate.
```{r}
prune.carseats =prune.misclass(tree.carseats ,best=14)
plot(prune.carseats )
text(prune.carseats ,pretty =0)
tree.pred=predict(prune.carseats ,Carseats.test , type="class")
table(tree.pred ,High.test)
(102 + 52)/200
```