🔥알림🔥
① 테디노트 유튜브 - 구경하러 가기!
② LangChain 한국어 튜토리얼 바로가기 👀
③ 랭체인 노트 무료 전자책(wikidocs) 바로가기 🙌

14 분 소요

Decision Tree는 Random Forest Ensemble 알고리즘의 기본이 되는 알고리즘이며, Tree 기반 알고리즘입니다. 의사결정나무 혹은 결정트리로 불리우는 이 알고리즘은 머신러닝의 학습 결과에 대하여 시각화를 통한 직관적인 이해가 가능하다는 것이 큰 장점입니다. 더불어, Random Forest Ensemble 알고리즘은 바로 이 Decision Tree 알고리즘의 앙상블 (Ensemble) 알고리즘인데, Random Forest 앙상블 알고리즘이 사용성은 쉬우면서 성능까지 뛰어나 캐글 (Kaggle.com)과 같은 데이터 분석 대회에서 Baseline 알고리즘으로 많이 활용되고 있습니다.

이번 실습에서는 Decision Tree 알고리즘에 대하여 시각화, Entropy와 Gini 계수에 대한 상세한 이해를 돕고자 만들어진 튜토리얼 이며, 실습 후반부에는 자주 사용되는 Hyperparameter 에 대한 소개도 진행합니다.

Random Forest 의 Hyperparameter와 겹치는 부분이 많기 때문에, 본 실습을 통해 Hyperparameter에 대하여 숙지해 두시면 자동으로 Random Forest 알고리즘의 Hyperparameter에 대한 이해까지 할 수 있게되는 1석 2조의 효과를 보실 수 있습니다.

코드

Colab으로 열기 Colab으로 열기

GitHub GitHub에서 소스보기


from IPython.display import Image

결정트리 or 의사결정나무 (Decision Tree)

결정트리를 가장 단수하게 표현하자면, Tree 구조를 가진 알고리즘입니다.

의사결정나무는 데이터를 분석하여 데이터 사이에서 패턴을 예측 가능한 규칙들의 조합으로 나타내며, 이 과정을 시각화 해 본다면 마치 스무고개 놀이와 비슷합니다.

Image(url='https://miro.medium.com/max/2960/1*dc_342kIsHCzuko1TtyEGQ.png', width=500)

결정트리의 기본 아이디어는 sample이 가장 섞이지 않은 상태로 완전히 분류되는 것, 다시 말해서 엔트로피(Entropy)를 낮추도록 만드는 것입니다.

엔트로피 (Entropy)

엔트로피는 쉽게 말해서 무질서한 정도를 정량화(수치화)한 값입니다.

다음은 엔트로피 지수를 방이 어질러있는 정도를 예시로 들어 표현되었습니다.

Image(url='https://image.slidesharecdn.com/entropyandthe2ndlaw-120327062903-phpapp02/95/103-entropy-and-the-2nd-law-3-728.jpg?cb=1335190079', width=500)

엔트로피 수식의 이해

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# 샘플데이터를 생성합니다.
group_1 = np.array([0.3, 0.4, 0.3])
group_2 = np.array([0.7, 0.2, 0.1])
group_3 = np.array([0.01, 0.01, 0.98])
fig, axes = plt.subplots(1, 3)
fig.set_size_inches(12, 4)
axes[0].bar(np.arange(3), group_1, color='blue')
axes[0].set_title('Group 1')
axes[1].bar(np.arange(3), group_2, color='red')
axes[1].set_title('Group 2')
axes[2].bar(np.arange(3), group_3, color='green')
axes[2].set_title('Group 3')
plt.show()
Image(url='https://miro.medium.com/max/1122/0*DkWdyGidNSfdT1Nu.png', width=350)
# entropy를 구현합니다.
def entropy(x):
    return (-x*np.log2(x)).sum()

Entropy 계산 및 시각화

entropy_1 = entropy(group_1)
entropy_2 = entropy(group_2)
entropy_3 = entropy(group_3)

print(f'Group 1: {entropy_1:.3f}\nGroup 2: {entropy_2:.3f}\nGroup 3: {entropy_3:.3f}')
Group 1: 1.571
Group 2: 1.157
Group 3: 0.161
plt.figure(figsize=(5, 5))
plt.bar(['Group 1', 'Group 2', 'Group 3'], [entropy_1, entropy_2, entropy_3])
plt.title('Entropy', fontsize=15)
plt.show()

지니 계수 (Gini Index)

  • 클래쓰들이 공평하게 섞여 있을 수록 지니 계수는 올라갑니다.
  • Decision Tree는 지니 불순도를 낮추는 방향으로 가지치기를 진행합니다.
