R Random Forest Tutorial cu exemplu

Ce este Random Forest รฎn R?

Pฤƒdurile aleatorii se bazeazฤƒ pe o idee simplฤƒ: โ€žรฎnศ›elepciunea mulศ›imiiโ€. Agregarea rezultatelor mai multor predictori oferฤƒ o predicศ›ie mai bunฤƒ decรขt cel mai bun predictor individual. Un grup de predictori se numeศ™te an ansamblu. Astfel, aceastฤƒ tehnicฤƒ se numeศ™te รŽnvฤƒศ›are prin ansamblu.

รŽn tutorialul anterior, aศ›i รฎnvฤƒศ›at cum sฤƒ utilizaศ›i Copacii de decizie pentru a face o predicศ›ie binarฤƒ. Pentru a ne รฎmbunฤƒtฤƒศ›i tehnica, putem antrena un grup de Clasificatori de arbore de decizie, fiecare pe un subset aleatoriu diferit al garniturii. Pentru a face o predicศ›ie, obศ›inem doar predicศ›iile tuturor arborilor indivizi, apoi prezicem clasa care obศ›ine cele mai multe voturi. Aceastฤƒ tehnicฤƒ se numeศ™te Pฤƒdurea รฎntรขmplฤƒtoare.

Pas 1) Importฤƒ datele

Pentru a vฤƒ asigura cฤƒ aveศ›i acelaศ™i set de date ca รฎn tutorialul pentru copaci de decizie, testul trenului ศ™i setul de testare sunt stocate pe internet. Le puteศ›i importa fฤƒrฤƒ a face nicio modificare.

library(dplyr)
data_train <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/train.csv")
glimpse(data_train)
data_test <- read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/test.csv") 
glimpse(data_test)

Pasul 2) Antreneazฤƒ modelul

O modalitate de a evalua performanศ›a unui model este de a-l antrena pe un numฤƒr de seturi de date mai mici ศ™i de a le evalua pe celฤƒlalt set de testare mai mic. Aceasta se numeศ™te Validare รฎncruciศ™atฤƒ F-fold caracteristicฤƒ. R are o funcศ›ie de a รฎmpฤƒrศ›i aleatoriu un numฤƒr de seturi de date de aproape aceeaศ™i dimensiune. De exemplu, dacฤƒ k=9, modelul este evaluat รฎn cadrul celor nouฤƒ foldere ศ™i testat pe setul de testare rฤƒmas. Acest proces se repetฤƒ pรขnฤƒ cรขnd toate subseturile au fost evaluate. Aceastฤƒ tehnicฤƒ este utilizatฤƒ pe scarฤƒ largฤƒ pentru selecศ›ia modelului, mai ales atunci cรขnd modelul are parametri de reglat.

Acum cฤƒ avem o modalitate de a ne evalua modelul, trebuie sฤƒ ne dฤƒm seama cum sฤƒ alegem parametrii care au generalizat cel mai bine datele.

Pฤƒdurea aleatoare alege un subset aleatoriu de caracteristici ศ™i construieศ™te mulศ›i arbori de decizie. Modelul face o medie a tuturor predicศ›iilor arborilor de decizii.

Pฤƒdurea aleatoare are cรขศ›iva parametri care pot fi modificaศ›i pentru a รฎmbunฤƒtฤƒศ›i generalizarea predicศ›iei. Veศ›i folosi funcศ›ia RandomForest() pentru a antrena modelul.

Sintaxa pentru Randon Forest este

RandomForest(formula, ntree=n, mtry=FALSE, maxnodes = NULL)
Arguments:
- Formula: Formula of the fitted model
- ntree: number of trees in the forest
- mtry: Number of candidates draw to feed the algorithm. By default, it is the square of the number of columns.
- maxnodes: Set the maximum amount of terminal nodes in the forest
- importance=TRUE: Whether independent variables importance in the random forest be assessed

notiศ›e: Pฤƒdurea aleatoare poate fi antrenatฤƒ pe mai mulศ›i parametri. Vฤƒ puteศ›i referi la vinietฤƒ pentru a vedea diferiศ›ii parametri.

