Back to Blog

[ML101] Chương 5: Cây Quyết Định (Decision Trees)

Giới thiệu thuật toán Cây Quyết Định và cách hoạt động

Bài viết có tham khảo, sử dụng và sửa đổi tài nguyên từ kho lưu trữ handson-mlp, tuân thủ giấy phép Apache‑2.0. Chúng tôi chân thành cảm ơn tác giả Aurélien Géron (@aureliengeron) vì sự chia sẻ kiến thức tuyệt vời và những đóng góp quý giá cho cộng đồng.

Chương này sẽ hướng dẫn bạn tìm hiểu về Cây Quyết Định (Decision Trees), một thuật toán học máy linh hoạt, mạnh mẽ, có khả năng thực hiện cả nhiệm vụ phân loại (classification) và hồi quy (regression), cũng như xử lý các bài toán có nhiều đầu ra (multioutput tasks). Cây quyết định cũng là thành phần cơ bản của Rừng Ngẫu nhiên (Random Forests), một trong những thuật toán học máy mạnh mẽ nhất hiện nay.

Trong chương này, chúng ta sẽ bắt đầu bằng cách thảo luận về cách huấn luyện, trực quan hóa và đưa ra dự đoán với cây quyết định. Sau đó, chúng ta sẽ xem xét thuật toán huấn luyện CART được sử dụng bởi Scikit-Learn, cũng như các phương pháp điều chuẩn (regularization) và áp dụng cho bài toán hồi quy. Cuối cùng, chúng ta sẽ bàn về một số hạn chế của cây quyết định.

Bạn có thể chạy trực tiếp các đoạn mã code tại: Google Colab.

Thiết lập môi trường (Setup)

Trước khi bắt đầu, vẫn như thường lệ, chúng ta cần đảm bảo môi trường Python đáp ứng các yêu cầu về phiên bản. Dự án này yêu cầu Python 3.10 trở lên và Scikit-Learn phiên bản 1.6.1 trở lên để đảm bảo tính tương thích của các hàm API mới.

import sys

# Kiểm tra phiên bản Python, yêu cầu >= 3.10
assert sys.version_info >= (3, 10)

Tiếp theo, chúng ta kiểm tra phiên bản của thư viện Scikit-Learn:

from packaging.version import Version
import sklearn

# Kiểm tra phiên bản Scikit-Learn, yêu cầu >= 1.6.1
assert Version(sklearn.__version__) >= Version("1.6.1")

Như đã làm trong các chương trước, chúng ta sẽ định nghĩa kích thước phông chữ mặc định cho các biểu đồ được vẽ bởi matplotlib để chúng trông rõ ràng và đẹp mắt hơn.

import matplotlib.pyplot as plt

# Thiết lập kích thước phông chữ cho các thành phần biểu đồ
plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)

Huấn luyện và Trực quan hóa Cây Quyết Định

Để hiểu rõ về Cây Quyết Định, hãy bắt đầu bằng việc xây dựng một cây và xem cách nó đưa ra dự đoán. Chúng ta sẽ sử dụng bộ dữ liệu Iris (hoa diên vĩ) kinh điển. Cụ thể, mô hình sẽ học cách phân loại các loài hoa dựa trên độ dài và độ rộng cánh hoa (petal).

Dưới đây là cách huấn luyện một DecisionTreeClassifier với độ sâu tối đa (max_depth) bằng 2.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

# Tải bộ dữ liệu Iris
iris = load_iris(as_frame=True)
# Lấy hai đặc trưng: chiều dài và chiều rộng cánh hoa
X_iris = iris.data[["petal length (cm)", "petal width (cm)"]].values
y_iris = iris.target

# Khởi tạo mô hình Cây Quyết Định với độ sâu tối đa là 2
tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
# Huấn luyện mô hình với dữ liệu
tree_clf.fit(X_iris, y_iris)
DecisionTreeClassifier(max_depth=2, random_state=42)

Sau khi huấn luyện, chúng ta có thể trực quan hóa cây quyết định bằng hàm export_graphviz. Hàm này xuất ra một tệp định nghĩa đồ thị .dot, sau đó chúng ta có thể sử dụng thư viện graphviz để hiển thị nó.

Biểu đồ dưới đây (tương ứng với Hình 5–1 trong sách) cho thấy cấu trúc logic của cây:

from sklearn.tree import export_graphviz

# Xuất mô hình cây ra file định dạng .dot
export_graphviz(
        tree_clf,
        out_file="my_iris_tree.dot",
        feature_names=["petal length (cm)", "petal width (cm)"],
        class_names=iris.target_names,
        rounded=True,
        filled=True
    )

from graphviz import Source
# Hiển thị cây từ file .dot
Source.from_file("my_iris_tree.dot")