Image(url='http://www.learnbymarketing.com/wp-content/uploads/2016/02/gini-index-formula.png', width=350)
# Gini Index 구현합니다.
def gini(x):
    return 1 - ((x / x.sum())**2).sum()
# 샘플데이터를 생성합니다.
group_1 = np.array([50, 50])
group_2 = np.array([30, 70])
group_3 = np.array([0, 100])
fig, axes = plt.subplots(1, 3)
fig.set_size_inches(12, 4)
axes[0].bar(['Positive', 'Negative'], group_1, color='blue')
axes[0].set_title('Group 1')
axes[1].bar(['Positive', 'Negative'], group_2, color='red')
axes[1].set_title('Group 2')
axes[2].bar(['Positive', 'Negative'], group_3, color='green')
axes[2].set_title('Group 3')
plt.show()
gini_1 = gini(group_1)
gini_2 = gini(group_2)
gini_3 = gini(group_3)

print(f'Group 1: {gini_1:.3f}\nGroup 2: {gini_2:.3f}\nGroup 3: {gini_3:.3f}')
Group 1: 0.500
Group 2: 0.420
Group 3: 0.000
plt.figure(figsize=(5, 5))
plt.bar(['Group 1', 'Group 2', 'Group 3'], [gini_1, gini_2, gini_3])
plt.title('Gini Index', fontsize=15)
plt.show()

Decision Tree 구현

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

SEED = 42
# breast cancer 데이터셋 로드
cancer = load_breast_cancer()
# train, test 데이터 분할
x_train, x_test, y_train, y_test = train_test_split(cancer['data'], cancer['target'], stratify=cancer['target'], random_state=SEED)
# 알고리즘 정의
tree = DecisionTreeClassifier(random_state=0)
# 학습
tree.fit(x_train, y_train)
DecisionTreeClassifier(random_state=0)
# 예측
pred = tree.predict(x_test)
# 정확도 측정
accuracy = accuracy_score(pred, y_test)
print(f'Accuracy Score: {accuracy:.3f}')
Accuracy Score: 0.937

의사결정나무 시각화

from sklearn.tree import export_graphviz
from sklearn.metrics import accuracy_score
import graphviz

def show_trees(tree):
    export_graphviz(tree, out_file="tree.dot", class_names=["악성", "양성"],
                    feature_names=cancer['feature_names'], 
                    precision=3, filled=True)
    with open("tree.dot") as f:
        dot_graph = f.read()
    pred = tree.predict(x_test)
    print('정확도: {:.2f} %'.format(accuracy_score(y_test, pred) * 100))

    display(graphviz.Source(dot_graph))
show_trees(tree)
정확도: 93.71 %
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 worst radius <= 16.795 gini = 0.468 samples = 426 value = [159, 267] class = 양성 1 worst concave points <= 0.136 gini = 0.161 samples = 284 value = [25, 259] class = 양성 0->1 True 28 texture error <= 0.473 gini = 0.106 samples = 142 value = [134, 8] class = 악성 0->28 False 2 radius error <= 1.048 gini = 0.031 samples = 252 value = [4, 248] class = 양성 1->2 17 worst texture <= 25.62 gini = 0.451 samples = 32 value = [21, 11] class = 악성 1->17 3 smoothness error <= 0.003 gini = 0.024 samples = 251 value = [3, 248] class = 양성 2->3 16 gini = 0.0 samples = 1 value = [1, 0] class = 악성 2->16 4 mean texture <= 19.9 gini = 0.375 samples = 4 value = [1, 3] class = 양성 3->4 7 area error <= 48.7 gini = 0.016 samples = 247 value = [2, 245] class = 양성 3->7 5 gini = 0.0 samples = 3 value = [0, 3] class = 양성 4->5 6 gini = 0.0 samples = 1 value = [1, 0] class = 악성 4->6 8 worst texture <= 33.35 gini = 0.008 samples = 243 value = [1, 242] class = 양성 7->8 13 mean concavity <= 0.029 gini = 0.375 samples = 4 value = [1, 3] class = 양성 7->13 9 gini = 0.0 samples = 225 value = [0, 225] class = 양성 8->9 10 worst texture <= 33.8 gini = 0.105 samples = 18 value = [1, 17] class = 양성 8->10 11 gini = 0.0 samples = 1 value = [1, 0] class = 악성 10->11 12 gini = 0.0 samples = 17 value = [0, 17] class = 양성 10->12 14 gini = 0.0 samples = 1 value = [1, 0] class = 악성 13->14 15 gini = 0.0 samples = 3 value = [0, 3] class = 양성 13->15 18 worst area <= 817.1 gini = 0.375 samples = 12 value = [3, 9] class = 양성 17->18 23 worst symmetry <= 0.268 gini = 0.18 samples = 20 value = [18, 2] class = 악성 17->23 19 mean smoothness <= 0.123 gini = 0.18 samples = 10 value = [1, 9] class = 양성 18->19 22 gini = 0.0 samples = 2 value = [2, 0] class = 악성 18->22 20 gini = 0.0 samples = 9 value = [0, 9] class = 양성 19->20 21 gini = 0.0 samples = 1 value = [1, 0] class = 악성 19->21 24 fractal dimension error <= 0.002 gini = 0.444 samples = 3 value = [1, 2] class = 양성 23->24 27 gini = 0.0 samples = 17 value = [17, 0] class = 악성 23->27 25 gini = 0.0 samples = 1 value = [1, 0] class = 악성 24->25 26 gini = 0.0 samples = 2 value = [0, 2] class = 양성 24->26 29 gini = 0.0 samples = 5 value = [0, 5] class = 양성 28->29 30 worst concavity <= 0.191 gini = 0.043 samples = 137 value = [134, 3] class = 악성 28->30 31 worst texture <= 30.975 gini = 0.48 samples = 5 value = [2, 3] class = 양성 30->31 34 gini = 0.0 samples = 132 value = [132, 0] class = 악성 30->34 32 gini = 0.0 samples = 3 value = [0, 3] class = 양성 31->32 33 gini = 0.0 samples = 2 value = [2, 0] class = 악성 31->33