Reglarea unui model este o muncฤƒ foarte obositoare. Existฤƒ o mulศ›ime de combinaศ›ii posibile รฎntre parametri. Nu aveศ›i neapฤƒrat timp sฤƒ le รฎncercaศ›i pe toate. O alternativฤƒ bunฤƒ este sฤƒ lฤƒsaศ›i aparatul sฤƒ gฤƒseascฤƒ cea mai bunฤƒ combinaศ›ie pentru dvs. Existฤƒ douฤƒ metode disponibile:

  • Cฤƒutare aleatorie
  • Cฤƒutare grilฤƒ

Vom defini ambele metode, dar รฎn timpul tutorialului, vom antrena modelul folosind cฤƒutarea pe grilฤƒ

Definiศ›ie de cฤƒutare รฎn grilฤƒ

Metoda de cฤƒutare รฎn grilฤƒ este simplฤƒ, modelul va fi evaluat peste toatฤƒ combinaศ›ia pe care o treceศ›i รฎn funcศ›ie, folosind validarea รฎncruciศ™atฤƒ.

De exemplu, doriศ›i sฤƒ รฎncercaศ›i modelul cu 10, 20, 30 de arbori ศ™i fiecare arbore va fi testat pe un numฤƒr de metri egal cu 1, 2, 3, 4, 5. Apoi maศ™ina va testa 15 modele diferite:

    .mtry ntrees
 1      1     10
 2      2     10
 3      3     10
 4      4     10
 5      5     10
 6      1     20
 7      2     20
 8      3     20
 9      4     20
 10     5     20
 11     1     30
 12     2     30
 13     3     30
 14     4     30
 15     5     30	

Algoritmul va evalua:

RandomForest(formula, ntree=10, mtry=1)
RandomForest(formula, ntree=10, mtry=2)
RandomForest(formula, ntree=10, mtry=3)
RandomForest(formula, ntree=20, mtry=2)
...

De fiecare datฤƒ, pฤƒdurea aleatoare experimenteazฤƒ cu o validare รฎncruciศ™atฤƒ. Un neajuns al cฤƒutฤƒrii pe grilฤƒ este numฤƒrul de experimente. Poate deveni foarte uศ™or exploziv atunci cรขnd numฤƒrul de combinaศ›ii este mare. Pentru a depฤƒศ™i aceastฤƒ problemฤƒ, puteศ›i utiliza cฤƒutarea aleatorie

Definiศ›ie de cฤƒutare aleatorie

Marea diferenศ›ฤƒ dintre cฤƒutarea aleatorie ศ™i cฤƒutarea pe grilฤƒ este cฤƒ cฤƒutarea aleatoare nu va evalua toatฤƒ combinaศ›ia de hiperparametri din spaศ›iul de cฤƒutare. รŽn schimb, va alege aleatoriu o combinaศ›ie la fiecare iteraศ›ie. Avantajul este cฤƒ costul de calcul este mai mic.

Setaศ›i parametrul de control

Veศ›i proceda dupฤƒ cum urmeazฤƒ pentru a construi ศ™i evalua modelul:

  • Evaluaศ›i modelul cu setarea implicitฤƒ
  • Gฤƒsiศ›i cel mai bun numฤƒr de mtry
  • Gฤƒsiศ›i cel mai bun numฤƒr de maxnodes
  • Gฤƒsiศ›i cel mai bun numฤƒr de arbori
  • Evaluaศ›i modelul pe setul de date de testare

รŽnainte de a รฎncepe explorarea parametrilor, trebuie sฤƒ instalaศ›i douฤƒ biblioteci.

  • caret: R bibliotecฤƒ de รฎnvฤƒศ›are automatฤƒ. Daca ai instalaศ›i R cu r-esenศ›ial. Este deja รฎn bibliotecฤƒ
  • e1071: R bibliotecฤƒ de รฎnvฤƒศ›are automatฤƒ.

Le puteศ›i importa รฎmpreunฤƒ cu RandomForest

library(randomForest)
library(caret)
library(e1071)

Setare implicitฤƒ

Validarea รฎncruciศ™atฤƒ K-fold este controlatฤƒ de funcศ›ia trainControl().