svg

Công cụ dòng lệnh dot từ gói Graphviz cũng có thể được sử dụng để chuyển đổi tệp .dot sang các định dạng ảnh như PNG hoặc PDF. Dòng lệnh dưới đây thực hiện việc chuyển đổi sang PNG:

# Mã bổ sung: Chuyển đổi file .dot sang .png bằng command line tool
!dot -Tpng "my_iris_tree.dot" -o "my_iris_tree.png"

Đưa ra Dự đoán (Making Predictions)

Khi quan sát cây quyết định ở trên, bạn có thể thấy cách nó hoạt động như một lưu đồ (flowchart). Bắt đầu từ nút gốc (root node) (độ sâu 0, ở trên cùng):

  1. Cây kiểm tra xem chiều dài cánh hoa có nhỏ hơn hoặc bằng 2.45 cm hay không.
  2. Nếu Đúng, bạn di chuyển xuống nút con bên trái (độ sâu 1). Nút này là một nút lá (leaf node) (nó không có nút con nào nữa), nghĩa là nó không hỏi thêm câu hỏi nào và đơn giản là gán nhãn dự đoán cho lớp đó (cụ thể là Iris setosa).
  3. Nếu Sai, bạn di chuyển xuống nút con bên phải (độ sâu 1). Tại đây, cây tiếp tục hỏi: chiều rộng cánh hoa có nhỏ hơn hoặc bằng 1.75 cm hay không?
    • Nếu đúng, dự đoán là Iris versicolor.
    • Nếu sai, dự đoán là Iris virginica.

Đoạn mã dưới đây minh họa ranh giới quyết định (decision boundaries) của cây. Các đường thẳng đứng và ngang tương ứng với các ngưỡng (thresholds) tại mỗi nút phân chia.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Định nghĩa bảng màu tùy chỉnh
custom_cmap = ListedColormap(['#fafab0', '#9898ff', '#a0faa0'])

plt.figure(figsize=(8, 4))

# Tạo lưới điểm để vẽ ranh giới quyết định
lengths, widths = np.meshgrid(np.linspace(0, 7.2, 100), np.linspace(0, 3, 100))
X_iris_all = np.c_[lengths.ravel(), widths.ravel()]

# Dự đoán cho toàn bộ lưới điểm
y_pred = tree_clf.predict(X_iris_all).reshape(lengths.shape)

# Vẽ vùng quyết định
plt.contourf(lengths, widths, y_pred, alpha=0.3, cmap=custom_cmap)

# Vẽ các điểm dữ liệu thực tế
for idx, (name, style) in enumerate(zip(iris.target_names, ("yo", "bs", "g^"))):
    plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],
             style, label=f"Iris {name}")

# Mã bổ sung: Làm đẹp biểu đồ (tương ứng Hình 5-2)
# Huấn luyện một cây sâu hơn để lấy các ngưỡng (chỉ dùng cho mục đích minh họa)
tree_clf_deeper = DecisionTreeClassifier(max_depth=3, random_state=42)
tree_clf_deeper.fit(X_iris, y_iris)
th0, th1, th2a, th2b = tree_clf_deeper.tree_.threshold[[0, 2, 3, 6]]

# Thiết lập nhãn trục
plt.xlabel("Petal length (cm)")
plt.ylabel("Petal width (cm)")

# Vẽ các đường ranh giới phân chia
plt.plot([th0, th0], [0, 3], "k-", linewidth=2)
plt.plot([th0, 7.2], [th1, th1], "k--", linewidth=2)
plt.plot([th2a, th2a], [0, th1], "k:", linewidth=2)
plt.plot([th2b, th2b], [th1, 3], "k:", linewidth=2)

# Thêm chú thích độ sâu (Depth)
plt.text(th0 - 0.05, 1.0, "Depth=0", horizontalalignment="right", fontsize=15)
plt.text(3.2, th1 + 0.02, "Depth=1", verticalalignment="bottom", fontsize=13)
plt.text(th2a + 0.05, 0.5, "(Depth=2)", fontsize=11)

plt.axis([0, 7.2, 0, 3])
plt.legend()
plt.show()

png

Bạn có thể truy cập trực tiếp vào cấu trúc bên trong của cây thông qua thuộc tính tree_. Đây là một đối tượng chứa các mảng NumPy lưu trữ thông tin về các nút, ngưỡng, và giá trị.

# Truy cập đối tượng cấu trúc cây
tree_clf.tree_
<sklearn.tree._tree.Tree at 0x79499dea5920>

Để tìm hiểu chi tiết hơn về cấu trúc này, bạn có thể tham khảo tài liệu hướng dẫn của lớp sklearn.tree._tree.Tree hoặc xem phần “Tài liệu bổ sung” ở cuối chương này.

