Як витягти правила рішення з дерева рішень scikit?


156

Чи можу я витягти основні правила прийняття рішень (або «шляхи прийняття рішень») з навченого дерева в дереві рішень як текстовий список?

Щось на зразок:

if A>0.4 then if B<0.2 then if C>0.8 then class='X'

Спасибі за вашу допомогу.



Ви коли-небудь знаходили відповідь на цю проблему? Мені потрібно експортувати правила дерева рішень у форматі кроку даних SAS, який майже точно такий, як у вас зазначено.
Зелазний7

1
Ви можете використовувати пакет sklearn-porter для експорту та транпіляції дерев рішень (також випадкових лісів та дерев, що підсилюються ) на C, Java, JavaScript та інші.
Дарій

Ви можете перевірити це посилання- kdnuggets.com/2017/05/…
yogesh agrawal

Відповіді:


138

Я вважаю, що ця відповідь є правильнішою, ніж інші відповіді тут:

from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print "def tree({}):".format(", ".join(feature_names))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print "{}if {} <= {}:".format(indent, name, threshold)
            recurse(tree_.children_left[node], depth + 1)
            print "{}else:  # if {} > {}".format(indent, name, threshold)
            recurse(tree_.children_right[node], depth + 1)
        else:
            print "{}return {}".format(indent, tree_.value[node])

    recurse(0, 1)

Це виводить дійсну функцію Python. Ось приклад виводу для дерева, яке намагається повернути свій вхід, число від 0 до 10.

def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]

Ось кілька каменів спотикання, які я бачу в інших відповідях:

  1. Використовувати, tree_.threshold == -2щоб визначити, чи вузол є листочком, не є хорошою ідеєю. Що робити, якщо це справжній вузол рішення з порогом -2? Натомість слід подивитися tree.featureабо tree.children_*.
  2. Рядок features = [feature_names[i] for i in tree_.feature]збігається з моєю версією sklearn, оскільки деякі значення tree.tree_.feature- -2 (спеціально для листкових вузлів).
  3. Немає необхідності мати декілька, якщо твердження в рекурсивній функції лише одне - це добре.

1
Цей код відмінно працює для мене. Однак у мене є 500+ імен функцій, тому вихідний код людині майже неможливо зрозуміти. Чи є спосіб дозволити мені лише вводити у функцію імена, які мені цікаві?
користувач3768495

1
Я згоден з попереднім коментарем. IIUC, print "{}return {}".format(indent, tree_.value[node])слід змінити на print "{}return {}".format(indent, np.argmax(tree_.value[node][0]))функцію повернення індексу класу.
soupault

1
@paulkernfeld Ага так, я бачу, що ви можете перетворити цикл RandomForestClassifier.estimators_, але я не зміг розробити, як поєднати результати оцінювачів.
Натан Ллойд

6
Я не міг змусити це працювати в python 3, біти _tree, схоже, ніколи не працювали, і TREE_UNDEFINED не було визначено. Це посилання мені допомогло. Хоча експортований код не працює безпосередньо в python, його легко перекласти на інші мови: web.archive.org/web/20171005203850/http://www.kdnuggets.com/…
Josiah

1
@Josiah, додай () до операторів друку, щоб він працював у python3. напр. print "bla"=>print("bla")
Nir

48

Я створив власну функцію для вилучення правил із дерев рішень, створених sklearn:

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier

# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})

# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)

Ця функція спочатку починається з вузлів (ідентифікованих -1 у дочірніх масивах), а потім рекурсивно знаходить батьків. Я називаю це "родовище" вузла. Попутно я захоплюю значення, які мені потрібно створити, якщо / тоді / інакше логіка SAS:

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]

     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'

          lineage.append((parent, split, threshold[parent], features[parent]))

          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)

     for child in idx:
          for node in recurse(left, right, child):
               print node

Набори нижче кортезів містять усе, що потрібно для створення SAS, якщо / тоді / else. Мені не подобається використовувати doблоки в SAS, тому я створюю логіку, що описує весь шлях вузла. Єдине ціле число після кортежів - це ідентифікатор кінцевого вузла в шляху. Усі попередні кортежі поєднуються, щоб створити цей вузол.

In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6

Вихід GraphViz дерева прикладу


правильний цей тип дерева, тому що col1 знову з'являється, один - col1 <= 0,50000, а один col1 <= 2,5000, якщо так, чи це будь-який тип рекурсії, який використовується в бібліотеці
jayant singh