trainControl(method = "cv", number = n, search ="grid")
arguments
- method = "cv": The method used to resample the dataset. 
- number = n: Number of folders to create
- search = "grid": Use the search grid method. For randomized method, use "grid"
Note: You can refer to the vignette to see the other arguments of the function.

Puteศ›i รฎncerca sฤƒ rulaศ›i modelul cu parametrii impliciti ศ™i sฤƒ vedeศ›i scorul de precizie.

notiศ›e: Veศ›i folosi aceleaศ™i comenzi รฎn timpul รฎntregului tutorial.

# Define the control
trControl <- trainControl(method = "cv",
    number = 10,
    search = "grid")

Veศ›i folosi biblioteca caret pentru a vฤƒ evalua modelul. Biblioteca are o funcศ›ie numitฤƒ train() pentru a evalua aproape toate masina de รฎnvฤƒศ›are algoritm. Spune altfel, poศ›i folosi aceastฤƒ funcศ›ie pentru a antrena alศ›i algoritmi.

Sintaxa de bazฤƒ este:

train(formula, df, method = "rf", metric= "Accuracy", trControl = trainControl(), tuneGrid = NULL)
argument
- `formula`: Define the formula of the algorithm
- `method`: Define which model to train. Note, at the end of the tutorial, there is a list of all the models that can be trained
- `metric` = "Accuracy": Define how to select the optimal model
- `trControl = trainControl()`: Define the control parameters
- `tuneGrid = NULL`: Return a data frame with all the possible combination

Sฤƒ รฎncercฤƒm sฤƒ construim modelul cu valorile implicite.

set.seed(1234)
# Run the model
rf_default <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    trControl = trControl)
# Print the results
print(rf_default)

Explicarea codului

  • trainControl(method=โ€cvโ€, number=10, search=โ€gridโ€): Evaluaศ›i modelul cu o cฤƒutare รฎn grilฤƒ de 10 foldere
  • train(โ€ฆ): Antreneazฤƒ un model de pฤƒdure aleatoriu. Cel mai bun model este ales cu mฤƒsura de precizie.

ieศ™ire:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    2    0.7919248  0.5536486
##    6    0.7811245  0.5391611
##   10    0.7572002  0.4939620
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 2.

Algoritmul foloseศ™te 500 de arbori ศ™i a testat trei valori diferite ale mtry: 2, 6, 10.

Valoarea finalฤƒ utilizatฤƒ pentru model a fost mtry = 2 cu o precizie de 0.78. Sฤƒ รฎncercฤƒm sฤƒ obศ›inem un scor mai mare.

Pasul 2) Cฤƒutaศ›i cel mai bun mtry

Puteศ›i testa modelul cu valori de mtry de la 1 la 10

set.seed(1234)
tuneGrid <- expand.grid(.mtry = c(1: 10))
rf_mtry <- train(survived~.,
    data = data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 300)
print(rf_mtry)

Explicarea codului

  • tuneGrid <- expand.grid(.mtry=c(3:10)): Construiศ›i un vector cu valoare de la 3:10

Valoarea finalฤƒ utilizatฤƒ pentru model a fost mtry = 4.

ieศ™ire:

## Random Forest 
## 
## 836 samples
##   7 predictor
##   2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 753, 752, 753, 752, 752, 752, ... 
## Resampling results across tuning parameters:
## 
##   mtry  Accuracy   Kappa    
##    1    0.7572576  0.4647368
##    2    0.7979346  0.5662364
##    3    0.8075158  0.5884815
##    4    0.8110729  0.5970664
##    5    0.8074727  0.5900030
##    6    0.8099111  0.5949342
##    7    0.8050918  0.5866415
##    8    0.8050918  0.5855399
##    9    0.8050631  0.5855035
##   10    0.7978916  0.5707336
## 
## Accuracy was used to select the optimal model using  the largest value.
## The final value used for the model was mtry = 4.

Cea mai bunฤƒ valoare a mtry este stocatฤƒ รฎn:

rf_mtry$bestTune$mtry

รŽl puteศ›i stoca ศ™i utiliza atunci cรขnd trebuie sฤƒ reglaศ›i ceilalศ›i parametri.

max(rf_mtry$results$Accuracy)

ieศ™ire:

## [1] 0.8110729
best_mtry <- rf_mtry$bestTune$mtry 
best_mtry