# help(sklearn.tree._tree.Tree)

Ước lượng Xác suất Lớp (Estimating Class Probabilities)

Cây Quyết Định cũng có thể ước lượng xác suất mà một mẫu dữ liệu thuộc về một lớp cụ thể k. Nó thực hiện điều này bằng cách duyệt qua cây để tìm nút lá chứa mẫu dữ liệu đó, sau đó trả về tỷ lệ các mẫu huấn luyện của lớp k trong nút lá này.

Ví dụ, nếu chúng ta tìm xác suất cho một bông hoa có cánh dài 5cm và rộng 1.5cm:

# Dự đoán xác suất cho mẫu [5, 1.5]
tree_clf.predict_proba([[5, 1.5]]).round(3)
array([[0.   , 0.907, 0.093]])

Và để dự đoán lớp cụ thể (lớp có xác suất cao nhất):

# Dự đoán nhãn lớp cho mẫu [5, 1.5]
tree_clf.predict([[5, 1.5]])
array([1])

Các Siêu tham số Điều chuẩn (Regularization Hyperparameters)

Cây Quyết Định đưa ra rất ít giả định về dữ liệu huấn luyện (không giống như hồi quy tuyến tính giả định dữ liệu là tuyến tính). Nếu không bị hạn chế, cấu trúc cây sẽ thích ứng quá mức với dữ liệu huấn luyện, dẫn đến hiện tượng quá khớp (overfitting).

Để tránh điều này, chúng ta cần hạn chế sự tự do của cây trong quá trình huấn luyện thông qua điều chuẩn (regularization). Các siêu tham số phổ biến bao gồm:

  • max_depth: Độ sâu tối đa của cây.
  • min_samples_leaf: Số lượng mẫu tối thiểu cần thiết tại một nút lá.
  • min_samples_split: Số lượng mẫu tối thiểu cần thiết để phân chia một nút.

Dưới đây, chúng ta sẽ so sánh hai cây được huấn luyện trên bộ dữ liệu moons: một cây không có ràng buộc (mặc định) và một cây có min_samples_leaf=5.

from sklearn.datasets import make_moons

# Tạo dữ liệu moons (dữ liệu hình trăng khuyết xen kẽ)
X_moons, y_moons = make_moons(n_samples=150, noise=0.2, random_state=42)

# Cây 1: Không có ràng buộc (mặc định)
tree_clf1 = DecisionTreeClassifier(random_state=42)

# Cây 2: Có ràng buộc min_samples_leaf=5
tree_clf2 = DecisionTreeClassifier(min_samples_leaf=5, random_state=42)

# Huấn luyện cả hai cây
tree_clf1.fit(X_moons, y_moons)
tree_clf2.fit(X_moons, y_moons)
DecisionTreeClassifier(min_samples_leaf=5, random_state=42)

Đoạn mã dưới đây vẽ ranh giới quyết định cho cả hai mô hình (Hình 5-3). Bạn có thể thấy mô hình bên trái (không giới hạn) bị quá khớp rõ rệt, tạo ra các ranh giới phức tạp bao quanh các điểm nhiễu. Mô hình bên phải (được điều chuẩn) tổng quát hóa tốt hơn.

# Hàm hỗ trợ vẽ ranh giới quyết định
def plot_decision_boundary(clf, X, y, axes, cmap):
    x1, x2 = np.meshgrid(np.linspace(axes[0], axes[1], 100),
                         np.linspace(axes[2], axes[3], 100))
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)

    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=cmap)
    plt.contour(x1, x2, y_pred, cmap="Greys", alpha=0.8)
    colors = {"Wistia": ["#78785c", "#c47b27"], "Pastel1": ["red", "blue"]}
    markers = ("o", "^")
    for idx in (0, 1):
        plt.plot(X[:, 0][y == idx], X[:, 1][y == idx],
                 color=colors[cmap][idx], marker=markers[idx], linestyle="none")
    plt.axis(axes)
    plt.xlabel(r"$x_1$")
    plt.ylabel(r"$x_2$", rotation=0)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)

# Vẽ cây không giới hạn
plt.sca(axes[0])
plot_decision_boundary(tree_clf1, X_moons, y_moons,
                       axes=[-1.5, 2.4, -1, 1.5], cmap="Wistia")
plt.title("No restrictions")

# Vẽ cây có giới hạn min_samples_leaf
plt.sca(axes[1])
plot_decision_boundary(tree_clf2, X_moons, y_moons,
                       axes=[-1.5, 2.4, -1, 1.5], cmap="Wistia")
plt.title(f"min_samples_leaf = {tree_clf2.min_samples_leaf}")
plt.ylabel("")
plt.show()