주요 Hyper Parameter

max_depth

max_depth는 최대 트리의 깊이를 제한 합니다.

기본 값은 None, 제한 없음 입니다.

tree = DecisionTreeClassifier(max_depth=3, random_state=SEED)
tree.fit(x_train, y_train)
show_trees(tree)
정확도: 94.41 %
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 worst radius <= 16.795 gini = 0.468 samples = 426 value = [159, 267] class = 양성 1 worst concave points <= 0.136 gini = 0.161 samples = 284 value = [25, 259] class = 양성 0->1 True 8 texture error <= 0.473 gini = 0.106 samples = 142 value = [134, 8] class = 악성 0->8 False 2 area error <= 91.555 gini = 0.031 samples = 252 value = [4, 248] class = 양성 1->2 5 worst texture <= 25.62 gini = 0.451 samples = 32 value = [21, 11] class = 악성 1->5 3 gini = 0.024 samples = 251 value = [3, 248] class = 양성 2->3 4 gini = 0.0 samples = 1 value = [1, 0] class = 악성 2->4 6 gini = 0.375 samples = 12 value = [3, 9] class = 양성 5->6 7 gini = 0.18 samples = 20 value = [18, 2] class = 악성 5->7 9 gini = 0.0 samples = 5 value = [0, 5] class = 양성 8->9 10 worst concavity <= 0.191 gini = 0.043 samples = 137 value = [134, 3] class = 악성 8->10 11 gini = 0.48 samples = 5 value = [2, 3] class = 양성 10->11 12 gini = 0.0 samples = 132 value = [132, 0] class = 악성 10->12

min_sample_split

min_sample_split은 노드 내에서 분할이 필요한 최소의 샘플 숫자입니다.

기본 값은 2입니다.