ieศ™ire:

## [1] 4

Pasul 3) Cฤƒutaศ›i cele mai bune maxnodes

Trebuie sฤƒ creaศ›i o buclฤƒ pentru a evalua diferitele valori ale maxnodes. รŽn urmฤƒtorul cod, veศ›i:

  • Creaศ›i o listฤƒ
  • Creaศ›i o variabilฤƒ cu cea mai bunฤƒ valoare a parametrului mtry; Obligatoriu
  • Creaศ›i bucla
  • Stocaศ›i valoarea curentฤƒ a lui maxnode
  • Rezumaศ›i rezultatele
store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(5: 15)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    current_iteration <- toString(maxnodes)
    store_maxnode[[current_iteration]] <- rf_maxnode
}
results_mtry <- resamples(store_maxnode)
summary(results_mtry)

Explicaศ›ia codului:

  • store_maxnode <- list(): Rezultatele modelului vor fi stocate รฎn aceastฤƒ listฤƒ
  • expand.grid(.mtry=best_mtry): Utilizaศ›i cea mai bunฤƒ valoare pentru mtry
  • for (maxnodes in c(15:25)) { โ€ฆ }: Calculaศ›i modelul cu valorile maxnodes รฎncepรขnd de la 15 la 25.
  • maxnodes=maxnodes: Pentru fiecare iteraศ›ie, maxnodes este egal cu valoarea curentฤƒ a maxnodes. adicฤƒ 15, 16, 17, โ€ฆ
  • key <- toString(maxnodes): Stocheazฤƒ ca variabilฤƒ ศ™ir valoarea lui maxnode.
  • store_maxnode[[key]] <- rf_maxnode: Salvaศ›i rezultatul modelului รฎn listฤƒ.
  • resamples(store_maxnode): Aranjaศ›i rezultatele modelului
  • summary(results_mtry): Tipฤƒriศ›i rezumatul tuturor combinaศ›iilor.

ieศ™ire:

## 
## Call:
## summary.resamples(object = results_mtry)
## 
## Models: 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.6785714 0.7529762 0.7903758 0.7799771 0.8168388 0.8433735    0
## 6  0.6904762 0.7648810 0.7784710 0.7811962 0.8125000 0.8313253    0
## 7  0.6904762 0.7619048 0.7738095 0.7788009 0.8102410 0.8333333    0
## 8  0.6904762 0.7627295 0.7844234 0.7847820 0.8184524 0.8433735    0
## 9  0.7261905 0.7747418 0.8083764 0.7955250 0.8258749 0.8333333    0
## 10 0.6904762 0.7837780 0.7904475 0.7895869 0.8214286 0.8433735    0
## 11 0.7023810 0.7791523 0.8024240 0.7943775 0.8184524 0.8433735    0
## 12 0.7380952 0.7910929 0.8144005 0.8051205 0.8288511 0.8452381    0
## 13 0.7142857 0.8005952 0.8192771 0.8075158 0.8403614 0.8452381    0
## 14 0.7380952 0.7941050 0.8203528 0.8098967 0.8403614 0.8452381    0
## 15 0.7142857 0.8000215 0.8203528 0.8075301 0.8378873 0.8554217    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 5  0.3297872 0.4640436 0.5459706 0.5270773 0.6068751 0.6717371    0
## 6  0.3576471 0.4981484 0.5248805 0.5366310 0.6031287 0.6480921    0
## 7  0.3576471 0.4927448 0.5192771 0.5297159 0.5996437 0.6508314    0
## 8  0.3576471 0.4848320 0.5408159 0.5427127 0.6200253 0.6717371    0
## 9  0.4236277 0.5074421 0.5859472 0.5601687 0.6228626 0.6480921    0
## 10 0.3576471 0.5255698 0.5527057 0.5497490 0.6204819 0.6717371    0
## 11 0.3794326 0.5235007 0.5783191 0.5600467 0.6126720 0.6717371    0
## 12 0.4460432 0.5480930 0.5999072 0.5808134 0.6296780 0.6717371    0
## 13 0.4014252 0.5725752 0.6087279 0.5875305 0.6576219 0.6678832    0
## 14 0.4460432 0.5585005 0.6117973 0.5911995 0.6590982 0.6717371    0
## 15 0.4014252 0.5689401 0.6117973 0.5867010 0.6507194 0.6955990    0