png

Bây giờ, hãy kiểm chứng khả năng tổng quát hóa bằng tập dữ liệu kiểm tra (test set). Mô hình được điều chuẩn thường sẽ có độ chính xác cao hơn trên dữ liệu mới.

# Tạo tập test lớn hơn để đánh giá
X_moons_test, y_moons_test = make_moons(n_samples=1000, noise=0.2,
                                        random_state=43)

# Điểm số của cây không giới hạn
print(tree_clf1.score(X_moons_test, y_moons_test))

# Điểm số của cây có điều chuẩn
print(tree_clf2.score(X_moons_test, y_moons_test))
0.898
0.92

Hồi quy (Regression)

Cây Quyết Định cũng có thể thực hiện bài toán hồi quy (dự đoán một giá trị liên tục thay vì một lớp). Hãy thử huấn luyện một cây hồi quy DecisionTreeRegressor trên tập dữ liệu bậc hai có nhiễu.

import numpy as np
from sklearn.tree import DecisionTreeRegressor

# Tạo dữ liệu bậc 2 ngẫu nhiên
rng = np.random.default_rng(seed=42)
X_quad = rng.random((200, 1)) - 0.5  # Một đặc trưng đầu vào ngẫu nhiên
y_quad = X_quad ** 2 + 0.025 * rng.standard_normal((200, 1)) # y = x^2 + nhiễu

# Huấn luyện cây hồi quy với độ sâu 2
tree_reg = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg.fit(X_quad, y_quad)
DecisionTreeRegressor(max_depth=2, random_state=42)

Cấu trúc của cây hồi quy trông rất giống cây phân loại, nhưng thay vì dự đoán một lớp, nó dự đoán một giá trị (value). Giá trị này chính là trung bình của các mẫu huấn luyện rơi vào nút lá đó.

# Xuất cây hồi quy ra file .dot
export_graphviz(
    tree_reg,
    out_file="my_regression_tree.dot",
    feature_names=["x1"],
    rounded=True,
    filled=True)

# Hiển thị cây
Source.from_file("my_regression_tree.dot")

svg

Chúng ta cũng sẽ huấn luyện một cây thứ hai với độ sâu lớn hơn (max_depth=3) để so sánh.

tree_reg2 = DecisionTreeRegressor(max_depth=3, random_state=42)
tree_reg2.fit(X_quad, y_quad)
DecisionTreeRegressor(max_depth=3, random_state=42)

Bạn có thể xem các ngưỡng (thresholds) mà cây sử dụng để phân chia dữ liệu:

# Ngưỡng của cây độ sâu 2
tree_reg.tree_.threshold
array([ 0.34304063, -0.30182856, -2.        , -2.        ,  0.43140428,
       -2.        , -2.        ])
# Ngưỡng của cây độ sâu 3
tree_reg2.tree_.threshold
array([ 0.34304063, -0.30182856, -0.41395289, -2.        , -2.        ,
        0.21817657, -2.        , -2.        ,  0.43140428,  0.39480372,
       -2.        , -2.        ,  0.46470371, -2.        , -2.        ])

Hình 5-5 dưới đây minh họa sự khác biệt giữa hai mô hình. Các dự đoán của cây hồi quy tạo thành một đường bậc thang (step function). Cây càng sâu, đường dự đoán càng chi tiết nhưng cũng dễ bị quá khớp.

# Hàm vẽ dự đoán của cây hồi quy
def plot_regression_predictions(tree_reg, X, y, axes=[-0.5, 0.5, -0.05, 0.25]):
    x1 = np.linspace(axes[0], axes[1], 500).reshape(-1, 1)
    y_pred = tree_reg.predict(x1)
    plt.axis(axes)
    plt.xlabel("$x_1$")
    plt.plot(X, y, "b.")
    plt.plot(x1, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)

# Biểu đồ cho cây depth=2
plt.sca(axes[0])
plot_regression_predictions(tree_reg, X_quad, y_quad)
th0, th1a, th1b = tree_reg.tree_.threshold[[0, 1, 4]]
for split, style in ((th0, "k-"), (th1a, "k--"), (th1b, "k--")):
    plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)
plt.text(th0, 0.16, "Depth=0", fontsize=15)
plt.text(th1a + 0.01, -0.01, "Depth=1", horizontalalignment="center", fontsize=13)
plt.text(th1b + 0.01, -0.01, "Depth=1", fontsize=13)
plt.ylabel("$y$", rotation=0)
plt.legend(loc="upper center", fontsize=16)
plt.title("max_depth=2")