правий відділ матиме записи між ними (0.5, 2.5]. Дерева виконані з рекурсивними перегородками. Ніщо не заважає зміні вибиратися кілька разів.
Зелазний7

гаразд, ви можете пояснити рекурсійній частині того, що трапляється чітко, тому що я використав це у своєму коді, і подібний результат видно
jayant singh

38

Я змінив код, поданий Zelazny7, щоб надрукувати псевдокод:

def get_code(tree, feature_names):
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value

        def recurse(left, right, threshold, features, node):
                if (threshold[node] != -2):
                        print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                        if left[node] != -1:
                                recurse (left, right, threshold, features,left[node])
                        print "} else {"
                        if right[node] != -1:
                                recurse (left, right, threshold, features,right[node])
                        print "}"
                else:
                        print "return " + str(value[node])

        recurse(left, right, threshold, features, 0)

якщо ви подзвоните get_code(dt, df.columns)на той же приклад, ви отримаєте:

if ( col1 <= 0.5 ) {
return [[ 1.  0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0.  1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1.  0.]]
} else {
return [[ 0.  1.]]
}
}
}

1
Чи можете ви сказати, що саме [[1. 0.]] у операторі return повертається до наведеного вище висновку. Я не хлопець Python, але працюю над подібними речами. Тож мені буде добре, якщо ви, будь ласка, докажете якісь деталі, щоб мені було легше.
Субградіп Бозе

1
@ user3156186 Це означає, що в класі '0' є один об’єкт і нульові об’єкти в класі '1'
Даніеле

1
@Daniele, ти знаєш, як упорядковані заняття? Я б здогадався, що буквено-цифровий, але я ніде не знайшов підтвердження.
IanS

Дякую! Для кращого сценарію, де значення порогового значення фактично становить -2, нам може знадобитися перейти (threshold[node] != -2)на ( left[node] != -1)(подібно до наведеного нижче способу отримання ідентифікаторів дочірніх вузлів)
tlingf

@Daniele, будь-яка ідея, як зробити вашу функцію "get_code" "повернути" значення, а не "друкувати" її, тому що мені потрібно відправити її на іншу функцію?
RoyaumeIX

17

Scikit learn представив новий смачний новий метод, який називається export_textу версії 0.21 (травень 2019 року), щоб витягти правила з дерева. Документація тут . Створювати власну функцію більше не потрібно.

Після того, як ви підійдете до своєї моделі, вам знадобляться лише два рядки коду. По-перше, імпортуйте export_text:

from sklearn.tree.export import export_text

По-друге, створіть об’єкт, який буде містити ваші правила. Щоб правила виглядали більш зрозумілими, використовуйте feature_namesаргумент та передайте список імен своїх функцій. Наприклад, якщо ваша модель викликається, modelа ваші функції названі в кадрі даних, який називається X_train, ви можете створити об'єкт під назвою tree_rules:

tree_rules = export_text(model, feature_names=list(X_train))

Потім просто надрукуйте або збережіть tree_rules. Ваш результат буде виглядати приблизно так:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1

14

Існує новий DecisionTreeClassifierметод decision_path, у випуску 0.18.0 . Розробники пропонують широкий (добре задокументований) посібник .

Перший розділ коду в інструкції, що друкує структуру дерева, здається нормальним. Однак я змінив код у другому розділі, щоб допитати один зразок. Мої зміни позначаються с# <--

Редагувати Ці зміни відзначені # <--в коді нижче тих пір був оновлені в покрокової зв'язку після того, як помилки були відзначені в висувних запитах # 8653 і # 10951 . Зараз набагато простіше слідувати.

sample_id = 0
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
                                    node_indicator.indptr[sample_id + 1]]

print('Rules used to predict sample %s: ' % sample_id)
for node_id in node_index:

    if leave_id[sample_id] == node_id:  # <-- changed != to ==
        #continue # <-- comment out
        print("leaf node {} reached, no decision here".format(leave_id[sample_id])) # <--

    else: # < -- added else to iterate through decision nodes
        if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
            threshold_sign = "<="
        else:
            threshold_sign = ">"

        print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
              % (node_id,
                 sample_id,
                 feature[node_id],
                 X_test[sample_id, feature[node_id]], # <-- changed i to sample_id
                 threshold_sign,
                 threshold[node_id]))

Rules used to predict sample 0: 
decision id node 0 : (X[0, 3] (= 2.4) > 0.800000011921)
decision id node 2 : (X[0, 2] (= 5.1) > 4.94999980927)
leaf node 4 reached, no decision here

