决策树简介


决策树是一种基本的分类与回归方法。决策树模型呈树形结构,在分类问题中,表示基于特征对实例进行分类的过程。它可以认为是if-then规则的集合,也可以认为是定义在特征空间与类空间上的条件概率分布。其主要优点是模型具有可读性,分类速度快。学习时,利用训练数据,根据损失函数最小化的原则建立决策树模型。预测时,对新的数据利用决策树模型进行分类。决策树学习通常包括三个步骤:特征选择、决策树的生成和决策树的修剪。

用决策树分类,从根结点开始,对实例的某一特征进行测试,根据测试结果,将实例分配到其子结点;这时,每一个子结点对应着该特征的一个取值。如此递归地对实例进行测试并分配,直至达到叶结点,最后将实例分到叶结点的类中。

简单地说,决策树就是一个类似流程图的树形结构,采用自顶向下的递归方式,从树的根节点开始,在它的内部节点上进行属性值的测试比较,然后按照给定实例的属性值确定对应的分支,最后在决策树的叶子节点得到结论。这个过程在以新的节点为根的子树上重复。直到所有新节点给出的结果一致或足以判断分类。

假设给定训练数据集:
D{\rm{ = }}\left\{ {(x_1 ,y_1 ),(x_2 ,y_2 ),...,(x_N ,y_N )} \right\}

学习的目标是根据给定的训练数据集构建一个决策树模型,使它能够对实例进行正确的分类。决策树学习的算法通常是一个递归地选择最优特征,并根据该特征队训练数据进行分割,使得对各个子集有一个最好的分类的过程。这一过程对应着特征空间的划分和决策树的构建。

生成的决策树可能对训练数据有很好的分类能力,但对未知的测试数据却未必有很好的分类能力,即可能发生过拟合现象。需要对已生成的树自上而下进行剪枝,将树变得更简单,从而使它具有更好的泛化能力。具体地,就是去掉过于细分的叶结点,使其回退到父结点,甚至更高的结点,然后将父结点或更高的结点改为新的叶结点。

汽车特征评估质量
本次使用的是下载的一个包含汽车多个细节的数据集,包括车门数量、后备箱大小、维修成本、安全性能、载人数量等等,来确定一辆汽车的质量。分类的目的是把车辆的质量分为4种类型:不达标、达标、良好、优秀。

数据集的形式如图1所示,其中的每个值都可以看作成字符串。本次考虑数据集中的6个属性,其取值范围是这样的:

购买价位:取值范围是vhigh、 high、 med、 low,分别代表很高、高、中等、低;
维修成本:取值范围是vhigh、 high、 med、 low,分别代表很高、高、中等、低;
车门数量:取值范围是2、 3、 4、 5、5more等;
载客数量:取值范围是2、 4、more等;
动力性能:取值范围是small、 med、 big,分别代表小、中、大;
安全性能:取值范围是low、 med、 high,分别代表低、中、高
分类的结果,即汽车的质量取值范围是unacc、acc、good、vgood,分别代表不达标、达标、良好、优秀。

考虑到每一行都具有字符串属性,需要假设所有的特征均是字符串,并在次基础上建立分类器。

首先将数据集当中的所有字符串变为数字,方便后面的分类,由于下载的数据集为.data格式,matlab无法直接读取,已经转化为.xlsx格式,并且将vhigh、 high、 med、 low分别替换为4、3、2、1,将small、 med、 big替换为1、2、3,将low、 med、 high替换为1、2、3,将unacc、acc、good、vgood替换为1、2、3、4。
数据中共有1728组,随机从中取出1500组作为训练集,剩下的228组作为测试集。使用训练集建立决策树,然后使用模型进行预测。分别根据决策树的结果计算出决策树中对车辆各种情况预测的正确率以及全部测试集预测的准确率。然后对决策树进行修剪,对比起重采样误差以及交叉验证误差。

clear all;
clc;
close all;
 
%% 导入数据
load car;
a = randperm(1728);
%训练集
Train_Data = data(a(1:1500),1:6);
Train_Label = data(a(1:1500),7);
%测试集
Test_Data = data(a(1501:1728),1:6);
Test_Label = data(a(1501:1728),7);
 
%% 创建决策树分类器
Tree = ClassificationTree.fit(Train_Data,Train_Label);
 
%% 查看决策树视图
view(Tree);
view(Tree,'mode','graph');
 
%% 预测分类
Tree_pre = predict(Tree,Test_Data);
 
%% 结果分析
count_train_1 = length(find(Train_Label == 1));  %训练集中车辆质量不达标个数
count_train_2 = length(find(Train_Label == 2));  %训练集中车辆质量达标个数
count_train_3 = length(find(Train_Label == 3));  %训练集中车辆质量良好个数
count_train_4 = length(find(Train_Label == 4));  %训练集中车辆质量优秀个数
 