Ultima valoare a lui maxnode are cea mai mare precizie. Puteศ›i รฎncerca cu valori mai mari pentru a vedea dacฤƒ puteศ›i obศ›ine un scor mai mare.

store_maxnode <- list()
tuneGrid <- expand.grid(.mtry = best_mtry)
for (maxnodes in c(20: 30)) {
    set.seed(1234)
    rf_maxnode <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = maxnodes,
        ntree = 300)
    key <- toString(maxnodes)
    store_maxnode[[key]] <- rf_maxnode
}
results_node <- resamples(store_maxnode)
summary(results_node)

ieศ™ire:

## 
## Call:
## summary.resamples(object = results_node)
## 
## Models: 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30 
## Number of resamples: 10 
## 
## Accuracy 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.7142857 0.7821644 0.8144005 0.8075301 0.8447719 0.8571429    0
## 21 0.7142857 0.8000215 0.8144005 0.8075014 0.8403614 0.8571429    0
## 22 0.7023810 0.7941050 0.8263769 0.8099254 0.8328313 0.8690476    0
## 23 0.7023810 0.7941050 0.8263769 0.8111302 0.8447719 0.8571429    0
## 24 0.7142857 0.7946429 0.8313253 0.8135112 0.8417599 0.8690476    0
## 25 0.7142857 0.7916667 0.8313253 0.8099398 0.8408635 0.8690476    0
## 26 0.7142857 0.7941050 0.8203528 0.8123207 0.8528758 0.8571429    0
## 27 0.7023810 0.8060456 0.8313253 0.8135112 0.8333333 0.8690476    0
## 28 0.7261905 0.7941050 0.8203528 0.8111015 0.8328313 0.8690476    0
## 29 0.7142857 0.7910929 0.8313253 0.8087063 0.8333333 0.8571429    0
## 30 0.6785714 0.7910929 0.8263769 0.8063253 0.8403614 0.8690476    0
## 
## Kappa 
##         Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 20 0.3956835 0.5316120 0.5961830 0.5854366 0.6661120 0.6955990    0
## 21 0.3956835 0.5699332 0.5960343 0.5853247 0.6590982 0.6919315    0
## 22 0.3735084 0.5560661 0.6221836 0.5914492 0.6422128 0.7189781    0
## 23 0.3735084 0.5594228 0.6228827 0.5939786 0.6657372 0.6955990    0
## 24 0.3956835 0.5600352 0.6337821 0.5992188 0.6604703 0.7189781    0
## 25 0.3956835 0.5530760 0.6354875 0.5912239 0.6554912 0.7189781    0
## 26 0.3956835 0.5589331 0.6136074 0.5969142 0.6822128 0.6955990    0
## 27 0.3735084 0.5852459 0.6368425 0.5998148 0.6426088 0.7189781    0
## 28 0.4290780 0.5589331 0.6154905 0.5946859 0.6356141 0.7189781    0
## 29 0.4070588 0.5534173 0.6337821 0.5901173 0.6423101 0.6919315    0
## 30 0.3297872 0.5534173 0.6202632 0.5843432 0.6590982 0.7189781    0

Cel mai mare scor de precizie se obศ›ine cu o valoare a maxnode egalฤƒ cu 22.

Pasul 4) Cฤƒutaศ›i cele mai bune arbori

Acum cฤƒ aveศ›i cea mai bunฤƒ valoare pentru mtry ศ™i maxnode, puteศ›i regla numฤƒrul de arbori. Metoda este exact aceeaศ™i cu maxnode.

store_maxtrees <- list()
for (ntree in c(250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000)) {
    set.seed(5678)
    rf_maxtrees <- train(survived~.,
        data = data_train,
        method = "rf",
        metric = "Accuracy",
        tuneGrid = tuneGrid,
        trControl = trControl,
        importance = TRUE,
        nodesize = 14,
        maxnodes = 24,
        ntree = ntree)
    key <- toString(ntree)
    store_maxtrees[[key]] <- rf_maxtrees
}
results_tree <- resamples(store_maxtrees)
summary(results_tree)

