多分类logistic回归

在之前文章介绍了,如何在R里面处理多分类的回归模型,得到的是各个因素的系数及相对OR,但是解释性,比二元logistic回归方程要冗杂的多。

那么今天继续前面的基础上,用机器学习的方法来解释多分类问题。 其实最终回归到这类分类问题的本质:有了一系列的影响因素x,那么根据这些影响因素来判断最终y属于哪一类别。

image.png

1.数据案例

这里主要用到DALEX包里面包含的HR数据,里面记录了职工在工作岗位的状态与年龄,性别,工作时长,评价及薪水有关。根据7847条记录来评估,如果一个职工属于男性,68岁,薪水及评价处于3等级,那么该职工可能会处于什么状态。

library(DALEX)

library(iBreakDown)

library(car)

library(questionr)

try(data(package="DALEX"))

data(HR)

# split

set.seed(543)

ind = sample(2,nrow(HR),replace=TRUE,prob=c(0.9,0.1))

trainData = HR[ind==1,]

testData = HR[ind==2,]

# randforest

m_rf = randomForest(status ~ . , data = trainData)

2.随机森林模型

我们根据上述数据,分成训练集与测试集(Train and Test)测试集用来估计随机森林模型的效果。

2.1模型评估

通过对Train数据构建rf模型后,我们对Train数据进行拟合,看一下模型的效果,Accuracy : 0.9357 显示很好,kappa一致性为90%。 那再用该fit去预测test数据, Accuracy : 0.7166 , Kappa : 56% ,显示效果不怎么理想。

# Prediction and Confusion Matrix - Training data

pred1 <- predict(m_rf, trainData)

head(pred1)

confusionMatrix(pred1, trainData$status) #

pred2 <- predict(m_rf, testData)

head(pred2)

confusionMatrix(pred2, testData$status) #

> confusionMatrix(pred1, trainData$status) #

Confusion Matrix and Statistics

Reference

Prediction fired ok promoted

fired 2478 194 49

ok 43 1738 80

promoted 25 64 2375

Overall Statistics

Accuracy : 0.9354

95% CI : (0.9294, 0.9411)

No Information Rate : 0.3613

P-Value [Acc > NIR] : < 2.2e-16

Kappa : 0.9024

Mcnemar's Test P-Value : < 2.2e-16

Statistics by Class:

Class: fired Class: ok Class: promoted

Sensitivity 0.9733 0.8707 0.9485

Specificity 0.9460 0.9756 0.9804

Pos Pred Value 0.9107 0.9339 0.9639

Neg Pred Value 0.9843 0.9502 0.9718

Prevalence 0.3613 0.2833 0.3554

Detection Rate 0.3517 0.2467 0.3371

Detection Prevalence 0.3862 0.2641 0.3497

Balanced Accuracy 0.9596 0.9232 0.9644

>

> pred2 <- predict(m_rf, testData)

> head(pred2)

1 20 36 42 49 56

fired fired fired fired fired ok

Levels: fired ok promoted

> confusionMatrix(pred2, testData$status) #

Confusion Matrix and Statistics

Reference

Prediction fired ok promoted

fired 246 62 19

ok 37 117 37

promoted 26 46 211

Overall Statistics

Accuracy : 0.7166

95% CI : (0.684, 0.7476)

No Information Rate : 0.3858

P-Value [Acc > NIR] : < 2e-16

Kappa : 0.5692

Mcnemar's Test P-Value : 0.03881

Statistics by Class:

Class: fired Class: ok Class: promoted

Sensitivity 0.7961 0.5200 0.7903

Specificity 0.8354 0.8715 0.8652

Pos Pred Value 0.7523 0.6126 0.7456

Neg Pred Value 0.8671 0.8230 0.8919

Prevalence 0.3858 0.2809 0.3333

Detection Rate 0.3071 0.1461 0.2634

Detection Prevalence 0.4082 0.2385 0.3533

Balanced Accuracy 0.8157 0.6958 0.8277

2.2变量重要性

我们看到,对影响因素进行重要性排序,等同于P值。在预测时候,哪些因素对y占影响比重较大。这里的variable_importance(),可以有好几种方式对变量进行衡量,这里采用默认的MeanDecreaseGini.

# vip

vip(m_rf)

var=randomForest::importance(m_rf)

var

image.png

2.2边际效应

我们知道了hours,age比较重要,那么是如何重要的,譬如年龄在什么阶段,会导致升职或者开除。 当工作小时在45以内,被开除/离职的概率较大,当工作时常超过60以后,很有可能会被提升。得到升职加薪的机会。 当然了,也可以绘制2D的边际效应,两个因素相互作用的Partial plot。

# partial plot

partialPlot(m_rf, HR, age)

head(partial(m_rf, pred.var = "age")) # returns a data frame

# for all varibles

nm=rownames(var)

# Get partial depedence values for top predictors

pd_df <- partial_dependence(fit = m_rf,

vars = nm,

data = df_rf,

n = c(100, 200))

# Plot partial dependence using edarf

plot_pd(pd_df)

image.png

image.png

2.3个体预测

现在假如有一个员工的信息如下,

gender age hours evaluation salary status

10000 female 57.96254 54.78624 4 4 promoted

去预测该职工最后的状态: 该预测结果显示,这个职工,有97%的可能性要升职加薪。而他的实际状态也是Promoted。

new_observation=tail(HR,1)

p_fun <- function(object, newdata){predict(object, newdata = newdata, type = "prob")}

bd_rf <- local_attributions(m_rf,

data = HR_test,

new_observation = new_observation,

predict_function = p_fun)

bd_rf

plot(bd_rf)

image.png

> sessionInfo()

R version 3.6.2 (2019-12-12)

Platform: x86_64-apple-darwin15.6.0 (64-bit)

Running under: macOS Mojave 10.14

Matrix products: default

BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib

LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:

[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:

[1] stats graphics utils datasets grDevices methods base

other attached packages:

[1] edarf_1.1.1 ranger_0.12.1 questionr_0.7.0 car_3.0-7

[5] carData_3.0-3 nnet_7.3-14 DALEX_1.2.1 vip_0.2.2

[9] ggpubr_0.3.0 rstatix_0.5.0 caret_6.0-86 lattice_0.20-41

[13] pdp_0.7.0 randomForest_4.6-14 iBreakDown_1.2.0 hrbrthemes_0.8.0

[17] reshape2_1.4.4 RColorBrewer_1.1-2 forcats_0.5.0 stringr_1.4.0

[21] dplyr_0.8.5 purrr_0.3.4 readr_1.3.1 tidyr_1.0.3

[25] tibble_3.0.1 ggplot2_3.3.0 tidyverse_1.3.0

参考

iBreakDown plots for classification modelsprediction 预测结果输出为概率pdp 边际效应Partial dependence (PD) plots For Random ForestsExplaining Black-Box Machine Learning ModelsInterpretable Machine Learning

推荐文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: