模型對自己有多少信心??Part 1: Sklearn解密 predict_proba是甚麼???

以Decision Tree為例

倢愷 Oscar
8 min readMar 17, 2021

什麼是機率性質的預測??為甚麼很重要??

在機器學習與深度學習的領域,最常被大家抨擊的一點就是,model只給出一個output,而我們不知道model具體為什麼給出這個output,這種解釋性的問題很難解,但是更簡單一層,大多時候我們連「model對自己目前的預測多有信心都不知道」,更別談解釋了。

而機率性質的預測某種程度上就可以視為是model對自己預測的信心指數。

舉例來說:
今天A醫生只跟你說手術會成功,而B醫生說手術90%會成功,對你而言就是不同的資訊。
對大部分人而言,多講了成功的機率,更容易讓你做決策。
所以我們才需要機率性質的預測(prediction)

回歸到機器學習,如果我們今天是3分類問題,要把一張圖片分為 {狗/貓/兔子},那我們給任意一張圖片,我們就希望模型可以告訴我們,他分別認為這張圖片屬於三個類別的機率,可能是 {90% / 5% / 5%} 也可能是 {40% / 30% / 30%},這兩種情況在預測時候會告訴你這張圖片是狗,但是卻代表著完全不同的兩種意義。

而在Scikit Learn中大家會發現,大多數的Classifier都有predict_proba這個method,而其所呈現的就是機率性質預測這個行為,但是卻很少人認真了解過每個模型要如何產生機率性質預測,而導致使用錯誤的模型,或是錯誤解讀模型的信心。

這邊文章將會解讀Scikit Learn在目前版本中如何去實踐predict_proba,並且提出一些使用上的建議跟必須要注意的地方,Predict_proba目前預計會有3個part要講,分別為Decision Tree的predict_proba、SVM的predict_proba跟Callibration。

哪些模型本身就有機率性質的預測?

大家在看上面的例子的時候,如果對深度學習夠熟悉的,可能馬上就想到,我們的Softmax function不就是做了這件事嗎???是的,當我們使用深度學習的模型時,只要是分類問題我們大多會在最後一層接上Softmax,而Softmax本身就會給出每個類別的機率預測。

所以一般而言我們在深度學習的模型中,都是可以取得機率預測的,當然準不準是另外一回事,在深度學習這塊我們也有很多針對Uncertainty的研究,之後有時間可以再開一篇來講。

但是在機器學習就不一樣。

很多機器學習演算法本身並沒有機率性質的預測。

舉例而言:Decision Tree(決策樹),Decision Tree本身的預測是藉由「一連串的判斷,把input分配到樹中的某一個leaf node,並依據同個leaf node裡面比例最高類別當成預測」。

簡單來說,如果我們今天要從一個學生的外型來判斷他的性別,用Decision Tree我們做的事情可能就是:
1. 這個學生的頭髮長度沒有超過10cm (第一層決策)
2. 在頭髮長度沒有超過10cm的學生中,這個學生的身高超過170cm (第二層決策)
3. 在前兩項敘述依序達成的情況下,資料裡面剩下50位學生,而其中45位是男生,所以我們預測這位同學是男生。 (依照比例做預測)

這篇可能就可以看出,我們的Decision Tree最後是直接輸出一個「結果」,而不是這個學生有多少機率是男生。

所以Decision Tree實際上是沒有機率性質的預測的,而大家常用的SVM也是同理。

那哪些模型本身就有機率性質的預測呢?

最典型就是Naive Bayes(樸素貝氏模型),因為Naive Bayes本身的演算法就是基於貝氏機率來做分解,中間藉由統計的方法得到各種分布的數值,但是模型本身的預測就是機率分布

而大家常用的Logistic Regression也是,Logistic Regression本身理論上並沒有像Naive Bayes這個嚴謹的機率定義,但是因為他的output都在0~1之間,並且每個類別output的總和為1,所以符合機率最重要的兩個公理(機率要藉於0~1之間、所有可能結果的機率和為1),因此我們也可以當作他有機率性質的預測。

那Sklearn的Decision Tree是怎麼做到機率性質預測的???

這就是很神奇的地方了,在我早期接觸機器學習的時候,我每次用到predict_proba這個methods都會嘖嘖稱奇XDD

實際上sklearn的所有source code都是可以看到的,所以如果對source code有興趣完全可以自行研究。以decision tree的predcit_proba而言,可以參考這個link

我這邊做一個直觀的講解,假設我們今天的Decision Tree長的如下圖。我們一樣是想基於學生的外在數據來分析性別。

例子一:甲學生:頭髮超過10cm,而體重也超過50,則甲學生會被這個Decision Tree分到D類別,而D類別會輸出男生,因為D類別男生比例比較高。

而Decision Tree這邊製造出機率預測的方法就非常簡單,看D類別裡面的比率,D類別裡面4/7是男生,所以我們的機率預測就是甲學生0.571的機率是男生。

例子二:乙學生:頭髮超過10cm,而體重沒超過50,則乙學生會被這個Decision Tree分到C類別,而C類別會輸出女生,因為C類別女生比例比較高。而輸出機率為0.909的機率是女生。

所以Decision Tree的機率預測方法就是:用最終在的leaf node之中的比率來當成預測的機率。

上面看起來都很美好,但是我們來假設一個新的情況。

甲學生跟丙學生都被分到了D組,
甲學生 頭髮 15cm 體重55kg
丙學生 頭髮 15cm 體重95kg
而兩者都得到了0.571的機率是男生的輸出

上面這個例子可能很多人就會看出問題了,以Decision Tree的這個機率預測的方法,所有在同一個leaf node裡面的data,都會有一模一樣的機率,而不在乎這些data本身的性質,像是我們可能會預期女生要到95kg的難度比男生更難,所以直觀上我們可能會覺得丙學生更有可能是男生,但是在這個情況就做不到。

這是Decision Tree的Predict_proba第一個很核心的問題:同leaf node裡面的差異被忽略了。

而第二個更嚴重的問題,我們來看看當我們實際把Decision Tree運行在iris data上會有什麼效果。(code參考如下)

from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_train, y_train)
prob = clf.predict_proba(X_test)print(prob)

當我把prob最終全部print出來會長如此。

一連串非常「極端」的機率!!!

這不是我們想要的東西,因為這樣的機率預測跟沒有機率一樣。

那為甚麼會出現這個問題,因為Decision Tree在你沒有設定任何pruning的時候(max_depth, min_sample_split …),會盡力切到完美,每一個leaf node都只蘊含一種class的data。

這邊可以看到每一個leaf nodes都只有一種class的data,所以最後各個leaf node都會輸出100%的機率,不論是預測哪個class。

Decision Tree的predict_proba在沒有pruning的時候高機率都會是100%。

很多人看到這邊可能就覺得,這很好解決,我只要加上pruning就好。

實際上沒錯,加上pruning我們可以讓機率預測不要都是100%,但是我實務上的經驗是,即便加上很強的pruning,我們都還是會得到90% 95%…相對極端的數值。

Decision Tree即便有pruning,都會傾向over estimate每一個sample的probability,也就是說會對自己過於有自信。

而這跟Decision Tree本身的訓練方法就有關。

結論,我推薦大家,不要使用Decision Tree的Predict_proba,或是不要依賴它,因為數值容易過於高估機率。

如果喜歡這篇文章可以幫我多拍手幾次XD,或是對於哪個類型文章有興趣都可以在留言區跟我講~ 後續會以中難度的ML/DS/AI知識為主,以及AI/HCI研究相關知識

--

--

倢愷 Oscar

我是倢愷,CTO at TeraThinker an AI Adaptive Learning System Company。AI/HCI研究者,超過100場的ML、DL演講、workshop經驗。主要學習如何將AI落地於業界。 有家教、演講合作,可以email跟我聯絡:axk51013@gmail.com