ieศ™ire:

## 
## Call:
## summary.resamples(object = results_tree)
## 
## Models: 250, 300, 350, 400, 450, 500, 550, 600, 800, 1000, 2000 
## Number of resamples: 10 
## 
## Accuracy 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.7380952 0.7976190 0.8083764 0.8087010 0.8292683 0.8674699    0
## 300  0.7500000 0.7886905 0.8024240 0.8027199 0.8203397 0.8452381    0
## 350  0.7500000 0.7886905 0.8024240 0.8027056 0.8277623 0.8452381    0
## 400  0.7500000 0.7886905 0.8083764 0.8051009 0.8292683 0.8452381    0
## 450  0.7500000 0.7886905 0.8024240 0.8039104 0.8292683 0.8452381    0
## 500  0.7619048 0.7886905 0.8024240 0.8062914 0.8292683 0.8571429    0
## 550  0.7619048 0.7886905 0.8083764 0.8099062 0.8323171 0.8571429    0
## 600  0.7619048 0.7886905 0.8083764 0.8099205 0.8323171 0.8674699    0
## 800  0.7619048 0.7976190 0.8083764 0.8110820 0.8292683 0.8674699    0
## 1000 0.7619048 0.7976190 0.8121510 0.8086723 0.8303571 0.8452381    0
## 2000 0.7619048 0.7886905 0.8121510 0.8086723 0.8333333 0.8452381    0
## 
## Kappa 
##           Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## 250  0.4061697 0.5667400 0.5836013 0.5856103 0.6335363 0.7196807    0
## 300  0.4302326 0.5449376 0.5780349 0.5723307 0.6130767 0.6710843    0
## 350  0.4302326 0.5449376 0.5780349 0.5723185 0.6291592 0.6710843    0
## 400  0.4302326 0.5482030 0.5836013 0.5774782 0.6335363 0.6710843    0
## 450  0.4302326 0.5449376 0.5780349 0.5750587 0.6335363 0.6710843    0
## 500  0.4601542 0.5449376 0.5780349 0.5804340 0.6335363 0.6949153    0
## 550  0.4601542 0.5482030 0.5857118 0.5884507 0.6396872 0.6949153    0
## 600  0.4601542 0.5482030 0.5857118 0.5884374 0.6396872 0.7196807    0
## 800  0.4601542 0.5667400 0.5836013 0.5910088 0.6335363 0.7196807    0
## 1000 0.4601542 0.5667400 0.5961590 0.5857446 0.6343666 0.6678832    0
## 2000 0.4601542 0.5482030 0.5961590 0.5862151 0.6440678 0.6656337    0

Ai modelul tฤƒu final. Puteศ›i antrena pฤƒdurea aleatorie cu urmฤƒtorii parametri:

  • ntree =800: vor fi dresaศ›i 800 de arbori
  • mtry=4: sunt alese 4 caracteristici pentru fiecare iteraศ›ie
  • maxnodes = 24: Maxim 24 de noduri รฎn nodurile terminale (frunze)
fit_rf <- train(survived~.,
    data_train,
    method = "rf",
    metric = "Accuracy",
    tuneGrid = tuneGrid,
    trControl = trControl,
    importance = TRUE,
    nodesize = 14,
    ntree = 800,
    maxnodes = 24)

Pasul 5) Evaluaศ›i modelul

Caretul bibliotecii are o funcศ›ie de a face predicศ›ii.

predict(model, newdata= df)
argument
- `model`: Define the model evaluated before. 
- `newdata`: Define the dataset to make prediction
prediction <-predict(fit_rf, data_test)

Puteศ›i utiliza predicศ›ia pentru a calcula matricea de confuzie ศ™i pentru a vedea scorul de precizie

confusionMatrix(prediction, data_test$survived)

ieศ™ire:

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  110  32
##        Yes  11  56
##                                          
##                Accuracy : 0.7943         
##                  95% CI : (0.733, 0.8469)
##     No Information Rate : 0.5789         
##     P-Value [Acc > NIR] : 3.959e-11      
##                                          
##                   Kappa : 0.5638         
##  Mcnemar's Test P-Value : 0.002289       
##                                          
##             Sensitivity : 0.9091         
##             Specificity : 0.6364         
##          Pos Pred Value : 0.7746         
##          Neg Pred Value : 0.8358         
##              Prevalence : 0.5789         
##          Detection Rate : 0.5263         
##    Detection Prevalence : 0.6794         
##       Balanced Accuracy : 0.7727         
##                                          
##        'Positive' Class : No             
## 