rate_train_1 = count_train_1 / 1500;             %训练集中车辆质量不达标占的比例
rate_train_2 = count_train_2 / 1500;             %训练集中车辆质量达标占的比例
rate_train_3 = count_train_3 / 1500;             %训练集中车辆质量优良占的比例
rate_train_4 = count_train_4 / 1500;             %训练集中车辆质量优秀占的比例
 
total_1 = length(find(data(:,7) == 1));  %总数据中车辆质量不达标个数
total_2 = length(find(data(:,7) == 2));  %总数据中车辆质量达标个数
total_3 = length(find(data(:,7) == 3));  %总数据中车辆质量优良个数
total_4 = length(find(data(:,7) == 4));  %总数据中车辆质量优秀个数
 
count_test_1 = length(find(Test_Label == 1));  %测试集中车辆质量不达标个数
count_test_2 = length(find(Test_Label == 2));  %测试集中车辆质量达标个数
count_test_3 = length(find(Test_Label == 3));  %测试集中车辆质量良好个数
count_test_4 = length(find(Test_Label == 4));  %测试集中车辆质量优秀个数
 
count_right_1 = length(find(Tree_pre == 1 & Test_Label == 1));  %测试集中预测车辆质量不达标正确的个数
count_right_2 = length(find(Tree_pre == 2 & Test_Label == 2));  %测试集中预测车辆质量达标正确的个数
count_right_3 = length(find(Tree_pre == 3 & Test_Label == 3));  %测试集中预测车辆质量优良正确的个数
count_right_4 = length(find(Tree_pre == 4 & Test_Label == 4));  %测试集中预测车辆质量优秀正确的个数
 
rate_right = (count_right_1+count_right_2+count_right_3+count_right_4)/228;
 
%% 显示部分结果
disp(['车辆总数:1728'...
      '  不达标:' num2str(total_1)...
      '  达标:' num2str(total_2)...
      '  优良:' num2str(total_3)...
      '  优秀:' num2str(total_4)]);
disp(['训练集车辆数:1500'...
      '  不达标:' num2str(count_train_1)...
      '  达标:' num2str(count_train_2)...
      '  优良:' num2str(count_train_3)...
      '  优秀:' num2str(count_train_4)]);
disp(['测试集车辆数:228'...
      '  不达标:' num2str(count_test_1)...
      '  达标:' num2str(count_test_2)...
      '  优良:' num2str(count_test_3)...
      '  优秀:' num2str(count_test_4)]);
disp(['决策树判断结果:'...
      '  不达标正确率:' sprintf('%2.2f%%', count_right_1/count_test_1*100)...
      '  达标正确率:' sprintf('%2.2f%%', count_right_2/count_test_2*100)...
      '  优良正确率:' sprintf('%2.2f%%', count_right_3/count_test_3*100)...
      '  优秀正确率:' sprintf('%2.2f%%', count_right_4/count_test_4*100)]);
  
disp(['总正确率:'... 
    sprintf('%2.2f%%', rate_right*100)]);
  
 %% 优化前决策树的重采样误差和交叉验证误差
resubDefault = resubLoss(Tree);
lossDefault = kfoldLoss(crossval(Tree));
disp(['剪枝前决策树的重采样误差:'... 
    num2str(resubDefault)]);
 disp(['剪枝前决策树的交叉验证误差:'... 
    num2str(lossDefault)]);
 
%% 剪枝
[~,~,~,bestlevel] = cvLoss(Tree,'subtrees','all','treesize','min');
cptree = prune(Tree,'Level',bestlevel);
view(cptree,'mode','graph')
 
%% 剪枝后决策树的重采样误差和交叉验证误差
resubPrune = resubLoss(cptree);
lossPrune = kfoldLoss(crossval(cptree));
disp(['剪枝后决策树的重采样误差:'... 
    num2str(resubPrune)]);
 disp(['剪枝后决策树的交叉验证误差:'... 
    num2str(resubPrune)]);

结果
车辆总数:1728  不达标:1210  达标:384  优良:69  优秀:65

训练集车辆数:1500  不达标:1046  达标:338  优良:56  优秀:60

测试集车辆数:228  不达标:164  达标:46  优良:13  优秀:5

决策树判断结果:  不达标正确率:97.56%  达标正确率:95.65%  优良正确率:84.62%  优秀正确率:100.00%

总正确率:96.49%

剪枝前决策树的重采样误差:0.026

剪枝前决策树的交叉验证误差:0.048667

剪枝后决策树的重采样误差:0.026667

剪枝后决策树的交叉验证误差:0.026667