데이터 분석/R

[R] Naive Bayes Classification (나이브 베이즈 분류)

eunki 2021. 7. 2. 18:35
728x90

데이터 불러오기

rawdata <- read.csv("wine.csv", header = TRUE)
rawdata$Class <- as.factor(rawdata$Class)
str(rawdata)

 

 


트레이닝-테스트 셋 분리 (7:3)

analdata <- rawdata

set.seed(2020)
datatotal <- sort(sample(nrow(analdata), nrow(analdata)*.7))
train <- rawdata[datatotal,]
test <- rawdata[-datatotal,]

train_x <- train[,1:13]
train_y <- train[,14]

test_x <- test[,1:13]
test_y <- test[,14]

 



학습

ctrl <- trainControl(method = "repeatedcv", repeats = 5) 
nbFit <- train(Class~., 
               data = train, 
               method = "naive_bayes", 
               trControl = ctrl, 
               preProcess = c("center", "scale"), 
               metric = "Accuracy") 

nbFit

 

→ kernel을 사용하지 않을 때, 더 높은 정확도를 가진다.

 

 

 

plot(nbFit)

 



예측

pred_test <- predict(nbFit, newdata = test) 
confusionMatrix(pred_test, test$Class)

 

→ Accuracy : 0.944, Kappa : 0.913

 

 

변수중요도
ROC 커브의 면적이 넓을수록 중요도는 상승한다.

importance_nb <- varImp(nbFit, scale = FALSE) 
importance_nb

 

 

 

plot(importance_nb) 

728x90