# Biểu đồ cho cây depth=3
plt.sca(axes[1])
th2s = tree_reg2.tree_.threshold[[2, 5, 9, 12]]
plot_regression_predictions(tree_reg2, X_quad, y_quad)
for split, style in ((th0, "k-"), (th1a, "k--"), (th1b, "k--")):
    plt.plot([split, split], [-0.05, 0.25], style, linewidth=2)
for split in th2s:
    plt.plot([split, split], [-0.05, 0.25], "k:", linewidth=1)
plt.text(th2s[2] + 0.01, 0.15, "Depth=2", fontsize=13)
plt.title("max_depth=3")
plt.show()

png

Cũng giống như bài toán phân loại, cây hồi quy rất dễ bị quá khớp nếu không có điều chuẩn. Hình 5-6 so sánh một cây không có ràng buộc (trái) và một cây có min_samples_leaf=10 (phải).

# Tạo 2 cây hồi quy để so sánh điều chuẩn
tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42, min_samples_leaf=10)

tree_reg1.fit(X_quad, y_quad)
tree_reg2.fit(X_quad, y_quad)

x1 = np.linspace(-0.5, 0.5, 500).reshape(-1, 1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)

# Cây không giới hạn: Quá khớp nghiêm trọng
plt.sca(axes[0])
plt.plot(X_quad, y_quad, "b.")
plt.plot(x1, y_pred1, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([-0.5, 0.5, -0.05, 0.25])
plt.xlabel("$x_1$")
plt.ylabel("$y$", rotation=0)
plt.legend(loc="upper center")
plt.title("No restrictions")

# Cây có điều chuẩn: Mượt mà hơn
plt.sca(axes[1])
plt.plot(X_quad, y_quad, "b.")
plt.plot(x1, y_pred2, "r.-", linewidth=2, label=r"$\hat{y}$")
plt.axis([-0.5, 0.5, -0.05, 0.25])
plt.xlabel("$x_1$")
plt.title(f"min_samples_leaf={tree_reg2.min_samples_leaf}")
plt.show()

png

Độ nhạy cảm với hướng trục (Sensitivity to Axis Orientation)

Một hạn chế của Cây Quyết Định là chúng thích các ranh giới quyết định vuông góc (orthogonal), nghĩa là tất cả các phân chia đều vuông góc với một trục tọa độ. Điều này khiến chúng nhạy cảm với việc xoay dữ liệu.

Ví dụ dưới đây cho thấy một tập dữ liệu phân tách tuyến tính đơn giản. Ở bên trái, ranh giới quyết định rất đơn giản. Nhưng khi xoay dữ liệu đi 45 độ (bên phải), ranh giới trở nên hình bậc thang không cần thiết.

# Tạo dữ liệu hình vuông ngẫu nhiên
rng = np.random.default_rng(seed=42)
X_square = rng.random((100, 2)) - 0.5
y_square = (X_square[:, 0] > 0).astype(np.int64)

# Xoay dữ liệu 45 độ
angle = np.pi / 4  # 45 degrees
rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)],
                            [np.sin(angle), np.cos(angle)]])
X_rotated_square = X_square.dot(rotation_matrix)

# Huấn luyện trên dữ liệu gốc
tree_clf_square = DecisionTreeClassifier(random_state=42)
tree_clf_square.fit(X_square, y_square)

# Huấn luyện trên dữ liệu đã xoay
tree_clf_rotated_square = DecisionTreeClassifier(random_state=42)
tree_clf_rotated_square.fit(X_rotated_square, y_square)

fig, axes = plt.subplots(ncols=2, figsize=(10, 4), sharey=True)

# Vẽ ranh giới dữ liệu gốc
plt.sca(axes[0])
plot_decision_boundary(tree_clf_square, X_square, y_square,
                       axes=[-0.7, 0.7, -0.7, 0.7], cmap="Pastel1")

# Vẽ ranh giới dữ liệu xoay
plt.sca(axes[1])
plot_decision_boundary(tree_clf_rotated_square, X_rotated_square, y_square,
                       axes=[-0.7, 0.7, -0.7, 0.7], cmap="Pastel1")
plt.ylabel("")
plt.show()

png

Một cách để giảm thiểu vấn đề này là sử dụng Phân tích Thành phần Chính (PCA) để xoay dữ liệu theo hướng tốt hơn trước khi đưa vào mô hình cây.

from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

# Tạo pipeline: Chuẩn hóa -> PCA -> Cây quyết định
pca_pipeline = make_pipeline(StandardScaler(), PCA())
X_iris_rotated = pca_pipeline.fit_transform(X_iris)
tree_clf_pca = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf_pca.fit(X_iris_rotated, y_iris)
DecisionTreeClassifier(max_depth=2, random_state=42)

Sau khi áp dụng PCA, ranh giới quyết định trên tập dữ liệu Iris đã xoay trở nên đơn giản hơn nhiều (Hình 5-8).