Aveศ›i o precizie de 0.7943 la sutฤƒ, care este mai mare decรขt valoarea implicitฤƒ

Pasul 6) Vizualizaศ›i rezultatul

รŽn cele din urmฤƒ, puteศ›i analiza importanศ›a caracteristicii cu funcศ›ia varImp(). Se pare cฤƒ cele mai importante caracteristici sunt sexul ศ™i vรขrsta. Acest lucru nu este surprinzฤƒtor, deoarece caracteristicile importante sunt probabil sฤƒ aparฤƒ mai aproape de rฤƒdฤƒcina copacului, รฎn timp ce caracteristicile mai puศ›in importante vor apฤƒrea adesea รฎnchise de frunze.

varImpPlot(fit_rf)

ieศ™ire:

varImp(fit_rf)
## rf variable importance
## 
##              Importance
## sexmale         100.000
## age              28.014
## pclassMiddle     27.016
## fare             21.557
## pclassUpper      16.324
## sibsp            11.246
## parch             5.522
## embarkedC         4.908
## embarkedQ         1.420
## embarkedS         0.000		

Rezumat

Putem rezuma cum sฤƒ antrenฤƒm ศ™i sฤƒ evaluฤƒm o pฤƒdure aleatoare cu tabelul de mai jos:

Bibliotecฤƒ Obiectiv Funcลฃie Parametru
randomForest Creaศ›i o pฤƒdure aleatorie RandomForest() formula, ntree=n, mtry=FALSE, maxnodes = NULL
semn de omisiune Creaศ›i validarea รฎncruciศ™atฤƒ a folderului K trainControl() metoda = โ€žcvโ€, numฤƒr = n, cฤƒutare = โ€žgridโ€
semn de omisiune Antreneazฤƒ o pฤƒdure aleatorie tren() formula, df, metoda = โ€žrfโ€, metric = โ€žAcurateศ›eโ€, trControl = trainControl(), tuneGrid = NULL
semn de omisiune Prevede din eศ™antion prezice model, newdata= df
semn de omisiune Matricea de confuzie ศ™i statistici confusionMatrix() model, y test
semn de omisiune importanศ›ฤƒ variabilฤƒ cvarImp() model

Apendice

Lista modelelor utilizate รฎn caret

names>(getModelInfo())

ieศ™ire:

