Hướng dẫn visualize decision tree python - hình dung cây quyết định python

Hướng dẫn visualize decision tree python - hình dung cây quyết định python
Một cây quyết định là một thuật toán được giám sát được sử dụng trong học máy. Nó đang sử dụng biểu đồ cây nhị phân (mỗi nút có hai con) để gán cho mỗi mẫu dữ liệu một giá trị đích. Các giá trị đích được trình bày trong lá cây. Để tiếp cận với lá, mẫu được truyền qua các nút, bắt đầu từ nút gốc. Trong mỗi nút, một quyết định được đưa ra, mà nút hậu duệ mà nó sẽ đi. Một quyết định được đưa ra dựa trên tính năng mẫu được chọn. Học cây quyết định là một quá trình tìm kiếm các quy tắc tối ưu trong mỗi nút cây bên trong theo số liệu đã chọn.

Các cây quyết định có thể được chia, liên quan đến các giá trị mục tiêu, thành:

  • Các cây phân loại được sử dụng để phân loại các mẫu, gán cho một tập hợp các giá trị giới hạn - các lớp. Trong scikit-learn là
    # Fit the classifier with default hyper-parameters
    clf = DecisionTreeClassifier(random_state=1234)
    model = clf.fit(X, y)
    
    1.
  • Các cây hồi quy được sử dụng để gán các mẫu thành các giá trị số trong phạm vi. Trong scikit-learn là
    # Fit the classifier with default hyper-parameters
    clf = DecisionTreeClassifier(random_state=1234)
    model = clf.fit(X, y)
    
    2.

Cây quyết định là một công cụ phổ biến trong phân tích quyết định. Họ có thể hỗ trợ các quyết định nhờ đại diện trực quan của từng quyết định.

Dưới đây tôi chỉ ra 4 cách để trực quan hóa cây quyết định trong Python:

  • In biểu diễn văn bản của cây bằng phương pháp
    # Fit the classifier with default hyper-parameters
    clf = DecisionTreeClassifier(random_state=1234)
    model = clf.fit(X, y)
    
    3
  • sơ đồ với phương pháp
    # Fit the classifier with default hyper-parameters
    clf = DecisionTreeClassifier(random_state=1234)
    model = clf.fit(X, y)
    
    4 (cần thiết matplotlib)
  • sơ đồ với phương thức
    # Fit the classifier with default hyper-parameters
    clf = DecisionTreeClassifier(random_state=1234)
    model = clf.fit(X, y)
    
    5 (cần phải có graphviz)
  • Vẽ với gói
    # Fit the classifier with default hyper-parameters
    clf = DecisionTreeClassifier(random_state=1234)
    model = clf.fit(X, y)
    
    6 (DtreEviz và GraphViz cần)

Tôi sẽ chỉ ra cách hình dung cây trên các nhiệm vụ phân loại và hồi quy.

Đào tạo cây quyết định về nhiệm vụ phân loại

Tôi sẽ đào tạo một

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
1 trên bộ dữ liệu
# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
8. Tôi sẽ sử dụng các tham số siêu mặc định cho trình phân loại.

from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier 
from sklearn import tree

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)

In đại diện văn bản

Xuất cây quyết định vào biểu diễn văn bản có thể hữu ích khi làm việc trên các ứng dụng giao diện người dùng Whitout hoặc khi chúng tôi muốn đăng nhập thông tin về mô hình vào tệp văn bản. Bạn có thể kiểm tra chi tiết về

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
9 trong tài liệu Sklearn.

text_representation = tree.export_text(clf)
print(text_representation)

|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2

Nếu bạn muốn lưu nó vào tệp, nó có thể được thực hiện với mã sau:

with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)

Cây âm mưu với text_representation = tree.export_text(clf) print(text_representation) 0

Phương pháp

text_representation = tree.export_text(clf)
print(text_representation)
0 đã được thêm vào Sklearn trong phiên bản
text_representation = tree.export_text(clf)
print(text_representation)
2. Nó yêu cầu
text_representation = tree.export_text(clf)
print(text_representation)
3 phải được cài đặt. Nó cho phép chúng ta dễ dàng tạo ra con số của cây (không có xuất khẩu trung gian sang graphviz) càng nhiều thông tin về các đối số
text_representation = tree.export_text(clf)
print(text_representation)
0 có trong tài liệu.

fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

Hướng dẫn visualize decision tree python - hình dung cây quyết định python

.

Để lưu hình vào tệp

text_representation = tree.export_text(clf)
print(text_representation)
7:

fig.savefig("decistion_tree.png")

Xin lưu ý rằng tôi đã sử dụng

text_representation = tree.export_text(clf)
print(text_representation)
8 trong
text_representation = tree.export_text(clf)
print(text_representation)
0. Khi tham số này được đặt thành
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
0, phương thức sử dụng màu để chỉ ra phần lớn lớp. (Sẽ rất tuyệt nếu sẽ có một số huyền thoại với lớp học và màu sắc.)

Trực quan hóa cây quyết định với graphviz

Vui lòng đảm bảo rằng bạn đã cài đặt

|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
1 (
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
2). Để vẽ cây trước tiên, chúng ta cần xuất nó sang định dạng DOT bằng phương thức
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
3 (liên kết đến tài liệu). Sau đó, chúng ta có thể vẽ nó trong sổ ghi chép hoặc lưu vào tệp.

import graphviz
# DOT data
dot_data = tree.export_graphviz(clf, out_file=None, 
                                feature_names=iris.feature_names,  
                                class_names=iris.target_names,
                                filled=True)

# Draw graph
graph = graphviz.Source(dot_data, format="png") 
graph

graph.render("decision_tree_graphivz")

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
0

Cây quyết định cốt truyện với gói # Fit the classifier with default hyper-parameters clf = DecisionTreeClassifier(random_state=1234) model = clf.fit(X, y) 6

Gói

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
6 có sẵn trong GitHub. Nó có thể được cài đặt với
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
6. Nó yêu cầu
|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
1 phải được cài đặt (nhưng bạn không cần phải chuyển đổi thủ công giữa các tệp chấm và hình ảnh). Để vẽ cây chỉ chạy:

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
1

Lưu trực quan hóa vào tệp:

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
2

Hình dung cây quyết định trong nhiệm vụ hồi quy

Dưới đây, tôi trình bày tất cả 4 phương pháp cho

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
2 từ gói scikit-learn (tất nhiên là trong Python).

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
3

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
4

Để giữ kích thước của cây nhỏ, tôi đã đặt

|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2
9.

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
5

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
6

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
7

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
8

Hướng dẫn visualize decision tree python - hình dung cây quyết định python

Xin lưu ý rằng màu sắc của lá là giá trị dự đoán.

# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
9

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
0

Từ các phương pháp trên yêu thích của tôi là trực quan hóa với gói

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
6. Tôi thích nó vì:

  • Nó hiển thị phân phối tính năng quyết định trong mỗi nút (NICE!)
  • nó cho thấy huyền thoại phù hợp màu sắc lớp
  • Nó cho thấy sự phân phối của lớp trong lá trong trường hợp các nhiệm vụ phân loại và giá trị trung bình của phản hồi lá trong trường hợp các nhiệm vụ hồi quy

Thật tuyệt vời khi có trực quan hóa

# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
6 ở chế độ tương tác, vì vậy người dùng có thể tự động thay đổi độ sâu của cây. Tôi đã sử dụng gói
# Fit the classifier with default hyper-parameters
clf = DecisionTreeClassifier(random_state=1234)
model = clf.fit(X, y)
6 trong gói Python học máy tự động học (Automl)
with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)
3. Bạn có thể kiểm tra các chi tiết của việc triển khai trong kho GitHub. Một điều quan trọng là, trong gói tự động của tôi, tôi không sử dụng cây quyết định với
with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)
4 lớn hơn
with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)
5. Tôi thêm giới hạn này để không có những cây quá lớn, theo tôi, điều này làm mất đi khả năng hiểu rõ những gì xảy ra trong mô hình. Dưới đây là ví dụ về báo cáo Markdown cho cây quyết định được tạo bởi
with open("decistion_tree.log", "w") as fout:
    fout.write(text_representation)
3.

Hướng dẫn visualize decision tree python - hình dung cây quyết định python



💌 Tham gia bản tin của chúng tôi 💌

Đăng ký nhận bản tin của chúng tôi để nhận cập nhật sản phẩm


Hướng dẫn visualize decision tree python - hình dung cây quyết định python

Chia sẻ sổ ghi chép Python của bạn với những người khác