plt.figure(figsize=(8, 4))
axes = [-2.2, 2.4, -0.6, 0.7]

# Tạo lưới điểm trên không gian đã biến đổi PCA
z0s, z1s = np.meshgrid(np.linspace(axes[0], axes[1], 100),
                       np.linspace(axes[2], axes[3], 100))
X_iris_pca_all = np.c_[z0s.ravel(), z1s.ravel()]
y_pred = tree_clf_pca.predict(X_iris_pca_all).reshape(z0s.shape)

plt.contourf(z0s, z1s, y_pred, alpha=0.3, cmap=custom_cmap)

# Vẽ dữ liệu đã xoay
for idx, (name, style) in enumerate(zip(iris.target_names, ("yo", "bs", "g^"))):
    plt.plot(X_iris_rotated[:, 0][y_iris == idx],
             X_iris_rotated[:, 1][y_iris == idx],
             style, label=f"Iris {name}")

plt.xlabel("$z_1$")
plt.ylabel("$z_2$", rotation=0)

# Vẽ đường phân chia của cây
th1, th2 = tree_clf_pca.tree_.threshold[[0, 2]]
plt.plot([th1, th1], axes[2:], "k-", linewidth=2)
plt.plot([th2, th2], axes[2:], "k--", linewidth=2)
plt.text(th1 - 0.01, axes[2] + 0.05, "Depth=0",
         horizontalalignment="right", fontsize=15)
plt.text(th2 - 0.01, axes[2] + 0.05, "Depth=1",
         horizontalalignment="right", fontsize=13)

plt.axis(axes)
plt.legend(loc=(0.32, 0.67))
plt.show()

png

Cây Quyết Định có Phương sai Cao (High Variance)

Chúng ta đã thấy rằng những thay đổi nhỏ trong dữ liệu (như phép xoay) có thể tạo ra cây quyết định rất khác biệt. Bây giờ, chúng ta sẽ thấy rằng ngay cả khi huấn luyện trên cùng một dữ liệu, mô hình cũng có thể khác nhau nếu tham số ngẫu nhiên thay đổi. Điều này là do thuật toán CART sử dụng yếu tố ngẫu nhiên (stochastic).

Dưới đây, chúng ta thay đổi random_state thành 40 và nhận được một mô hình hoàn toàn khác so với ban đầu.

# Huấn luyện lại với random_state khác
tree_clf_tweaked = DecisionTreeClassifier(max_depth=2, random_state=40)
tree_clf_tweaked.fit(X_iris, y_iris)
DecisionTreeClassifier(max_depth=2, random_state=40)
# Vẽ lại cây với random_state mới để thấy sự khác biệt (Hình 5-9)
plt.figure(figsize=(8, 4))
y_pred = tree_clf_tweaked.predict(X_iris_all).reshape(lengths.shape)
plt.contourf(lengths, widths, y_pred, alpha=0.3, cmap=custom_cmap)

for idx, (name, style) in enumerate(zip(iris.target_names, ("yo", "bs", "g^"))):
    plt.plot(X_iris[:, 0][y_iris == idx], X_iris[:, 1][y_iris == idx],
             style, label=f"Iris {name}")

th0, th1 = tree_clf_tweaked.tree_.threshold[[0, 2]]
plt.plot([0, 7.2], [th0, th0], "k-", linewidth=2)
plt.plot([0, 7.2], [th1, th1], "k--", linewidth=2)

plt.text(1.8, th0 + 0.05, "Depth=0", verticalalignment="bottom", fontsize=15)
plt.text(2.3, th1 + 0.05, "Depth=1", verticalalignment="bottom", fontsize=13)

plt.xlabel("Petal length (cm)")
plt.ylabel("Petal width (cm)")
plt.axis([0, 7.2, 0, 3])
plt.legend()
plt.show()

png

Tài liệu bổ sung: Cấu trúc cây (tree structure)

Phần này dành cho những ai muốn hiểu sâu hơn về cách Scikit-Learn lưu trữ cây quyết định. Thuộc tính tree_ của một mô hình đã huấn luyện chứa toàn bộ thông tin cấu trúc.

tree = tree_clf.tree_
tree
<sklearn.tree._tree.Tree at 0x79499dea5920>

Bạn có thể lấy tổng số nút trong cây:

tree.node_count
5

Và các thuộc tính tự giải thích khác:

tree.max_depth
2
tree.max_n_classes
3
tree.n_features
2
tree.n_outputs
1
tree.n_leaves
np.int64(3)

Tất cả thông tin về các nút được lưu trữ trong các mảng NumPy. Ví dụ, độ vẩn đục (impurity) của mỗi nút:

