機器學習之決策樹原理和sklearn實踐

1. 場景描述

時間:早上八點,地點:婚介所

‘閨女,我有給你找了個合適的對象,今天要不要見一面?’

‘多大?’ ‘26歲’

‘長的帥嗎?’ ‘還可以,不算太帥’

‘工資高嗎?’ ‘略高於平均水平’

‘會寫代碼嗎?’ ‘人家是程序員,代碼寫的棒着呢!’

‘好,把他的聯繫方式發過來吧,我抽空見一面’

上面的場景描述摘抄自 ,是一個典型的決策樹分類問題,通過年齡、長相、工資、是否會編程等特徵屬性對介紹對象進行是否約會進行分類

決策樹是一種自上而下,對樣本數據進行樹形分類的過程,由結點和有向邊組成,每個結點(恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘結點除外)便是一個特徵或屬性,恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘結點表示類別。從頂部根結點開始,所有樣本聚在儀器,經過根結點的劃分,樣本被分到不同的子結點中。再根據子結點的特徵進一步劃分,直至樣本都被分到某一類別(恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘子結點)中

2. 決策樹原理

決策樹作為最基礎、最常見的有監督學習模型,常被用於分類問題和回歸問題,將決策樹應用集成思想可以得到隨機森林、梯度提升決策樹等模型。其主要優點是模型具有可讀性,分類速度快。決策樹的學習通常包括三個步驟:特徵選擇、決策樹的生成和決策樹的修剪,下面對特徵選擇算法進行描述和區別

2.1 ID3—最大信息增益

在信息論與概率統計中,熵(entropy)是表示隨機變量不確定性的度量,設X是一個取有限個值的隨機變量,其概率分佈為:\[P(X=X_i)=P_i (i = 1,2,…,n)\],則隨機變量X的熵定義為:\[H(X) = -\sum_{i=1}^np_i\log{p_i}\]表達式中的對數以2為底或以e為底,這時熵的單位分別稱作bit或nat,從表達式可以看出X的熵與X的取值無關,所以X的熵也記作\(H(p)\),即\[H(p) = -\sum_{x=1}^np_i\log{p_i}\]熵取值越大,隨機變量的不確定性越大

條件熵:

條件熵H(Y|X)表示在已知隨機變量X的條件下,隨機變量Y的不確定性,隨機變量X給定的條件下隨機變量Y的條件熵定義為X給定條件下Y的條件概率分佈的熵對X的數學期望\[H(Y|X) = \sum_{i=1}^nP(X=X_i)H(Y|X=X_i)\]

信息增益:\[g(D,A) = H(D) – H(D|A)\]

import pandas as pd
data = {
        '年齡':['老','年輕','年輕','年輕','年輕'],
        '長相':['帥','一般','丑','一般','一般'],
        '工資':['高','中等','高','高','低'],
        '寫代碼':['不會','會','不會','會','不會'],
        '類別':['不見','見','不見','見','不見']}
frame = pd.DataFrame(data,index=['小A','小B','小C','小D','小L'])
print(frame)
    年齡  長相  工資 寫代碼  類別
小A   老   帥   高  不會  不見
小B  年輕  一般  中等   會   見
小C  年輕   丑   高  不會  不見
小D  年輕  一般   高   會   見
小L  年輕  一般   低  不會  不見
import math
print(math.log(3/5))
print('H(D):',-3/5 *math.log(3/5,2) - 2/5*math.log(2/5,2))
print('H(D|年齡)',1/5*math.log(1,2)+4/5*(-1/2*math.log(1/2,2)-1/2*math.log(1/2,2)))
print('以同樣的方法計算H(D|長相),H(D|工資),H(D|寫代碼)')
print('H(D|長相)',0.551)
print('H(D|工資)',0.551)
print('H(D|寫代碼)',0)
-0.5108256237659907
H(D): 0.9709505944546686
H(D|年齡) 0.8
以同樣的方法計算H(D|長相),H(D|工資),H(D|寫代碼)
H(D|長相) 0.551
H(D|工資) 0.551
H(D|寫代碼) 0