Змініть, sample_idщоб побачити шляхи прийняття рішень для інших зразків. Я не запитував розробників про ці зміни, просто здавався інтуїтивнішим, працюючи на прикладі.


ти мій друг - легенда! будь-які ідеї, як побудувати дерево рішення для конкретного зразка? велика допомога вдячна

1
Дякую Віктору, мабуть, найкраще задати це окремим питанням, оскільки побудова вимог може бути специфічною для потреб користувача. Ви, ймовірно, отримаєте хорошу відповідь, якщо надасте уявлення про те, як ви хочете, щоб виглядав результат.
Кевін

агов Кевін, я створив питання stackoverflow.com/questions/48888893 / ...

Ви б були так люб'язно поглянути: stackoverflow.com/questions/52654280/…
Олександр Червов

Чи можете ви пояснити, яку частину називають node_index, не отримуючи її. що воно робить?
Anindya Sankar Dey

12
from StringIO import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
print out.getvalue()

Ви можете побачити диграфське дерево. Тоді, clf.tree_.featureі clf.tree_.valueє масив вузлів розділення функція і масив значень вузлів відповідно. Ви можете посилатися на більш детальну інформацію з цього джерела github .


1
Так, я знаю, як намалювати дерево - але мені потрібна більш текстова версія - правила. щось на кшталт: orange.biolab.si/docs/latest/reference/rst/…
Dror Hilman

4

Просто тому, що всім було настільки корисно, я просто додам модифікацію до прекрасних рішень Zelazny7 та Daniele. Цей варіант призначений для python 2.7, з вкладками, щоб зробити його більш зрозумілим:

def get_code(tree, feature_names, tabdepth=0):
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, tabdepth=0):
            if (threshold[node] != -2):
                    print '\t' * tabdepth,
                    print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "} else {"
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node], tabdepth+1)
                    print '\t' * tabdepth,
                    print "}"
            else:
                    print '\t' * tabdepth,
                    print "return " + str(value[node])

    recurse(left, right, threshold, features, 0)

3

Коди нижче - це мій підхід під анакондою python 2.7 плюс назва пакету "pydot-ng" до створення файлу PDF з правилами рішення. Я сподіваюся, що це корисно.

from sklearn import tree

clf = tree.DecisionTreeClassifier(max_leaf_nodes=n)
clf_ = clf.fit(X, data_y)

feature_names = X.columns
class_name = clf_.classes_.astype(int).astype(str)

def output_pdf(clf_, name):
    from sklearn import tree
    from sklearn.externals.six import StringIO
    import pydot_ng as pydot
    dot_data = StringIO()
    tree.export_graphviz(clf_, out_file=dot_data,
                         feature_names=feature_names,
                         class_names=class_name,
                         filled=True, rounded=True,
                         special_characters=True,
                          node_ids=1,)
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("%s.pdf"%name)

output_pdf(clf_, name='filename%s'%n)

тут показано дерево-графію


3

Я переживав це, але мені потрібні були правила, які слід писати в такому форматі

if A>0.4 then if B<0.2 then if C>0.8 then class='X' 

Тож я адаптував відповідь @paulkernfeld (спасибі), яку ви можете налаштувати під свої потреби

def tree_to_code(tree, feature_names, Y):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    pathto=dict()

    global k
    k = 0
    def recurse(node, depth, parent):
        global k
        indent = "  " * depth

        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            s= "{} <= {} ".format( name, threshold, node )
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s

            recurse(tree_.children_left[node], depth + 1, node)
            s="{} > {}".format( name, threshold)
            if node == 0:
                pathto[node]=s
            else:
                pathto[node]=pathto[parent]+' & ' +s
            recurse(tree_.children_right[node], depth + 1, node)
        else:
            k=k+1
            print(k,')',pathto[parent], tree_.value[node])
    recurse(0, 1, 0)

3

Ось спосіб перевести все дерево в єдиний (не обов’язково занадто читабельний для людини) вираз пітона за допомогою бібліотеки SKompiler :

from skompiler import skompile
skompile(dtree.predict).to('python/code')

3