tree = DecisionTreeClassifier(max_depth=6, min_samples_split=20,  random_state=SEED)
tree.fit(x_train, y_train)
show_trees(tree)
정확도: 94.41 %
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 worst radius <= 16.795 gini = 0.468 samples = 426 value = [159, 267] class = 양성 1 worst concave points <= 0.136 gini = 0.161 samples = 284 value = [25, 259] class = 양성 0->1 True 16 texture error <= 0.473 gini = 0.106 samples = 142 value = [134, 8] class = 악성 0->16 False 2 area error <= 91.555 gini = 0.031 samples = 252 value = [4, 248] class = 양성 1->2 11 worst texture <= 25.62 gini = 0.451 samples = 32 value = [21, 11] class = 악성 1->11 3 area error <= 48.7 gini = 0.024 samples = 251 value = [3, 248] class = 양성 2->3 10 gini = 0.0 samples = 1 value = [1, 0] class = 악성 2->10 4 smoothness error <= 0.003 gini = 0.016 samples = 247 value = [2, 245] class = 양성 3->4 9 gini = 0.375 samples = 4 value = [1, 3] class = 양성 3->9 5 gini = 0.375 samples = 4 value = [1, 3] class = 양성 4->5 6 worst texture <= 33.35 gini = 0.008 samples = 243 value = [1, 242] class = 양성 4->6 7 gini = 0.0 samples = 225 value = [0, 225] class = 양성 6->7 8 gini = 0.105 samples = 18 value = [1, 17] class = 양성 6->8 12 gini = 0.375 samples = 12 value = [3, 9] class = 양성 11->12 13 worst symmetry <= 0.268 gini = 0.18 samples = 20 value = [18, 2] class = 악성 11->13 14 gini = 0.444 samples = 3 value = [1, 2] class = 양성 13->14 15 gini = 0.0 samples = 17 value = [17, 0] class = 악성 13->15 17 gini = 0.0 samples = 5 value = [0, 5] class = 양성 16->17 18 worst concavity <= 0.191 gini = 0.043 samples = 137 value = [134, 3] class = 악성 16->18 19 gini = 0.48 samples = 5 value = [2, 3] class = 양성 18->19 20 gini = 0.0 samples = 132 value = [132, 0] class = 악성 18->20

min_samples_leaf

min_samples_leaf는 말단 노드의 최소 샘플의 숫자를 지정합니다.

기본 값은 1 입니다.

DecisionTreeClassifier()
DecisionTreeClassifier()

max_leaf_nodes

max_leaf_nodes는 말단 노드의 최대 갯수 (과대 적합 방지용)

기본 값은 None, 제한 없음 입니다.

tree = DecisionTreeClassifier(max_depth=7, max_leaf_nodes=10, random_state=SEED)
tree.fit(x_train, y_train)
show_trees(tree)
정확도: 94.41 %
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 worst radius <= 16.795 gini = 0.468 samples = 426 value = [159, 267] class = 양성 1 worst concave points <= 0.136 gini = 0.161 samples = 284 value = [25, 259] class = 양성 0->1 True 2 texture error <= 0.473 gini = 0.106 samples = 142 value = [134, 8] class = 악성 0->2 False 3 perimeter error <= 6.597 gini = 0.031 samples = 252 value = [4, 248] class = 양성 1->3 4 worst texture <= 25.62 gini = 0.451 samples = 32 value = [21, 11] class = 악성 1->4 17 gini = 0.024 samples = 251 value = [3, 248] class = 양성 3->17 18 gini = 0.0 samples = 1 value = [1, 0] class = 악성 3->18 7 worst area <= 817.1 gini = 0.375 samples = 12 value = [3, 9] class = 양성 4->7 8 worst symmetry <= 0.268 gini = 0.18 samples = 20 value = [18, 2] class = 악성 4->8 11 gini = 0.18 samples = 10 value = [1, 9] class = 양성 7->11 12 gini = 0.0 samples = 2 value = [2, 0] class = 악성 7->12 15 gini = 0.444 samples = 3 value = [1, 2] class = 양성 8->15 16 gini = 0.0 samples = 17 value = [17, 0] class = 악성 8->16 5 gini = 0.0 samples = 5 value = [0, 5] class = 양성 2->5 6 worst concavity <= 0.191 gini = 0.043 samples = 137 value = [134, 3] class = 악성 2->6 9 worst texture <= 30.975 gini = 0.48 samples = 5 value = [2, 3] class = 양성 6->9 10 gini = 0.0 samples = 132 value = [132, 0] class = 악성 6->10 13 gini = 0.0 samples = 3 value = [0, 3] class = 양성 9->13 14 gini = 0.0 samples = 2 value = [2, 0] class = 악성 9->14

max_features

최적의 분할을 찾기 위해 고려할 피처의 수

0.8 은 80% 의 feature 만 고려하여 분할 알고리즘 적용

기본 값은 None, 모두 사용입니다.