計算信息增益:g(D,寫代碼)=0.971最大,可以先按照寫代碼來拆分決策樹

2.2 C4.5—最大信息增益比

以信息增益作為劃分訓練數據集的特徵,存在偏向於選擇取值較多的問題,使用信息增益比可以對對着問題進行校正,這是特徵選擇的另一標準
信息增益比定義為其信息增益g(D,A)與訓練數據集D關於特徵A的值的熵\(H_A(D)\)之比:\[g_R(D,A) = \frac{g(D,A)}{H_A(D)}\]

\[H_A(D) = -\sum_{i=1}^n\frac{|D_i|}{|D|}\log\frac{|D_i|}{|D|}\]

拿上面ID3的例子說明:
\[H_年齡(D) = -1/5*math.log(1/5,2)-4/5*math.log(4/5,2)\]

\[g_R(D,年齡) = H_{年齡}(D)/g(D,年齡) = 0.171/0.722 = 0.236 \]

2.3 CART—-最大基尼指數(Gini)

Gini描述的是數據的純度,與信息熵含義類似,分類問題中,假設有K個類,樣本點數據第k類的概率為\(P_k\),則概率分佈的基尼指數定義為:
\[Gini(p) = 1- \sum_{k=1}^Kp_k(1-p_k) = 1 – \sum_{k=1}^Kp_{k}^2\]
對於二分類問題,弱樣本點屬於第1個類的概率是p,則概率分佈的基尼指數為\[Gini(p) = 2p(1-p)\],對於給定的樣本幾何D,其基尼指數為\[Gini(D) = 1 – \sum_{k=1}^K[\frac{|C_k|}{|D|}]^2\]注意這裏\(C_k\)是D種屬於第k類的樣本子集,K是類的個數,如果樣本幾個D根據特徵A是否取某一可能指a被分割成D1和D2兩部分,則在特徵A的條件下,集合D的基尼指數定義為\[Gini(D,A) = \frac{|D_1|}{|D|}Gini(D_1)+\frac{|D_2|}{|D|}Gini(D_2)\]
\[Gini(D|年齡=老)=1/5*(1-1)+4/5*[1-(1/2*1/2+1/2*1/2)] = 0.4\]

CART在每一次迭代種選擇基尼指數最小的特徵及其對應的切分點進行分類

2.4 ID3、C4.5與Gini的區別

2.4.1 從樣本類型角度

從樣本類型角度,ID3隻能處理離散型變量,而C4.5和CART都處理連續性變量,C4.5處理連續性變量時,通過對數據排序之後找到類別不同的分割線作為切割點,根據切分點把連續型數學轉換為bool型,從而將連續型變量轉換多個取值區間的離散型變量。而對於CART,由於其構建時每次都會對特徵進行二值劃分,因此可以很好地適合連續性變量。

2.4.2 從應用角度

ID3和C4.5隻適用於分類任務,而CART既可以用於分類也可以用於回歸

2.4.3 從實現細節、優化等角度

ID3對樣本特徵缺失值比較敏感,而C4.5和CART可以對缺失值進行不同方式的處理,ID3和C4.5可以在每個結點熵產生出多叉分支,且每個特徵在層級之間不會復用,而CART每個結點只會產生兩個分支,因此會形成一顆二叉樹,且每個特徵可以被重複使用;ID3和C4.5通過剪枝來權衡樹的準確性和泛化能力,而CART直接利用全部數據發現所有可能的樹結構進行對比。

3. 決策樹的剪枝

3.1 為什麼要進行剪枝?

對決策樹進行剪枝是為了防止過擬合

根據決策樹生成算法通過訓練數據集生成了複雜的決策樹,導致對於測試數據集出現了過擬合現象,為了解決過擬合,就必須考慮決策樹的複雜度,對決策樹進行剪枝,剪掉一些枝恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘,提升模型的泛化能力

決策樹的剪枝通常由兩種方法,預剪枝和后剪枝

3.2 預剪枝

預剪枝的核心思想是在樹中結點進行擴展之前,先計算當前的劃分是否能帶來模型泛化能力的提升,如果不能,則不再繼續生長子樹。此時可能存在不同類別的樣本同時存於結點中,按照多數投票的原則判斷該結點所屬類別。預剪枝對於何時停止決策樹的生長有以下幾種方法

  • (1)當樹達到一定深度的時候,停止樹的生長
  • (2)當恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘結點數到達某個閾值的時候,停止樹的生長
  • (3)當到達結點的樣本數量少於某個閾值的時候,停止樹的生長
  • (4)計算每次分裂對測試集的準確度提升,當小於某個閾值的時候,不再繼續擴展

預剪枝思想直接,算法簡單,效率高特點,適合解決大規模問題。但如何準確地估計何時停止樹的生長,針對不同問題會有很大差別,需要一定的經驗判斷。且預剪枝存在一定的局限性,有欠擬合的風險

3.3 后剪枝

后剪枝的核心思想是讓算法生成一顆完全生長的決策樹,然後從底層向上計算是否剪枝。剪枝過程將子樹刪除,用一個恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘結點代替,該結點的類別同樣按照多數投票原則進行判斷。同樣地,后剪枝恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘可以通過在測試集上的準確率進行判斷,如果剪枝過後的準確率有所提升,則進行剪枝,后剪枝方法通常可以得到泛化能力更強的決策樹,但時間開銷更大

損失函數

\[C_a(T) = \sum_{t=1}^{|T|}N_tH_t(T) + a|T|\]

\(其中|T|為恭弘=恭弘=恭弘=叶 恭弘 恭弘 恭弘結點個數,N_t為結點t的樣本個數,H_t(T)為結點t的信息熵,a|T|為懲罰項,a>=0\)

\[C_a(T) = \sum_{t=1}^{|T|}N_tH_t(T) + a|T| = -\sum_{t=1}^{|T|}\sum_{k=1}^KN_{tk}\log \frac{N_{tk}}{N_t} + a|T|\]

注意:上面的公式中是\(N_{tk}\log \frac{N_{tk}}{N_t}\),而不是\(\frac{N_{tk}}{N_t} \log \frac{N_{tk}}{N_t}\)

令:\[C_a(T) = C(T) + a|T|\]

\(C(T)\)表示模型對訓練數據的預測誤差,即模型與訓練數據的擬合程度,|T|表示模型複雜度,參數a>=0控制兩者的影響力,較大的a促使選擇較簡單的模型,較小的a促使選擇複雜的模型,a=0意味着只考慮模型與訓練數據的擬合程度,不考慮模型的複雜度

4. 使用sklearn庫為衛星數據集訓練並微調一個決策樹

4.1 需求

  • a.使用make_moons(n_samples=10000,noise=0.4)生成一個衛星數據集
  • b.使用train_test_split()拆分訓練集和測試集
  • c.使用交叉驗證的網格搜索為DecisionTreeClassifier找到合適的超參數,提示:嘗試max_leaf_nodes的多種值
  • d.使用超參數對整個訓練集進行訓練,並測量模型測試集上的性能

代碼實現

from sklearn.datasets import make_moons
import numpy as np
import pandas as pd
dataset = make_moons(n_samples=10000,noise=0.4)
print(type(dataset))
print(dataset)
<class 'tuple'>
(array([[ 0.24834453, -0.11160162],
       [-0.34658051, -0.43774172],
       [-0.25009951, -0.80638312],
       ...,
       [ 2.3278198 ,  0.39007769],
       [-0.77964208,  0.68470383],
       [ 0.14500963,  1.35272533]]), array([1, 1, 1, ..., 1, 0, 0], dtype=int64))
dataset_array = np.array(dataset[0])
label_array = np.array(dataset[1])
print(dataset_array.shape,label_array.shape)
(10000, 2) (10000,)
# 拆分數據集
from sklearn.model_selection import train_test_split
x_train,x_test = train_test_split(dataset_array,test_size=0.2,random_state=42)
print(x_train.shape,x_test.shape)
y_train,y_test = train_test_split(label_array,test_size=0.2,random_state=42)
print(y_train.shape,y_test.shape)
(8000, 2) (2000, 2)
(8000,) (2000,)
# 使用交叉驗證的網格搜索為DecisionTreeClassifier找到合適的超參數
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