Це ґрунтується на відповіді @paulkernfeld. Якщо у вас є фрейм X з вашими функціями та цільовий кадр даних y з вашими резоонами, і ви хочете отримати уявлення, яке значення y закінчилося в якому вузлі (а також мурашник, щоб його скласти відповідно), ви можете зробити наступне:

    def tree_to_code(tree, feature_names):
        from sklearn.tree import _tree
        codelines = []
        codelines.append('def get_cat(X_tmp):\n')
        codelines.append('   catout = []\n')
        codelines.append('   for codelines in range(0,X_tmp.shape[0]):\n')
        codelines.append('      Xin = X_tmp.iloc[codelines]\n')
        tree_ = tree.tree_
        feature_name = [
            feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
            for i in tree_.feature
        ]
        #print "def tree({}):".format(", ".join(feature_names))

        def recurse(node, depth):
            indent = "      " * depth
            if tree_.feature[node] != _tree.TREE_UNDEFINED:
                name = feature_name[node]
                threshold = tree_.threshold[node]
                codelines.append ('{}if Xin["{}"] <= {}:\n'.format(indent, name, threshold))
                recurse(tree_.children_left[node], depth + 1)
                codelines.append( '{}else:  # if Xin["{}"] > {}\n'.format(indent, name, threshold))
                recurse(tree_.children_right[node], depth + 1)
            else:
                codelines.append( '{}mycat = {}\n'.format(indent, node))

        recurse(0, 1)
        codelines.append('      catout.append(mycat)\n')
        codelines.append('   return pd.DataFrame(catout,index=X_tmp.index,columns=["category"])\n')
        codelines.append('node_ids = get_cat(X)\n')
        return codelines
    mycode = tree_to_code(clf,X.columns.values)

    # now execute the function and obtain the dataframe with all nodes
    exec(''.join(mycode))
    node_ids = [int(x[0]) for x in node_ids.values]
    node_ids2 = pd.DataFrame(node_ids)

    print('make plot')
    import matplotlib.cm as cm
    colors = cm.rainbow(np.linspace(0, 1, 1+max( list(set(node_ids)))))
    #plt.figure(figsize=cm2inch(24, 21))
    for i in list(set(node_ids)):
        plt.plot(y[node_ids2.values==i],'o',color=colors[i], label=str(i))  
    mytitle = ['y colored by node']
    plt.title(mytitle ,fontsize=14)
    plt.xlabel('my xlabel')
    plt.ylabel(tagname)
    plt.xticks(rotation=70)       
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.00), shadow=True, ncol=9)
    plt.tight_layout()
    plt.show()
    plt.close 

не найелегантніша версія, але вона справляється з цим ...


1
Це хороший підхід, коли ви хочете повернути кодові рядки, а не просто роздрукувати їх.
Хаджар Хомаюні

3

Це потрібний вам код

Я правильно змінив код, що сподобався, для відступу в зошиті з юпітером python 3 правильно

import numpy as np
from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [feature_names[i] 
                    if i != _tree.TREE_UNDEFINED else "undefined!" 
                    for i in tree_.feature]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "    " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, np.argmax(tree_.value[node])))

    recurse(0, 1)

2

Ось функція, правила друку дерева рішень у науковому режимі під python 3 та з зрушеннями для умовних блоків, щоб зробити структуру більш зрозумілою:

def print_decision_tree(tree, feature_names=None, offset_unit='    '):
    '''Plots textual representation of rules of a decision tree
    tree: scikit-learn representation of tree
    feature_names: list of feature names. They are set to f1,f2,f3,... if not specified
    offset_unit: a string of offset of the conditional block'''

    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value
    if feature_names is None:
        features  = ['f%d'%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0):
            offset = offset_unit*depth
            if (threshold[node] != -2):
                    print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                    if left[node] != -1:
                            recurse (left, right, threshold, features,left[node],depth+1)
                    print(offset+"} else {")
                    if right[node] != -1:
                            recurse (left, right, threshold, features,right[node],depth+1)
                    print(offset+"}")
            else:
                    print(offset+"return " + str(value[node]))

    recurse(left, right, threshold, features, 0,0)

2

Ви також можете зробити його більш інформативним, відзначивши його, до якого класу він належить, або навіть вказавши його вихідне значення.

def print_decision_tree(tree, feature_names, offset_unit='    '):    
left      = tree.tree_.children_left
right     = tree.tree_.children_right
threshold = tree.tree_.threshold
value = tree.tree_.value
if feature_names is None:
    features  = ['f%d'%i for i in tree.tree_.feature]
else:
    features  = [feature_names[i] for i in tree.tree_.feature]        