tree = DecisionTreeClassifier(max_depth=7, max_features=0.8, random_state=SEED)
tree.fit(x_train, y_train)
show_trees(tree)
정확도: 90.91 %
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> Tree 0 worst radius <= 16.795 gini = 0.468 samples = 426 value = [159, 267] class = 양성 1 worst concave points <= 0.136 gini = 0.161 samples = 284 value = [25, 259] class = 양성 0->1 True 28 texture error <= 0.473 gini = 0.106 samples = 142 value = [134, 8] class = 악성 0->28 False 2 worst fractal dimension <= 0.055 gini = 0.031 samples = 252 value = [4, 248] class = 양성 1->2 17 worst texture <= 25.62 gini = 0.451 samples = 32 value = [21, 11] class = 악성 1->17 3 gini = 0.0 samples = 1 value = [1, 0] class = 악성 2->3 4 smoothness error <= 0.003 gini = 0.024 samples = 251 value = [3, 248] class = 양성 2->4 5 mean smoothness <= 0.081 gini = 0.375 samples = 4 value = [1, 3] class = 양성 4->5 8 area error <= 48.7 gini = 0.016 samples = 247 value = [2, 245] class = 양성 4->8 6 gini = 0.0 samples = 3 value = [0, 3] class = 양성 5->6 7 gini = 0.0 samples = 1 value = [1, 0] class = 악성 5->7 9 worst texture <= 33.35 gini = 0.008 samples = 243 value = [1, 242] class = 양성 8->9 14 texture error <= 1.938 gini = 0.375 samples = 4 value = [1, 3] class = 양성 8->14 10 gini = 0.0 samples = 225 value = [0, 225] class = 양성 9->10 11 worst texture <= 33.8 gini = 0.105 samples = 18 value = [1, 17] class = 양성 9->11 12 gini = 0.0 samples = 1 value = [1, 0] class = 악성 11->12 13 gini = 0.0 samples = 17 value = [0, 17] class = 양성 11->13 15 gini = 0.0 samples = 3 value = [0, 3] class = 양성 14->15 16 gini = 0.0 samples = 1 value = [1, 0] class = 악성 14->16 18 mean concave points <= 0.08 gini = 0.375 samples = 12 value = [3, 9] class = 양성 17->18 23 worst symmetry <= 0.268 gini = 0.18 samples = 20 value = [18, 2] class = 악성 17->23 19 worst concave points <= 0.138 gini = 0.18 samples = 10 value = [1, 9] class = 양성 18->19 22 gini = 0.0 samples = 2 value = [2, 0] class = 악성 18->22 20 gini = 0.0 samples = 1 value = [1, 0] class = 악성 19->20 21 gini = 0.0 samples = 9 value = [0, 9] class = 양성 19->21 24 worst smoothness <= 0.132 gini = 0.444 samples = 3 value = [1, 2] class = 양성 23->24 27 gini = 0.0 samples = 17 value = [17, 0] class = 악성 23->27 25 gini = 0.0 samples = 1 value = [1, 0] class = 악성 24->25 26 gini = 0.0 samples = 2 value = [0, 2] class = 양성 24->26 29 gini = 0.0 samples = 5 value = [0, 5] class = 양성 28->29 30 worst concavity <= 0.191 gini = 0.043 samples = 137 value = [134, 3] class = 악성 28->30 31 worst symmetry <= 0.284 gini = 0.48 samples = 5 value = [2, 3] class = 양성 30->31 34 gini = 0.0 samples = 132 value = [132, 0] class = 악성 30->34 32 gini = 0.0 samples = 3 value = [0, 3] class = 양성 31->32 33 gini = 0.0 samples = 2 value = [2, 0] class = 악성 31->33

feature의 중요도 파악

feature_importances_ 변수를 통해서 tree 알고리즘이 학습시 고려한 feature 별 중요도를 확인할 수 있습니다.

tree.feature_importances_
array([0.        , 0.        , 0.        , 0.        , 0.00752597,
       0.        , 0.        , 0.01354675, 0.        , 0.        ,
       0.        , 0.05383566, 0.        , 0.00238745, 0.00231135,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.69546322, 0.04179055, 0.        , 0.        , 0.00668975,
       0.        , 0.01740312, 0.12587473, 0.02341413, 0.00975731])

DataFrame으로 만들면 중요도(feature importances) 순서로 정렬할 수 있습니다.

df = pd.DataFrame(list(zip(cancer['feature_names'], tree.feature_importances_)), columns=['feature', 'importance']).sort_values('importance', ascending=False)
df = df.reset_index(drop=True)
df.head(15)
feature importance
0 worst radius 0.695463
1 worst concave points 0.125875
2 texture error 0.053836
3 worst texture 0.041791
4 worst symmetry 0.023414
5 worst concavity 0.017403
6 mean concave points 0.013547
7 worst fractal dimension 0.009757
8 mean smoothness 0.007526
9 worst smoothness 0.006690
10 area error 0.002387
11 smoothness error 0.002311
12 worst perimeter 0.000000
13 worst area 0.000000
14 symmetry error 0.000000
plt.figure(figsize=(10, 10))
sns.barplot(y='feature', x='importance', data=df)
plt.show()

댓글남기기