tree.impurity
array([0.66666667, 0.        , 0.5       , 0.16803841, 0.04253308])

Nút gốc có chỉ số là 0. Các nút con trái và phải của nút i được lưu tại tree.children_left[i]tree.children_right[i]. Ví dụ, con của nút gốc là:

tree.children_left[0], tree.children_right[0]
(np.int64(1), np.int64(2))

Khi nút con trái và phải bằng nhau (thường là -1), điều đó có nghĩa đây là nút lá:

tree.children_left[3], tree.children_right[3]
(np.int64(-1), np.int64(-1))

Vậy nên bạn có thể lấy danh sách ID của các nút lá như sau:

is_leaf = (tree.children_left == tree.children_right)
np.arange(tree.node_count)[is_leaf]
array([1, 3, 4])

Các nút không phải lá được gọi là nút phân chia (split nodes). Đặc trưng dùng để phân chia được lưu trong mảng feature (giá trị của nút lá nên bị bỏ qua):

tree.feature
array([ 0, -2,  1, -2, -2], dtype=int64)

Và các ngưỡng tương ứng là:

tree.threshold
array([ 2.44999999, -2.        ,  1.75      , -2.        , -2.        ])

Số lượng mẫu của mỗi lớp đi qua mỗi nút cũng có sẵn:

tree.value
array([[[0.33333333, 0.33333333, 0.33333333]],

       [[1.        , 0.        , 0.        ]],

       [[0.        , 0.5       , 0.5       ]],

       [[0.        , 0.90740741, 0.09259259]],

       [[0.        , 0.02173913, 0.97826087]]])
tree.n_node_samples
array([150,  50, 100,  54,  46], dtype=int64)
# Kiểm tra tính nhất quán: tổng samples theo class phải bằng n_node_samples
np.all(tree.value.sum(axis=(1, 2)) == tree.n_node_samples)
np.False_

Đoạn mã dưới đây minh họa cách tính độ sâu của mỗi nút bằng cách duyệt cây:

def compute_depth(tree_clf):
    tree = tree_clf.tree_
    depth = np.zeros(tree.node_count)
    stack = [(0, 0)]
    while stack:
        node, node_depth = stack.pop()
        depth[node] = node_depth
        if tree.children_left[node] != tree.children_right[node]:
            stack.append((tree.children_left[node], node_depth + 1))
            stack.append((tree.children_right[node], node_depth + 1))
    return depth

depth = compute_depth(tree_clf)
depth
array([0., 1., 1., 2., 2.])

Ví dụ: Lấy ngưỡng của tất cả các nút phân chia ở độ sâu 1:

tree_clf.tree_.feature[(depth == 1) & (~is_leaf)]
array([1], dtype=int64)
tree_clf.tree_.threshold[(depth == 1) & (~is_leaf)]
array([1.75])

Ôn tập

1. Độ sâu xấp xỉ của cây quyết định: Độ sâu của một cây nhị phân cân bằng có mm lá xấp xỉ bằng log2(m)\log_2(m). Nếu tập huấn luyện có 1 triệu mẫu (m=106m=10^6) và cây không bị giới hạn (mỗi lá 1 mẫu), độ sâu sẽ vào khoảng log2(106)20\log_2(10^6) \approx 20.

2. Gini impurity của nút con so với nút cha: Thông thường, Gini impurity của nút con thấp hơn nút cha vì thuật toán CART cố gắng giảm thiểu tạp chất. Tuy nhiên, vẫn có trường hợp Gini impurity của một nút con cao hơn nút cha, miễn là mức giảm tạp chất ở nút con còn lại đủ lớn để bù đắp, làm cho tổng tạp chất có trọng số giảm xuống.

3. Ngăn chặn Overfitting: Nếu cây đang bị quá khớp (overfitting), bạn nên giảm max_depth. Điều này hạn chế khả năng phát triển phức tạp của cây.

4. Scaling dữ liệu: Cây quyết định không quan tâm đến việc dữ liệu có được chuẩn hóa (scaled) hay không. Vì vậy, việc chuẩn hóa dữ liệu khi mô hình đang underfitting sẽ không giúp ích gì.

5. Độ phức tạp tính toán: Độ phức tạp là O(n×mlog2(m))O(n \times m \log_2(m)). Nếu tăng kích thước tập huấn luyện lên 10 lần, thời gian huấn luyện sẽ tăng khoảng 11.711.7 lần (chứ không phải 10 lần).

6. Ảnh hưởng của số lượng đặc trưng: Nếu số lượng đặc trưng tăng gấp đôi, thời gian huấn luyện cũng sẽ tăng gấp đôi.

Bài tập thực hành 1: Tinh chỉnh Hyperparameters cho dataset Moons