def recurse(left, right, threshold, features, node, depth=0):
        offset = offset_unit*depth
        if (threshold[node] != -2):
                print(offset+"if ( " + features[node] + " <= " + str(threshold[node]) + " ) {")
                if left[node] != -1:
                        recurse (left, right, threshold, features,left[node],depth+1)
                print(offset+"} else {")
                if right[node] != -1:
                        recurse (left, right, threshold, features,right[node],depth+1)
                print(offset+"}")
        else:
                #print(offset,value[node]) 

                #To remove values from node
                temp=str(value[node])
                mid=len(temp)//2
                tempx=[]
                tempy=[]
                cnt=0
                for i in temp:
                    if cnt<=mid:
                        tempx.append(i)
                        cnt+=1
                    else:
                        tempy.append(i)
                        cnt+=1
                val_yes=[]
                val_no=[]
                res=[]
                for j in tempx:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_no.append(j)
                for j in tempy:
                    if j=="[" or j=="]" or j=="." or j==" ":
                        res.append(j)
                    else:
                        val_yes.append(j)
                val_yes = int("".join(map(str, val_yes)))
                val_no = int("".join(map(str, val_no)))

                if val_yes>val_no:
                    print(offset,'\033[1m',"YES")
                    print('\033[0m')
                elif val_no>val_yes:
                    print(offset,'\033[1m',"NO")
                    print('\033[0m')
                else:
                    print(offset,'\033[1m',"Tie")
                    print('\033[0m')

recurse(left, right, threshold, features, 0,0)

введіть тут опис зображення


2

Ось мій підхід до витягування правил рішення у формі, яку можна використовувати безпосередньо в sql, тому дані можна згрупувати по вузлу. (На основі підходів попередніх плакатів.)

Результатом стануть наступні CASEпропозиції, які можна скопіювати в оператор sql, напр.

SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>


import numpy as np

import pickle
feature_names=.............
features  = [feature_names[i] for i in range(len(feature_names))]
clf= pickle.loads(trained_model)
impurity=clf.tree_.impurity
importances = clf.feature_importances_
SqlOut=""

#global Conts
global ContsNode
global Path
#Conts=[]#
ContsNode=[]
Path=[]
global Results
Results=[]

def print_decision_tree(tree, feature_names, offset_unit=''    ''):    
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    value = tree.tree_.value

    if feature_names is None:
        features  = [''f%d''%i for i in tree.tree_.feature]
    else:
        features  = [feature_names[i] for i in tree.tree_.feature]        

    def recurse(left, right, threshold, features, node, depth=0,ParentNode=0,IsElse=0):
        global Conts
        global ContsNode
        global Path
        global Results
        global LeftParents
        LeftParents=[]
        global RightParents
        RightParents=[]
        for i in range(len(left)): # This is just to tell you how to create a list.
            LeftParents.append(-1)
            RightParents.append(-1)
            ContsNode.append("")
            Path.append("")


        for i in range(len(left)): # i is node
            if (left[i]==-1 and right[i]==-1):      
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not " +ContsNode[RightParents[i]]                     
                Results.append(" case when  " +Path[i]+"  then ''" +"{:4d}".format(i)+ " "+"{:2.2f}".format(impurity[i])+" "+Path[i][0:180]+"''")

            else:       
                if LeftParents[i]>=0:
                    if Path[LeftParents[i]]>" ":
                        Path[i]=Path[LeftParents[i]]+" AND " +ContsNode[LeftParents[i]]                                 
                    else:
                        Path[i]=ContsNode[LeftParents[i]]                                   
                if RightParents[i]>=0:
                    if Path[RightParents[i]]>" ":
                        Path[i]=Path[RightParents[i]]+" AND not " +ContsNode[RightParents[i]]                                   
                    else:
                        Path[i]=" not "+ContsNode[RightParents[i]]                      
                if (left[i]!=-1):
                    LeftParents[left[i]]=i
                if (right[i]!=-1):
                    RightParents[right[i]]=i
                ContsNode[i]=   "( "+ features[i] + " <= " + str(threshold[i])   + " ) "

    recurse(left, right, threshold, features, 0,0,0,0)
print_decision_tree(clf,features)
SqlOut=""
for i in range(len(Results)): 
    SqlOut=SqlOut+Results[i]+ " end,"+chr(13)+chr(10)

1

Тепер ви можете використовувати export_text.

from sklearn.tree import export_text

r = export_text(loan_tree, feature_names=(list(X_train.columns)))
print(r)

Повний приклад з [sklearn] [1]

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

0

Змінено код Zelazny7 для отримання SQL з дерева рішень.

# SQL from decision tree