decisionTree = DecisionTreeClassifier(criterion='gini')
param_grid = {'max_leaf_nodes': [i for i in range(2,10)]}
gridSearchCV = GridSearchCV(decisionTree,param_grid=param_grid,cv=3,verbose=2)
gridSearchCV.fit(x_train,y_train)
Fitting 3 folds for each of 8 candidates, totalling 24 fits
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  24 out of  24 | elapsed:    0.0s finished

GridSearchCV(cv=3, error_score='raise-deprecating',
       estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best'),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid={'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=2)
print(gridSearchCV.best_params_)
decision_tree = gridSearchCV.best_estimator_
{'max_leaf_nodes': 4}
# 使用測試集對模型進行評估
from sklearn.metrics import accuracy_score
y_prab = gridSearchCV.predict(x_test)
print('accuracy_score:',accuracy_score(y_test,y_prab))
accuracy_score: 0.8455
# 可視化模型
from sklearn.tree import export_graphviz

export_graphviz(decision_tree,
               out_file='./tree.dot',
               rounded = True,
               filled = True)

生成tree.dot文件,然後使用dot命令\[dot -Tpng tree.dot -o decisontree_moons.png\]

5. 附錄

5.1 sklearn.tree.DecisionTreeClassifier類說明

5.1.1 DecsisionTreeClassifier類參數說明

  • criterion: 特徵選擇方式,string,(‘gini’ or ‘entropy’),default=’gini’
  • splitter: 每個結點的拆分策略,(‘best’ or ‘random’),string,default=’best’
  • max_depth: int,default=None
  • min_samples_split: int,float,default=2,分割前所需的最小樣本數
  • min_samples_leaf:
  • min_weight_fraction_leaf:
  • max_features:
  • random_state:
  • max_leaf_nodes:
  • min_impurity_decrease:
  • min_impurity_split:
  • class_weight:
  • presort: bool,default=False,對於小型數據集(幾千個以內)設置presort=True通過對數據預處理來加快訓練,但對於較大訓練集而言,可能會減慢訓練速度

5.1.2 DecisionTreeClassifier屬性說明

  • classes_:
  • feature_importances_:
  • max_features_:
  • n_classes_:
  • n_features_:
  • n_outputs_:
  • tree_:

5.2 GridSearchCV類說明

5.2.1 GridSearchCV參數說明

  • estimator: 估算器,繼承於BaseEstimator
  • param_grid: dict,鍵為參數名,值為該參數需要測試值選項
  • scoring: default=None
  • fit_params:
  • n_jobs: 設置要并行運行的作業數,取值為None或1,None表示1 job,1表示all processors,default=None
  • cv: 交叉驗證的策略數,None或integer,None表示默認3-fold, integer指定“(分層)KFold”中的摺疊數
  • verbose: 輸出日誌類型

5.2.2 GridSearchCV屬性說明

  • cv_results_: dict of numpy(masked) ndarray
  • best_estimator_:
  • best_score_: Mean cross-validated score of the best_estimator
  • best_params_:
  • best_index_: int,The index (of the “cv_results_“ arrays) which corresponds to the best candidate parameter setting
  • scorer_:
  • n_splits_: The number of cross-validation splits (folds/iterations)
  • refit_time: float

參考資料:

  • (1)
  • (2)
  • (3)李航

【精選推薦文章】

自行創業 缺乏曝光? 下一步"網站設計"幫您第一時間規劃公司的門面形象

網頁設計一頭霧水??該從何著手呢? 找到專業技術的網頁設計公司,幫您輕鬆架站!

評比前十大台北網頁設計台北網站設計公司知名案例作品心得分享

台北網頁設計公司這麼多,該如何挑選?? 網頁設計報價省錢懶人包"嚨底家"

您可能也會喜歡…