a. Tạo dữ liệu: Sử dụng make_moons với n_samples=10000noise=0.4.

from sklearn.datasets import make_moons

# Tạo dữ liệu moons lớn
X_moons, y_moons = make_moons(n_samples=10000, noise=0.4, random_state=42)

b. Chia tập dữ liệu: Sử dụng train_test_split để chia thành tập huấn luyện và kiểm tra.

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X_moons, y_moons,
                                                    test_size=0.2,
                                                    random_state=42)

c. Tìm kiếm lưới (Grid Search): Sử dụng GridSearchCV để tìm các siêu tham số tốt nhất (ví dụ: max_leaf_nodes).

from sklearn.model_selection import GridSearchCV

params = {
    'max_leaf_nodes': list(range(2, 100)),
    'max_depth': [1, 2, 3, 4, 5, 6],
    'min_samples_split': [2, 3, 4]}

grid_search_cv = GridSearchCV(DecisionTreeClassifier(random_state=42),
                              params,
                              cv=3)

grid_search_cv.fit(X_train, y_train)
GridSearchCV(cv=3, estimator=DecisionTreeClassifier(random_state=42),
             param_grid={'max_depth': [1, 2, 3, 4, 5, 6],
                         'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                                            13, 14, 15, 16, 17, 18, 19, 20, 21,
                                            22, 23, 24, 25, 26, 27, 28, 29, 30,
                                            31, ...],
                         'min_samples_split': [2, 3, 4]})
# Hiển thị bộ tham số tốt nhất tìm được
grid_search_cv.best_estimator_
DecisionTreeClassifier(max_depth=6, max_leaf_nodes=17, random_state=42)

d. Đánh giá mô hình: Mặc định GridSearchCV sẽ huấn luyện lại mô hình tốt nhất trên toàn bộ tập train. Chúng ta chỉ cần đánh giá trên tập test.

from sklearn.metrics import accuracy_score

y_pred = grid_search_cv.predict(X_test)
accuracy_score(y_test, y_pred)
0.8595

Bài tập thực hành 2: Trồng một Rừng cây (Grow a Forest)

Bài tập này hướng dẫn bạn tạo ra một mô hình Random Forest thủ công.

a. Tạo các tập con: Tạo 1,000 tập con từ tập huấn luyện, mỗi tập chứa 100 mẫu ngẫu nhiên.

from sklearn.model_selection import ShuffleSplit

n_trees = 1000
n_instances = 100
mini_sets = []

# Sử dụng ShuffleSplit để lấy ngẫu nhiên các tập con
rs = ShuffleSplit(n_splits=n_trees, test_size=len(X_train) - n_instances,
                  random_state=42)

for mini_train_index, mini_test_index in rs.split(X_train):
    X_mini_train = X_train[mini_train_index]
    y_mini_train = y_train[mini_train_index]
    mini_sets.append((X_mini_train, y_mini_train))

b. Huấn luyện các cây: Huấn luyện một cây quyết định cho mỗi tập con, sử dụng siêu tham số tốt nhất đã tìm được ở bài trước.

from sklearn.base import clone

# Tạo danh sách 1000 cây (bản sao của mô hình tốt nhất)
forest = [clone(grid_search_cv.best_estimator_) for _ in range(n_trees)]
accuracy_scores = []

for tree, (X_mini_train, y_mini_train) in zip(forest, mini_sets):
    tree.fit(X_mini_train, y_mini_train)

    y_pred = tree.predict(X_test)
    accuracy_scores.append(accuracy_score(y_test, y_pred))

# Độ chính xác trung bình của các cây đơn lẻ (khoảng 80%)
np.mean(accuracy_scores)
np.float64(0.8056605)

c. Tổng hợp dự đoán (Majority Vote): Đối với mỗi mẫu trong tập kiểm tra, chúng ta lấy dự đoán từ 1,000 cây và chọn lớp được dự đoán nhiều nhất (biểu quyết số đông).

Y_pred = np.empty([n_trees, len(X_test)], dtype=np.uint8)

# Thu thập dự đoán từ tất cả các cây
for tree_index, tree in enumerate(forest):
    Y_pred[tree_index] = tree.predict(X_test)

from scipy.stats import mode
# Tìm mode (giá trị xuất hiện nhiều nhất) cho mỗi mẫu
y_pred_majority_votes, n_votes = mode(Y_pred, axis=0)

d. Đánh giá kết quả cuối cùng: Kết quả của việc biểu quyết số đông thường cao hơn so với cây đơn lẻ. Chúc mừng, bạn vừa tạo ra một bộ phân loại Random Forest!

accuracy_score(y_test, y_pred_majority_votes.reshape([-1]))
0.873