def get_lineage(tree, feature_names):
     left      = tree.tree_.children_left
     right     = tree.tree_.children_right
     threshold = tree.tree_.threshold
     features  = [feature_names[i] for i in tree.tree_.feature]
     le='<='               
     g ='>'
     # get ids of child nodes
     idx = np.argwhere(left == -1)[:,0]     

     def recurse(left, right, child, lineage=None):          
          if lineage is None:
               lineage = [child]
          if child in left:
               parent = np.where(left == child)[0].item()
               split = 'l'
          else:
               parent = np.where(right == child)[0].item()
               split = 'r'
          lineage.append((parent, split, threshold[parent], features[parent]))
          if parent == 0:
               lineage.reverse()
               return lineage
          else:
               return recurse(left, right, parent, lineage)
     print 'case '
     for j,child in enumerate(idx):
        clause=' when '
        for node in recurse(left, right, child):
            if len(str(node))<3:
                continue
            i=node
            if i[1]=='l':  sign=le 
            else: sign=g
            clause=clause+i[3]+sign+str(i[2])+' and '
        clause=clause[:-4]+' then '+str(j)
        print clause
     print 'else 99 end as clusters'

0

Мабуть, давно хтось уже вирішив спробувати додати наступну функцію до функцій експорту офіційного дерева дерев Scikit (яка в основному підтримує лише export_graphviz)

def export_dict(tree, feature_names=None, max_depth=None) :
    """Export a decision tree in dict format.

Ось його повне зобов'язання:

https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py

Не зовсім впевнений, що сталося з цим коментарем. Але ви також можете спробувати використовувати цю функцію.

Я думаю, що це вимагає серйозного запиту на документацію для хороших людей scikit-навчитися правильно документувати sklearn.tree.TreeAPI, який є базовою структурою дерева, яка DecisionTreeClassifierвиставляється як його атрибут tree_.


0

Просто використовуйте функцію від sklearn.tree, як це

from sklearn.tree import export_graphviz
    export_graphviz(tree,
                out_file = "tree.dot",
                feature_names = tree.columns) //or just ["petal length", "petal width"]

А потім загляньте у папку проекту на файл tree.dot , скопіюйте ВСІ вміст і вставте сюди http://www.webgraphviz.com/ та сгенеруйте свій графік :)


0

Дякую за чудове рішення @paulkerfeld. На вершині свого рішення, для всіх тих , хто хоче мати впорядковану версію дерев, просто використовувати tree.threshold, tree.children_left, tree.children_right, tree.featureі tree.value. Оскільки листя не мають розколи і , отже , не мають імен та дітей, їх заповнювач в tree.featureі tree.children_***в _tree.TREE_UNDEFINEDі _tree.TREE_LEAF. Кожному розколу присвоюється унікальний індекс по depth first search.
Зауважте, що tree.valueце форма[n, 1, 1]


0

Ось функція, яка генерує код Python з дерева рішень шляхом перетворення результату export_text:

import string
from sklearn.tree import export_text

def export_py_code(tree, feature_names, max_depth=100, spacing=4):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    # Second: generate Python code from the text
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree({}):\n'.format(', '.join(features))
    for line in repr(tree).split('\n'):
        code += skip + "# " + line + '\n'
    for line in res.split('\n'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '\n'

    return code

Використання зразка:

res = export_py_code(tree, feature_names=names, spacing=4)
print (res)

Вибірка зразка:

def decision_tree(f1, f2, f3):
    # DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
    #                        max_features=None, max_leaf_nodes=None,
    #                        min_impurity_decrease=0.0, min_impurity_split=None,
    #                        min_samples_leaf=1, min_samples_split=2,
    #                        min_weight_fraction_leaf=0.0, presort=False,
    #                        random_state=42, splitter='best')
    if f1 <= 12.5:
        if f2 <= 17.5:
            if f1 <= 10.5:
                return 2
            if f1 > 10.5:
                return 3
        if f2 > 17.5:
            if f2 <= 22.5:
                return 1
            if f2 > 22.5:
                return 1
    if f1 > 12.5:
        if f1 <= 17.5:
            if f3 <= 23.5:
                return 2
            if f3 > 23.5:
                return 3
        if f1 > 17.5:
            if f1 <= 25:
                return 1
            if f1 > 25:
                return 2

Наведений вище приклад генерується за допомогою names = ['f'+str(j+1) for j in range(NUM_FEATURES)] .

Одна зручна особливість полягає в тому, що він може генерувати менший розмір файлу зі зменшеним інтервалом. Просто встановити spacing=2.

Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.