##   [1] "ada"                 "AdaBag"              "AdaBoost.M1"        ##   [4] "adaboost"            "amdai"               "ANFIS"              ##   [7] "avNNet"              "awnb"                "awtan"              ##  [10] "bag"                 "bagEarth"            "bagEarthGCV"        ##  [13] "bagFDA"              "bagFDAGCV"           "bam"                ##  [16] "bartMachine"         "bayesglm"            "binda"              ##  [19] "blackboost"          "blasso"              "blassoAveraged"     ##  [22] "bridge"              "brnn"                "BstLm"              ##  [25] "bstSm"               "bstTree"             "C5.0"               ##  [28] "C5.0Cost"            "C5.0Rules"           "C5.0Tree"           ##  [31] "cforest"             "chaid"               "CSimca"             ##  [34] "ctree"               "ctree2"              "cubist"             ##  [37] "dda"                 "deepboost"           "DENFIS"             ##  [40] "dnn"                 "dwdLinear"           "dwdPoly"            ##  [43] "dwdRadial"           "earth"               "elm"                ##  [46] "enet"                "evtree"              "extraTrees"         ##  [49] "fda"                 "FH.GBML"             "FIR.DM"             ##  [52] "foba"                "FRBCS.CHI"           "FRBCS.W"            ##  [55] "FS.HGD"              "gam"                 "gamboost"           ##  [58] "gamLoess"            "gamSpline"           "gaussprLinear"      ##  [61] "gaussprPoly"         "gaussprRadial"       "gbm_h3o"            ##  [64] "gbm"                 "gcvEarth"            "GFS.FR.MOGUL"       ##  [67] "GFS.GCCL"            "GFS.LT.RS"           "GFS.THRIFT"         ##  [70] "glm.nb"              "glm"                 "glmboost"           ##  [73] "glmnet_h3o"          "glmnet"              "glmStepAIC"         ##  [76] "gpls"                "hda"                 "hdda"               ##  [79] "hdrda"               "HYFIS"               "icr"                ##  [82] "J48"                 "JRip"                "kernelpls"          ##  [85] "kknn"                "knn"                 "krlsPoly"           ##  [88] "krlsRadial"          "lars"                "lars2"              ##  [91] "lasso"               "lda"                 "lda2"               ##  [94] "leapBackward"        "leapForward"         "leapSeq"            ##  [97] "Linda"               "lm"                  "lmStepAIC"          ## [100] "LMT"                 "loclda"              "logicBag"           ## [103] "LogitBoost"          "logreg"              "lssvmLinear"        ## [106] "lssvmPoly"           "lssvmRadial"         "lvq"                ## [109] "M5"                  "M5Rules"             "manb"               ## [112] "mda"                 "Mlda"                "mlp"                ## [115] "mlpKerasDecay"       "mlpKerasDecayCost"   "mlpKerasDropout"    ## [118] "mlpKerasDropoutCost" "mlpML"               "mlpSGD"             ## [121] "mlpWeightDecay"      "mlpWeightDecayML"    "monmlp"             ## [124] "msaenet"             "multinom"            "mxnet"              ## [127] "mxnetAdam"           "naive_bayes"         "nb"                 ## [130] "nbDiscrete"          "nbSearch"            "neuralnet"          ## [133] "nnet"                "nnls"                "nodeHarvest"        ## [136] "null"                "OneR"                "ordinalNet"         ## [139] "ORFlog"              "ORFpls"              "ORFridge"           ## [142] "ORFsvm"              "ownn"                "pam"                ## [145] "parRF"               "PART"                "partDSA"            ## [148] "pcaNNet"             "pcr"                 "pda"                ## [151] "pda2"                "penalized"           "PenalizedLDA"       ## [154] "plr"                 "pls"                 "plsRglm"            ## [157] "polr"                "ppr"                 "PRIM"               ## [160] "protoclass"          "pythonKnnReg"        "qda"                ## [163] "QdaCov"              "qrf"                 "qrnn"               ## [166] "randomGLM"           "ranger"              "rbf"                ## [169] "rbfDDA"              "Rborist"             "rda"                ## [172] "regLogistic"         "relaxo"              "rf"                 ## [175] "rFerns"              "RFlda"               "rfRules"            ## [178] "ridge"               "rlda"                "rlm"                ## [181] "rmda"                "rocc"                "rotationForest"     ## [184] "rotationForestCp"    "rpart"               "rpart1SE"           ## [187] "rpart2"              "rpartCost"           "rpartScore"         ## [190] "rqlasso"             "rqnc"                "RRF"                ## [193] "RRFglobal"           "rrlda"               "RSimca"             ## [196] "rvmLinear"           "rvmPoly"             "rvmRadial"          ## [199] "SBC"                 "sda"                 "sdwd"               ## [202] "simpls"              "SLAVE"               "slda"               ## [205] "smda"                "snn"                 "sparseLDA"          ## [208] "spikeslab"           "spls"                "stepLDA"            ## [211] "stepQDA"             "superpc"             "svmBoundrangeString"## [214] "svmExpoString"       "svmLinear"           "svmLinear2"         ## [217] "svmLinear3"          "svmLinearWeights"    "svmLinearWeights2"  ## [220] "svmPoly"             "svmRadial"           "svmRadialCost"      ## [223] "svmRadialSigma"      "svmRadialWeights"    "svmSpectrumString"  ## [226] "tan"                 "tanSearch"           "treebag"            ## [229] "vbmpRadial"          "vglmAdjCat"          "vglmContRatio"      ## [232] "vglmCumulative"      "widekernelpls"       "WM"                 ## [235] "wsrf"                "xgbLinear"           "xgbTree"            ## [238] "xyf"

Rezumaศ›i aceastฤƒ postare cu: