[ML101] Chương 3: Bài toán Phân loại (Classification)
Tìm hiểu về bài toán phân loại trong Machine Learning với các thuật toán phổ biến
Bài viết có tham khảo, sử dụng và sửa đổi tài nguyên từ kho lưu trữ handson-mlp, tuân thủ giấy phép Apache‑2.0. Chúng tôi chân thành cảm ơn tác giả Aurélien Géron (@aureliengeron) vì sự chia sẻ kiến thức tuyệt vời và những đóng góp quý giá cho cộng đồng.
Chào mừng bạn đến với Chương 3. Trong chương này, chúng ta sẽ đi sâu vào bài toán Phân loại (Classification), một trong những nhiệm vụ phổ biến nhất trong Học máy (Machine Learning). Khác với bài toán Hồi quy (Regression) ở chương trước là dự đoán một giá trị số liên tục, bài toán Phân loại tập trung vào việc gán nhãn cho dữ liệu đầu vào (ví dụ: email này là spam hay không, bức ảnh này là số mấy).
Bạn có thể chạy trực tiếp các đoạn mã code tại: Google Colab.
Cài đặt môi trường
Đầu tiên, như thường lệ, chúng ta cần đảm bảo môi trường Python và các thư viện cần thiết đã được cài đặt đúng phiên bản. Dự án này yêu cầu Python 3.10 trở lên và Scikit-Learn phiên bản 1.0.1 trở lên.
import sys
# Kiểm tra phiên bản Python phải từ 3.10 trở lên
assert sys.version_info >= (3, 10)
Tiếp theo, chúng ta kiểm tra phiên bản của thư viện Scikit-Learn.
from packaging.version import Version
import sklearn
# Kiểm tra phiên bản Scikit-Learn phải từ 1.6.1 trở lên
assert Version(sklearn.__version__) >= Version("1.6.1")
Giống như chương trước, chúng ta sẽ thiết lập kích thước phông chữ mặc định cho thư viện vẽ biểu đồ matplotlib để các hình ảnh trực quan hóa trông đẹp và rõ ràng hơn.
import matplotlib.pyplot as plt
# Thiết lập kích thước phông chữ cho các thành phần biểu đồ
plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)
Bộ dữ liệu MNIST
Trong chương này, chúng ta sẽ sử dụng bộ dữ liệu MNIST. Đây là bộ dữ liệu kinh điển trong giới Học máy, được ví như chương trình “Hello World” của các thuật toán nhận dạng hình ảnh. Bộ dữ liệu này bao gồm 70,000 hình ảnh nhỏ các chữ số viết tay (từ 0 đến 9) của học sinh trung học và nhân viên Cục điều tra dân số Mỹ.
Scikit-Learn cung cấp sẵn hàm fetch_openml để tải các bộ dữ liệu phổ biến.
from sklearn.datasets import fetch_openml
# Tải bộ dữ liệu MNIST từ OpenML
# as_frame=False để trả về dạng mảng NumPy thay vì DataFrame của Pandas
mnist = fetch_openml('mnist_784', as_frame=False)
Sau khi tải xong, chúng ta có thể xem mô tả chi tiết của bộ dữ liệu này. Đoạn mã dưới đây in ra phần mô tả (Description) của dữ liệu, bao gồm nguồn gốc, tác giả và các thông tin thống kê cơ bản.
# extra code – đoạn này chỉ để in ra mô tả, hơi dài một chút
print(mnist.DESCR)
**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges
**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown
**Please cite**:
The MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples
It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
With some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.
The MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.
Downloaded from openml.org.
Đối tượng mnist trả về là một từ điển (dictionary) chứa dữ liệu và siêu dữ liệu (metadata). Chúng ta có thể xem các khóa (keys) của nó.
mnist.keys() # extra code – chúng ta chỉ quan tâm đến 'data' và 'target' trong notebook này
dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])
Chúng ta sẽ tách dữ liệu thành hai phần:
X: Ma trận đặc trưng (features), chứa các giá trị pixel của hình ảnh.y: Mảng nhãn (target), chứa số thực tế mà hình ảnh biểu diễn.
X, y = mnist.data, mnist.target
X
array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]])
Hãy kiểm tra kích thước của tập dữ liệu. Chúng ta mong đợi 70,000 mẫu, và mỗi mẫu có 784 đặc trưng (vì mỗi ảnh có kích thước 28x28 pixel, và ).
X.shape
(70000, 784)
Tương tự, hãy xem mảng nhãn y. Chú ý rằng kiểu dữ liệu hiện tại đang là chuỗi ký tự (object/string).
y
array(['5', '0', '4', ..., '4', '5', '6'], dtype=object)
y.shape
(70000,)
28 * 28
784
Để hình dung dữ liệu rõ hơn, chúng ta sẽ viết một hàm plot_digit để vẽ lại hình ảnh từ mảng 784 giá trị pixel. Chúng ta cần biến đổi (reshape) mảng 1 chiều thành mảng 2 chiều 28x28.
import matplotlib.pyplot as plt
def plot_digit(image_data):
image = image_data.reshape(28, 28) # Chuyển mảng 1 chiều 784 thành ma trận 28x28
plt.imshow(image, cmap="binary") # Vẽ ảnh đen trắng (binary colormap)
plt.axis("off") # Tắt trục tọa độ
some_digit = X[0] # Lấy mẫu đầu tiên
plot_digit(some_digit)
plt.show()

Nhìn vào hình trên, có vẻ như đó là số 5. Hãy kiểm tra nhãn thực tế.
y[0]
'5'
Để có cái nhìn tổng quan hơn về độ đa dạng của chữ viết tay, hãy hiển thị 100 hình ảnh đầu tiên trong tập dữ liệu.
# extra code – đoạn mã này tạo ra Hình 3–2
plt.figure(figsize=(9, 9))
for idx, image_data in enumerate(X[:100]):
plt.subplot(10, 10, idx + 1) # Tạo lưới 10x10
plot_digit(image_data)
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

Trước khi đi vào huấn luyện, chúng ta bắt buộc phải chia dữ liệu thành tập huấn luyện (training set) và tập kiểm tra (test set). Bộ dữ liệu MNIST đã được chia sẵn: 60,000 ảnh đầu tiên là tập huấn luyện và 10,000 ảnh sau là tập kiểm tra.
# Chia tập train (60k mẫu đầu) và test (10k mẫu sau)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
Huấn luyện Bộ phân loại Nhị phân (Binary Classifier)
Để đơn giản hóa bài toán, chúng ta hãy bắt đầu với việc chỉ nhận diện một chữ số, ví dụ là số 5. Đây là bài toán Phân loại Nhị phân: phân biệt giữa “là số 5” (Dương - Positive) và “không phải số 5” (Âm - Negative).
Chúng ta cần tạo các vector nhãn mục tiêu mới cho bài toán này:
y_train_5 = (y_train == '5') # True cho tất cả các số 5, False cho các số khác
y_test_5 = (y_test == '5')
Bây giờ, hãy chọn một mô hình phân loại và huấn luyện nó. Chúng ta sẽ bắt đầu với SGDClassifier (Stochastic Gradient Descent), một bộ phân loại tuyến tính có khả năng xử lý dữ liệu lớn hiệu quả.
from sklearn.linear_model import SGDClassifier
# Khởi tạo bộ phân loại SGD với random_state cố định để kết quả có thể tái lập
sgd_clf = SGDClassifier(random_state=42)
# Huấn luyện mô hình trên tập train
sgd_clf.fit(X_train, y_train_5)
SGDClassifier(random_state=42)
SGDClassifier(random_state=42)
Sau khi mô hình đã được huấn luyện, hãy thử dùng nó để dự đoán mẫu some_digit (vốn là số 5) mà chúng ta đã lấy ra từ đầu.
sgd_clf.predict([some_digit])
array([ True])
Kết quả là True, nghĩa là mô hình đã nhận diện đúng đây là số 5.
Các thước đo Hiệu năng (Performance Measures)
Đánh giá một bộ phân loại thường khó hơn nhiều so với bộ hồi quy. Chúng ta có nhiều công cụ khác nhau để thực hiện việc này.
Đo độ chính xác bằng Kiểm định chéo (Cross-Validation)
Kiểm định chéo là một phương pháp mạnh mẽ để đánh giá mô hình. Hàm cross_val_score sẽ chia tập huấn luyện thành cv phần (folds), sau đó lần lượt huấn luyện trên cv-1 phần và kiểm tra trên phần còn lại.
from sklearn.model_selection import cross_val_score
# Đánh giá mô hình bằng accuracy (độ chính xác) với 3-fold cross-validation
cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.95035, 0.96035, 0.9604 ])
Kết quả trên cho thấy độ chính xác trên 95%. Tuy nhiên, để hiểu rõ cơ chế hoạt động bên dưới, đôi khi chúng ta cần tự cài đặt quy trình kiểm định chéo. Đoạn mã sau minh họa cách sử dụng StratifiedKFold để đảm bảo tỷ lệ các lớp trong mỗi fold là tương đồng với toàn bộ tập dữ liệu.
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
skfolds = StratifiedKFold(n_splits=3) # Thêm shuffle=True nếu dữ liệu chưa được xáo trộn
# Duyệt qua từng fold
for train_index, test_index in skfolds.split(X_train, y_train_5):
clone_clf = clone(sgd_clf) # Sao chép mô hình gốc
# Tách dữ liệu theo index
X_train_folds = X_train[train_index]
y_train_folds = y_train_5[train_index]
X_test_fold = X_train[test_index]
y_test_fold = y_train_5[test_index]
# Huấn luyện và dự đoán
clone_clf.fit(X_train_folds, y_train_folds)
y_pred = clone_clf.predict(X_test_fold)
# Tính tỷ lệ dự đoán đúng
n_correct = sum(y_pred == y_test_fold)
print(n_correct / len(y_pred))
0.95035
0.96035
0.9604
Độ chính xác rất cao (trên 95%). Nhưng hãy cẩn thận! Trong tập dữ liệu này, số 5 chỉ chiếm khoảng 10%. Điều đó có nghĩa là nếu bạn có một mô hình “ngớ ngẩn” luôn đoán “Không phải số 5”, nó vẫn sẽ đúng 90% thời gian.
Hãy thử chứng minh điều này bằng DummyClassifier.
from sklearn.dummy import DummyClassifier
dummy_clf = DummyClassifier()
dummy_clf.fit(X_train, y_train_5)
print(any(dummy_clf.predict(X_train))) # Kiểm tra xem có dự đoán nào là True không
False
# Đánh giá Dummy Classifier
cross_val_score(dummy_clf, X_train, y_train_5, cv=3, scoring="accuracy")
array([0.90965, 0.90965, 0.90965])
Đúng như dự đoán, độ chính xác đạt khoảng 90%. Điều này cho thấy Accuracy (Độ chính xác) không phải là thước đo tốt cho các bài toán mà dữ liệu bị lệch (imbalanced datasets).
Ma trận Nhầm lẫn (Confusion Matrix)
Cách tốt hơn để đánh giá là nhìn vào Ma trận nhầm lẫn. Nó cho biết có bao nhiêu mẫu thuộc lớp A bị phân loại nhầm thành lớp B. Để tính ma trận này, trước tiên ta cần các dự đoán.
from sklearn.model_selection import cross_val_predict
# Trả về kết quả dự đoán trên từng fold thay vì điểm số
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_train_5, y_train_pred)
cm
array([[53892, 687],
[ 1891, 3530]])
Mỗi hàng đại diện cho một lớp thực tế, mỗi cột đại diện cho một lớp dự đoán:
- Hàng đầu tiên: Các ảnh không phải số 5 (Negative Class).
- Hàng thứ hai: Các ảnh là số 5 (Positive Class).
Một mô hình hoàn hảo sẽ chỉ có các giá trị nằm trên đường chéo chính (True Negative và True Positive).
y_train_perfect_predictions = y_train_5 # Giả sử chúng ta dự đoán đúng hoàn toàn
confusion_matrix(y_train_5, y_train_perfect_predictions)
array([[54579, 0],
[ 0, 5421]])
Precision (Độ chính xác của dự báo dương) và Recall (Độ phủ)
Ma trận nhầm lẫn chứa rất nhiều thông tin, nhưng đôi khi chúng ta muốn các chỉ số cô đọng hơn:
- Precision = : Trong số các mẫu mô hình đoán là “5”, bao nhiêu phần trăm là đúng?
- Recall = : Trong số tất cả các số “5” thực tế, mô hình tìm ra được bao nhiêu phần trăm?
from sklearn.metrics import precision_score, recall_score
precision_score(y_train_5, y_train_pred) # Tương đương 3530 / (687 + 3530)
0.8370879772350012
# extra code – tính thủ công để kiểm chứng: TP / (FP + TP)
cm[1, 1] / (cm[0, 1] + cm[1, 1])
np.float64(0.8370879772350012)
recall_score(y_train_5, y_train_pred) # Tương đương 3530 / (1891 + 3530)
0.6511713705958311
# extra code – tính thủ công để kiểm chứng: TP / (FN + TP)
cm[1, 1] / (cm[1, 0] + cm[1, 1])
np.float64(0.6511713705958311)
Chúng ta có thể kết hợp Precision và Recall thành một chỉ số duy nhất gọi là Score. Đây là trung bình điều hòa của Precision và Recall.
from sklearn.metrics import f1_score
f1_score(y_train_5, y_train_pred)
0.7325171197343847
# extra code – tính thủ công F1 score
cm[1, 1] / (cm[1, 1] + (cm[1, 0] + cm[0, 1]) / 2)
np.float64(0.7325171197343847)
Sự đánh đổi giữa Precision và Recall (Precision/Recall Trade-off)
Không may là bạn không thể tối ưu hóa cả hai chỉ số này cùng lúc. Tăng Precision sẽ làm giảm Recall và ngược lại. SGDClassifier đưa ra quyết định dựa trên một hàm quyết định (decision_function). Nếu điểm số lớn hơn ngưỡng (threshold), nó gán nhãn Positive.
Hãy xem điểm số của mẫu some_digit:
y_scores = sgd_clf.decision_function([some_digit])
y_scores
array([2164.22030239])
Mặc định, ngưỡng là 0.
threshold = 0
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([ True])
# extra code – minh họa rằng y_scores > 0 cho kết quả giống hàm predict()
y_scores > 0
array([ True])
Nếu chúng ta tăng ngưỡng lên 3000, Recall sẽ giảm (vì mô hình khắt khe hơn, dễ bỏ sót số 5).
threshold = 3000
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred
array([False])
Để quyết định ngưỡng nào là tốt nhất, chúng ta cần lấy điểm quyết định của tất cả các mẫu trong tập huấn luyện, sau đó vẽ biểu đồ Precision và Recall theo các ngưỡng khác nhau.
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
method="decision_function")
from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
Đoạn mã dưới đây vẽ biểu đồ Precision và Recall dưới dạng các hàm số của ngưỡng quyết định.
plt.figure(figsize=(8, 4)) # extra code – định dạng hình ảnh
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.vlines(threshold, 0, 1.0, "k", "dotted", label="threshold")
# extra code – trang trí cho Hình 3–5 thêm đẹp và dễ hiểu
idx = (thresholds >= threshold).argmax() # tìm index đầu tiên >= threshold
plt.plot(thresholds[idx], precisions[idx], "bo")
plt.plot(thresholds[idx], recalls[idx], "go")
plt.axis([-50000, 50000, 0, 1])
plt.grid()
plt.xlabel("Threshold")
plt.legend(loc="center right")
plt.show()

Một cách khác để chọn điểm cân bằng tốt là vẽ Precision trực tiếp theo Recall.
import matplotlib.patches as patches # extra code – dùng để vẽ mũi tên cong
plt.figure(figsize=(6, 5)) # extra code
plt.plot(recalls, precisions, linewidth=2, label="Precision/Recall curve")
# extra code – trang trí cho Hình 3–6
plt.plot([recalls[idx], recalls[idx]], [0., precisions[idx]], "k:")
plt.plot([0.0, recalls[idx]], [precisions[idx], precisions[idx]], "k:")
plt.plot([recalls[idx]], [precisions[idx]], "ko",
label="Point at threshold 3,000")
plt.gca().add_patch(patches.FancyArrowPatch(
(0.79, 0.60), (0.61, 0.78),
connectionstyle="arc3,rad=.2",
arrowstyle="Simple, tail_width=1.5, head_width=8, head_length=10",
color="#444444"))
plt.text(0.56, 0.62, "Higher\nthreshold", color="#333333")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.axis([0, 1, 0, 1])
plt.grid()
plt.legend(loc="lower left")
plt.show()

Giả sử bạn muốn đạt Precision là 90%. Chúng ta có thể tìm ngưỡng cần thiết để đạt được điều này.
idx_for_90_precision = (precisions >= 0.90).argmax()
threshold_for_90_precision = thresholds[idx_for_90_precision]
threshold_for_90_precision
np.float64(3370.0194991439557)
Sau đó, chúng ta áp dụng ngưỡng này để đưa ra dự đoán và kiểm tra kết quả.
y_train_pred_90 = (y_scores >= threshold_for_90_precision)
precision_score(y_train_5, y_train_pred_90)
0.9000345901072293
Tuy nhiên, khi Precision tăng thì Recall sẽ giảm.
recall_at_90_precision = recall_score(y_train_5, y_train_pred_90)
recall_at_90_precision
0.4799852425751706
Đường cong ROC (The ROC Curve)
Đường cong Receiver Operating Characteristic (ROC) là một công cụ phổ biến khác. Nó biểu diễn Tỷ lệ Dương tính Thật (True Positive Rate - TPR, cũng chính là Recall) so với Tỷ lệ Dương tính Giả (False Positive Rate - FPR).
from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
Hãy vẽ biểu đồ ROC. Đường chấm chấm biểu diễn một bộ phân loại ngẫu nhiên (hoàn toàn không có khả năng phân loại). Một bộ phân loại tốt sẽ có đường cong càng xa đường chấm này (về phía góc trên bên trái) càng tốt.
idx_for_threshold_at_90 = (thresholds <= threshold_for_90_precision).argmax()
tpr_90, fpr_90 = tpr[idx_for_threshold_at_90], fpr[idx_for_threshold_at_90]
plt.figure(figsize=(6, 5)) # extra code
plt.plot(fpr, tpr, linewidth=2, label="ROC curve")
plt.plot([0, 1], [0, 1], 'k:', label="Random classifier's ROC curve")
plt.plot([fpr_90], [tpr_90], "ko", label="Threshold for 90% precision")
# extra code – trang trí cho Hình 3–7
plt.gca().add_patch(patches.FancyArrowPatch(
(0.20, 0.89), (0.07, 0.70),
connectionstyle="arc3,rad=.4",
arrowstyle="Simple, tail_width=1.5, head_width=8, head_length=10",
color="#444444"))
plt.text(0.12, 0.71, "Higher\nthreshold", color="#333333")
plt.xlabel('False Positive Rate (Fall-Out)')
plt.ylabel('True Positive Rate (Recall)')
plt.grid()
plt.axis([0, 1, 0, 1])
plt.legend(loc="lower right", fontsize=13)
plt.show()

Một cách để so sánh các bộ phân loại là đo Diện tích dưới đường cong (Area Under the Curve - AUC). Bộ phân loại hoàn hảo có ROC AUC bằng 1.
from sklearn.metrics import roc_auc_score
roc_auc_score(y_train_5, y_scores)
np.float64(0.9604938554008616)
Bây giờ hãy thử huấn luyện một RandomForestClassifier và so sánh đường cong ROC cũng như ROC AUC của nó với SGDClassifier.
Lưu ý: Đoạn mã sau có thể mất vài phút để chạy.
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
# RandomForest không có decision_function mà dùng predict_proba
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,
method="predict_proba")
Hàm predict_proba trả về một mảng chứa xác suất cho mỗi lớp (ví dụ: 10% là số 5, 90% không phải).
y_probas_forest[:2]
array([[0.11, 0.89],
[0.99, 0.01]])
Một điều thú vị là đây là xác suất ước lượng. Trong số các ảnh mà mô hình dự đoán có xác suất từ 50% đến 60% là số 5, thực tế có khoảng 94% là số 5 thật.
# Not in the code - kiểm chứng xác suất ước lượng
idx_50_to_60 = (y_probas_forest[:, 1] > 0.50) & (y_probas_forest[:, 1] < 0.60)
print(f"{(y_train_5[idx_50_to_60]).sum() / idx_50_to_60.sum():.1%}")
94.0%
Để vẽ đường cong ROC, chúng ta cần điểm số (scores) chứ không phải xác suất. Tuy nhiên, chúng ta có thể dùng xác suất của lớp dương tính làm điểm số.
y_scores_forest = y_probas_forest[:, 1]
precisions_forest, recalls_forest, thresholds_forest = precision_recall_curve(
y_train_5, y_scores_forest)
plt.figure(figsize=(6, 5)) # extra code
plt.plot(recalls_forest, precisions_forest, "b-", linewidth=2,
label="Random Forest")
plt.plot(recalls, precisions, "--", linewidth=2, label="SGD")
# extra code – trang trí cho Hình 3–8
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.axis([0, 1, 0, 1])
plt.grid()
plt.legend(loc="lower left")
plt.show()

Chúng ta thấy đường cong của Random Forest tốt hơn nhiều so với SGD (nằm gần góc trên bên phải hơn). Hãy kiểm tra các chỉ số cụ thể.
y_train_pred_forest = y_probas_forest[:, 1] >= 0.5 # xác suất >= 50% thì coi là dương tính
f1_score(y_train_5, y_train_pred_forest)
0.9274509803921569
roc_auc_score(y_train_5, y_scores_forest)
np.float64(0.9983436731328145)
precision_score(y_train_5, y_train_pred_forest)
0.9897468089558485
recall_score(y_train_5, y_train_pred_forest)
0.8725327430363402
Phân loại Đa lớp (Multiclass Classification)
Phân loại đa lớp là khi chúng ta muốn phân biệt nhiều hơn hai lớp (ví dụ: số 0 đến số 9). Một số thuật toán (như Random Forest, Naive Bayes) có khả năng xử lý trực tiếp nhiều lớp. Một số khác (như SVM, Linear Classifiers) vốn dĩ là nhị phân.
Tuy nhiên, chúng ta có thể dùng nhiều bộ phân loại nhị phân để giải quyết bài toán đa lớp:
- OvR (One-versus-the-Rest): Huấn luyện 10 bộ phân loại (0 vs tất cả, 1 vs tất cả…).
- OvO (One-versus-One): Huấn luyện bộ phân loại cho từng cặp (0 vs 1, 0 vs 2…).
SVM (Support Vector Machine) hoạt động kém với dữ liệu lớn, nên chúng ta sẽ chỉ thử nghiệm trên 2,000 mẫu đầu tiên để tiết kiệm thời gian.
from sklearn.svm import SVC
svm_clf = SVC(random_state=42)
svm_clf.fit(X_train[:2000], y_train[:2000]) # Dùng y_train (nhãn gốc 0-9), không phải y_train_5
SVC(random_state=42)
SVC(random_state=42)
Hãy thử dự đoán mẫu số 5 ban đầu.
svm_clf.predict([some_digit])
array(['5'], dtype=object)
Khi gọi hàm decision_function, Scikit-Learn thực chất đang chạy chiến lược OvO (với SVM) và trả về điểm số cho từng lớp. Lớp nào có điểm cao nhất sẽ được chọn.
some_digit_scores = svm_clf.decision_function([some_digit])
some_digit_scores.round(2)
array([[ 3.79, 0.73, 6.06, 8.3 , -0.29, 9.3 , 1.75, 2.77, 7.21,
4.82]])
class_id = some_digit_scores.argmax() # Tìm vị trí có điểm cao nhất
class_id
np.int64(5)
svm_clf.classes_
array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], dtype=object)
svm_clf.classes_[class_id]
'5'
Nếu bạn muốn ép buộc SVM trả về điểm số của tất cả các cặp đấu (OvO), bạn có thể cấu hình tham số decision_function_shape.
# extra code – hiển thị 45 điểm số cho 45 cặp đấu (10 chọn 2)
svm_clf.decision_function_shape = "ovo"
some_digit_scores_ovo = svm_clf.decision_function([some_digit])
some_digit_scores_ovo.round(2)
array([[ 0.11, -0.21, -0.97, 0.51, -1.01, 0.19, 0.09, -0.31, -0.04,
-0.45, -1.28, 0.25, -1.01, -0.13, -0.32, -0.9 , -0.36, -0.93,
0.79, -1. , 0.45, 0.24, -0.24, 0.25, 1.54, -0.77, 1.11,
1.13, 1.04, 1.2 , -1.42, -0.53, -0.45, -0.99, -0.95, 1.21,
1. , 1. , 1.08, -0.02, -0.67, -0.14, -0.3 , -0.13, 0.25]])
Hoặc bạn có thể dùng OneVsRestClassifier để ép buộc sử dụng chiến lược OvR cho SVM.
from sklearn.multiclass import OneVsRestClassifier
ovr_clf = OneVsRestClassifier(SVC(random_state=42))
ovr_clf.fit(X_train[:2000], y_train[:2000])
OneVsRestClassifier(estimator=SVC(random_state=42))
OneVsRestClassifier(estimator=SVC(random_state=42))
SVC(random_state=42)
SVC(random_state=42)
ovr_clf.predict([some_digit])
array(['5'], dtype='<U1')
len(ovr_clf.estimators_) # Kiểm tra xem có đúng là 10 mô hình con không
10
Với SGDClassifier, Scikit-Learn mặc định sử dụng chiến lược OvR vì nó hiệu quả hơn cho thuật toán này.
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])
array(['3'], dtype='<U1')
sgd_clf.decision_function([some_digit]).round()
array([[-31893., -34420., -9531., 1824., -22320., -1386., -26189.,
-16148., -4604., -12051.]])
Cảnh báo: Các ô mã sau có thể mất vài phút để chạy.
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
array([0.87365, 0.85835, 0.8689 ])
Kết quả trên 85% là khá tốt. Tuy nhiên, chúng ta có thể cải thiện hơn nữa bằng cách chuẩn hóa dữ liệu đầu vào (Scaling). Các mô hình tuyến tính như SGD rất nhạy cảm với độ lớn của đặc trưng.
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype("float64"))
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
array([0.8983, 0.891 , 0.9018])
Phân tích Lỗi (Error Analysis)
Nếu đây là một dự án thực tế, bạn sẽ cần thực hiện nhiều bước như chọn mô hình, tinh chỉnh siêu tham số… Ở đây, giả sử chúng ta đã chọn được mô hình tốt nhất, bây giờ chúng ta muốn tìm cách cải thiện nó bằng cách phân tích các lỗi mà nó mắc phải.
Đầu tiên, hãy nhìn vào Ma trận nhầm lẫn.
Cảnh báo: Ô mã sau sẽ mất vài phút để chạy.
from sklearn.metrics import ConfusionMatrixDisplay
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
plt.rc('font', size=9) # extra code – chỉnh font nhỏ lại cho vừa
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred)
plt.show()

Ma trận trên hơi khó nhìn vì số lượng mẫu ở mỗi lớp khác nhau. Hãy chuẩn hóa (normalize) nó để xem tỷ lệ phần trăm sai sót.
plt.rc('font', size=10) # extra code
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,
normalize="true", values_format=".0%")
plt.show()

Để làm rõ hơn những nơi mô hình mắc lỗi, chúng ta có thể đặt trọng số bằng 0 cho các dự đoán đúng (đường chéo chính), chỉ tập trung hiển thị các lỗi.
sample_weight = (y_train_pred != y_train)
plt.rc('font', size=10) # extra code
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred,
sample_weight=sample_weight,
normalize="true", values_format=".0%")
plt.show()

Chúng ta có thể vẽ lại các biểu đồ này cạnh nhau để dễ so sánh cho mục đích trình bày trong sách.
# extra code – tạo Hình 3–9
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
plt.rc('font', size=9)
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0])
axs[0].set_title("Confusion matrix")
plt.rc('font', size=10)
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1],
normalize="true", values_format=".0%")
axs[1].set_title("CM normalized by row")
plt.show()

# extra code – tạo Hình 3–10
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
plt.rc('font', size=10)
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[0],
sample_weight=sample_weight,
normalize="true", values_format=".0%")
axs[0].set_title("Errors normalized by row")
ConfusionMatrixDisplay.from_predictions(y_train, y_train_pred, ax=axs[1],
sample_weight=sample_weight,
normalize="pred", values_format=".0%")
axs[1].set_title("Errors normalized by column")
plt.show()
plt.rc('font', size=14) # trả lại font chữ lớn

Nhìn vào ma trận, ta thấy có sự nhầm lẫn đáng kể giữa số 3 và số 5. Hãy thử hiển thị các hình ảnh cụ thể để xem tại sao mô hình lại nhầm lẫn.
cl_a, cl_b = '3', '5'
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)] # 3 đoán là 3
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)] # 3 đoán là 5 (Lỗi)
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)] # 5 đoán là 3 (Lỗi)
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)] # 5 đoán là 5
# extra code – tạo Hình 3–11
size = 5
pad = 0.2
plt.figure(figsize=(size, size))
for images, (label_col, label_row) in [(X_ba, (0, 0)), (X_bb, (1, 0)),
(X_aa, (0, 1)), (X_ab, (1, 1))]:
for idx, image_data in enumerate(images[:size*size]):
x = idx % size + label_col * (size + pad)
y = idx // size + label_row * (size + pad)
plt.imshow(image_data.reshape(28, 28), cmap="binary",
extent=(x, x + 1, y, y + 1))
plt.xticks([size / 2, size + pad + size / 2], [str(cl_a), str(cl_b)])
plt.yticks([size / 2, size + pad + size / 2], [str(cl_b), str(cl_a)])
plt.plot([size + pad / 2, size + pad / 2], [0, 2 * size + pad], "k:")
plt.plot([0, 2 * size + pad], [size + pad / 2, size + pad / 2], "k:")
plt.axis([0, 2 * size + pad, 0, 2 * size + pad])
plt.xlabel("Predicted label")
plt.ylabel("True label")
plt.show()

Phân loại Đa nhãn (Multilabel Classification)
Đôi khi, chúng ta muốn gán nhiều nhãn cho một mẫu. Ví dụ: trong một bức ảnh có thể có cả “Alice” và “Bob”.
Hãy thử tạo một hệ thống phân loại đa nhãn:
- Số này có lớn hơn hoặc bằng 7 không?
- Số này có phải là số lẻ không?
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
y_train_large = (y_train >= '7')
y_train_odd = (y_train.astype('int8') % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd] # Kết hợp 2 nhãn lại
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
KNeighborsClassifier()
KNeighborsClassifier()
Thử dự đoán với số 5: Nó không lớn hơn 7 (False), và nó là số lẻ (True).
knn_clf.predict([some_digit])
array([[False, True]])
Cảnh báo: Ô mã sau có thể mất vài phút để chạy.
y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3)
f1_score(y_multilabel, y_train_knn_pred, average="macro")
0.9764102655606048
# extra code – kiểm tra average="weighted"
f1_score(y_multilabel, y_train_knn_pred, average="weighted")
0.9778357403921755
Nếu bạn muốn sử dụng một bộ phân loại không hỗ trợ đa nhãn (như SVM), bạn có thể dùng ClassifierChain để xâu chuỗi các mô hình lại với nhau.
from sklearn.multioutput import ClassifierChain
chain_clf = ClassifierChain(SVC(), cv=3, random_state=42)
chain_clf.fit(X_train[:2000], y_multilabel[:2000])
ClassifierChain(base_estimator=SVC(), cv=3, random_state=42)
ClassifierChain(base_estimator=SVC(), cv=3, random_state=42)
SVC()
SVC()
chain_clf.predict([some_digit])
array([[0., 1.]])
Phân loại Đa đầu ra (Multioutput Classification)
Đây là dạng tổng quát nhất, trong đó mỗi nhãn có thể có nhiều hơn 2 giá trị (multiclass).
Ví dụ: Làm sạch nhiễu (noise) khỏi ảnh. Đầu vào là ảnh nhiễu, đầu ra là ảnh sạch (784 pixel, mỗi pixel có giá trị từ 0-255).
rng = np.random.default_rng(seed=42)
noise_train = rng.integers(0, 100, (len(X_train), 784)) # Tạo nhiễu ngẫu nhiên
X_train_mod = X_train + noise_train
noise_test = rng.integers(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise_test
y_train_mod = X_train
y_test_mod = X_test
# extra code – tạo Hình 3–12
plt.subplot(121); plot_digit(X_test_mod[0])
plt.subplot(122); plot_digit(y_test_mod[0])
plt.show()

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_mod, y_train_mod)
clean_digit = knn_clf.predict([X_test_mod[0]])
plot_digit(clean_digit)
plt.show()

Phần mở rộng
1. Bộ phân loại MNIST với độ chính xác trên 97%
Bài tập: Hãy cố gắng xây dựng một bộ phân loại cho tập dữ liệu MNIST đạt độ chính xác trên 97% trên tập kiểm tra. Gợi ý: KNeighborsClassifier hoạt động khá tốt cho nhiệm vụ này; bạn chỉ cần tìm các giá trị siêu tham số tốt (thử tìm kiếm lưới weights và n_neighbors).
Chúng ta bắt đầu với một mô hình KNN đơn giản làm mốc so sánh (baseline).
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
baseline_accuracy = knn_clf.score(X_test, y_test)
baseline_accuracy
0.9688
Kết quả đã rất gần 97%. Hãy dùng GridSearchCV để tinh chỉnh siêu tham số. Để tiết kiệm thời gian, chúng ta chỉ huấn luyện trên 10,000 mẫu đầu.
from sklearn.model_selection import GridSearchCV
param_grid = [{'weights': ["uniform", "distance"], 'n_neighbors': [3, 4, 5, 6]}]
knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, cv=5)
grid_search.fit(X_train[:10_000], y_train[:10_000])
GridSearchCV(cv=5, estimator=KNeighborsClassifier(),
param_grid=[{'n_neighbors': [3, 4, 5, 6],
'weights': ['uniform', 'distance']}])GridSearchCV(cv=5, estimator=KNeighborsClassifier(),
param_grid=[{'n_neighbors': [3, 4, 5, 6],
'weights': ['uniform', 'distance']}])KNeighborsClassifier(n_neighbors=4, weights='distance')
KNeighborsClassifier(n_neighbors=4, weights='distance')
grid_search.best_params_
{'n_neighbors': 4, 'weights': 'distance'}
grid_search.best_score_
np.float64(0.9441999999999998)
Bây giờ hãy lấy mô hình tốt nhất này và huấn luyện lại trên toàn bộ tập dữ liệu.
grid_search.best_estimator_.fit(X_train, y_train)
tuned_accuracy = grid_search.score(X_test, y_test)
tuned_accuracy
0.9714
Chúng ta đã đạt mục tiêu 97%! 🥳
2. Tăng cường dữ liệu (Data Augmentation)
Bài tập: Viết một hàm để dịch chuyển hình ảnh MNIST theo bất kỳ hướng nào (trái, phải, lên, xuống) một pixel. Sau đó, với mỗi ảnh trong tập huấn luyện, tạo ra 4 bản sao đã dịch chuyển và thêm vào tập huấn luyện. Cuối cùng, huấn luyện mô hình tốt nhất của bạn trên tập dữ liệu mở rộng này.
Kỹ thuật này gọi là Data Augmentation.
from scipy.ndimage import shift
def shift_image(image, dx, dy):
image = image.reshape((28, 28))
# Dịch chuyển ảnh và lấp đầy khoảng trống bằng 0 (màu trắng/đen tùy nền)
shifted_image = shift(image, [dy, dx], cval=0, mode="constant")
return shifted_image.reshape([-1])
Hãy kiểm tra xem hàm hoạt động như thế nào.
image = X_train[1000] # lấy một chữ số ngẫu nhiên
shifted_image_down = shift_image(image, 0, 5)
shifted_image_left = shift_image(image, -5, 0)
plt.figure(figsize=(12, 3))
plt.subplot(131)
plt.title("Original")
plt.imshow(image.reshape(28, 28),
interpolation="nearest", cmap="Greys")
plt.subplot(132)
plt.title("Shifted down")
plt.imshow(shifted_image_down.reshape(28, 28),
interpolation="nearest", cmap="Greys")
plt.subplot(133)
plt.title("Shifted left")
plt.imshow(shifted_image_left.reshape(28, 28),
interpolation="nearest", cmap="Greys")
plt.show()

Bây giờ hãy tạo tập huấn luyện mở rộng.
X_train_augmented = [image for image in X_train]
y_train_augmented = [label for label in y_train]
for dx, dy in ((-1, 0), (1, 0), (0, 1), (0, -1)): # 4 hướng
for image, label in zip(X_train, y_train):
X_train_augmented.append(shift_image(image, dx, dy))
y_train_augmented.append(label)
X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)
Chúng ta cần xáo trộn dữ liệu vì hiện tại các ảnh dịch chuyển đang nằm cụm lại với nhau.
rng = np.random.default_rng(seed=42)
shuffle_idx = rng.permutation(len(X_train_augmented))
X_train_augmented = X_train_augmented[shuffle_idx]
y_train_augmented = y_train_augmented[shuffle_idx]
Tiến hành huấn luyện mô hình KNN với tham số tốt nhất đã tìm được.
knn_clf = KNeighborsClassifier(**grid_search.best_params_)
knn_clf.fit(X_train_augmented, y_train_augmented)
KNeighborsClassifier(n_neighbors=4, weights='distance')
KNeighborsClassifier(n_neighbors=4, weights='distance')
Cảnh báo: Ô mã sau có thể mất vài phút để chạy.
augmented_accuracy = knn_clf.score(X_test, y_test)
augmented_accuracy
0.9763
Độ chính xác tăng thêm khoảng 0.5%. Nghe có vẻ nhỏ, nhưng nó giúp giảm tỷ lệ lỗi (error rate) đáng kể.
error_rate_change = (1 - augmented_accuracy) / (1 - tuned_accuracy) - 1
print(f"error_rate_change = {error_rate_change:.0%}")
error_rate_change = -17%
3. Thử thách với bộ dữ liệu Titanic
Bài tập: Mục tiêu là xây dựng mô hình dự đoán hành khách nào sống sót dựa trên các thông tin như tuổi, giới tính, hạng vé, v.v.
Đầu tiên, tải và nạp dữ liệu.
from pathlib import Path
import pandas as pd
import tarfile
import urllib.request
def load_titanic_data():
tarball_path = Path("datasets/titanic.tgz")
if not tarball_path.is_file():
Path("datasets").mkdir(parents=True, exist_ok=True)
url = "https://github.com/ageron/data/raw/main/titanic.tgz"
urllib.request.urlretrieve(url, tarball_path)
with tarfile.open(tarball_path) as titanic_tarball:
titanic_tarball.extractall(path="datasets", filter="data")
return [pd.read_csv(Path("datasets/titanic") / filename)
for filename in ("train.csv", "test.csv")]
train_data, test_data = load_titanic_data()
Dữ liệu đã được chia sẵn train/test. Hãy xem qua vài dòng dữ liệu huấn luyện.
train_data.head()
PassengerId Survived Pclass \
0 1 0 3
1 2 1 1
2 3 1 3
3 4 1 1
4 5 0 3
Name Sex Age SibSp \
0 Braund, Mr. Owen Harris male 22.0 1
1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1
2 Heikkinen, Miss. Laina female 26.0 0
3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1
4 Allen, Mr. William Henry male 35.0 0
Parch Ticket Fare Cabin Embarked
0 0 A/5 21171 7.2500 NaN S
1 0 PC 17599 71.2833 C85 C
2 0 STON/O2. 3101282 7.9250 NaN S
3 0 113803 53.1000 C123 S
4 0 373450 8.0500 NaN S
Đặt PassengerId làm index.
train_data = train_data.set_index("PassengerId")
test_data = test_data.set_index("PassengerId")
Kiểm tra thông tin dữ liệu (missing values, data types).
train_data.info()
<class 'pandas.core.frame.DataFrame'>
Index: 891 entries, 1 to 891
Data columns (total 11 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Survived 891 non-null int64
1 Pclass 891 non-null int64
2 Name 891 non-null object
3 Sex 891 non-null object
4 Age 714 non-null float64
5 SibSp 891 non-null int64
6 Parch 891 non-null int64
7 Ticket 891 non-null object
8 Fare 891 non-null float64
9 Cabin 204 non-null object
10 Embarked 889 non-null object
dtypes: float64(2), int64(4), object(5)
memory usage: 83.5+ KB
train_data[train_data["Sex"]=="female"]["Age"].median()
27.0
Ta thấy Age, Cabin và Embarked bị thiếu dữ liệu. Chúng ta sẽ xây dựng các Pipeline để xử lý việc này (ví dụ: điền tuổi bằng giá trị trung vị).
Hãy xem thống kê các thuộc tính số.
train_data.describe()
Survived Pclass Age SibSp Parch Fare
count 891.000000 891.000000 714.000000 891.000000 891.000000 891.000000
mean 0.383838 2.308642 29.699113 0.523008 0.381594 32.204208
std 0.486592 0.836071 14.526507 1.102743 0.806057 49.693429
min 0.000000 1.000000 0.416700 0.000000 0.000000 0.000000
25% 0.000000 2.000000 20.125000 0.000000 0.000000 7.910400
50% 0.000000 3.000000 28.000000 0.000000 0.000000 14.454200
75% 1.000000 3.000000 38.000000 1.000000 0.000000 31.000000
max 1.000000 3.000000 80.000000 8.000000 6.000000 512.329200
Chỉ khoảng 38% hành khách sống sót. Kiểm tra phân phối các nhãn.
train_data["Survived"].value_counts()
| count | |
|---|---|
| Survived | |
| 0 | 549 |
| 1 | 342 |
Kiểm tra các thuộc tính phân loại (categorical).
train_data["Pclass"].value_counts()
| count | |
|---|---|
| Pclass | |
| 3 | 491 |
| 1 | 216 |
| 2 | 184 |
train_data["Sex"].value_counts()
| count | |
|---|---|
| Sex | |
| male | 577 |
| female | 314 |
train_data["Embarked"].value_counts()
| count | |
|---|---|
| Embarked | |
| S | 644 |
| C | 168 |
| Q | 77 |
Xây dựng Pipeline xử lý dữ liệu.
- Với dữ liệu số: Điền giá trị thiếu bằng trung vị (Median), sau đó chuẩn hóa (StandardScaler).
- Với dữ liệu phân loại: Điền giá trị thiếu bằng giá trị phổ biến nhất, sau đó mã hóa OneHot.
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
num_pipeline = Pipeline([
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler())
])
from sklearn.preprocessing import OrdinalEncoder, OneHotEncoder
cat_pipeline = Pipeline([
("ordinal_encoder", OrdinalEncoder()),
("imputer", SimpleImputer(strategy="most_frequent")),
("cat_encoder", OneHotEncoder(sparse_output=False)),
])
from sklearn.compose import ColumnTransformer
num_attribs = ["Age", "SibSp", "Parch", "Fare"]
cat_attribs = ["Pclass", "Sex", "Embarked"]
preprocess_pipeline = ColumnTransformer([
("num", num_pipeline, num_attribs),
("cat", cat_pipeline, cat_attribs),
])
Áp dụng Pipeline để chuẩn bị dữ liệu huấn luyện.
X_train = preprocess_pipeline.fit_transform(train_data)
X_train
array([[-0.56573582, 0.43279337, -0.47367361, ..., 0. ,
0. , 1. ],
[ 0.6638609 , 0.43279337, -0.47367361, ..., 1. ,
0. , 0. ],
[-0.25833664, -0.4745452 , -0.47367361, ..., 0. ,
0. , 1. ],
...,
[-0.10463705, 0.43279337, 2.00893337, ..., 0. ,
0. , 1. ],
[-0.25833664, -0.4745452 , -0.47367361, ..., 1. ,
0. , 0. ],
[ 0.20276213, -0.4745452 , -0.47367361, ..., 0. ,
1. , 0. ]])
y_train = train_data["Survived"]
Huấn luyện với RandomForestClassifier.
forest_clf = RandomForestClassifier(n_estimators=100, random_state=42)
forest_clf.fit(X_train, y_train)
RandomForestClassifier(random_state=42)
RandomForestClassifier(random_state=42)
Dự đoán trên tập test và đánh giá bằng Cross Validation.
X_test = preprocess_pipeline.transform(test_data)
y_pred = forest_clf.predict(X_test)
forest_scores = cross_val_score(forest_clf, X_train, y_train, cv=10)
forest_scores.mean()
np.float64(0.8137578027465668)
Thử với SVM xem sao.
from sklearn.svm import SVC
svm_clf = SVC(gamma="auto")
svm_scores = cross_val_score(svm_clf, X_train, y_train, cv=10)
svm_scores.mean()
np.float64(0.8249313358302123)
Vẽ biểu đồ hộp (Box plot) để so sánh phân phối điểm số của hai mô hình.
plt.figure(figsize=(8, 4))
plt.plot([1]*10, svm_scores, ".")
plt.plot([2]*10, forest_scores, ".")
plt.boxplot([svm_scores, forest_scores], labels=("SVM", "Random Forest"))
plt.ylabel("Accuracy")
plt.show()
/tmp/ipython-input-3995193800.py:4: MatplotlibDeprecationWarning: The 'labels' parameter of boxplot() has been renamed 'tick_labels' since Matplotlib 3.9; support for the old name will be dropped in 3.11.
plt.boxplot([svm_scores, forest_scores], labels=("SVM", "Random Forest"))

Chúng ta cũng có thể tạo thêm các đặc trưng mới (Feature Engineering) như nhóm tuổi (AgeBucket) hay số lượng người thân đi cùng.
train_data["AgeBucket"] = train_data["Age"] // 15 * 15
train_data[["AgeBucket", "Survived"]].groupby(['AgeBucket']).mean()
Survived
AgeBucket
0.0 0.576923
15.0 0.362745
30.0 0.423256
45.0 0.404494
60.0 0.240000
75.0 1.000000
train_data["RelativesOnboard"] = train_data["SibSp"] + train_data["Parch"]
train_data[["RelativesOnboard", "Survived"]].groupby(
['RelativesOnboard']).mean()
Survived
RelativesOnboard
0 0.303538
1 0.552795
2 0.578431
3 0.724138
4 0.200000
5 0.136364
6 0.333333
7 0.000000
10 0.000000
4. Bộ phân loại Spam (Spam Classifier)
Bài tập: Xây dựng một bộ phân loại email spam sử dụng dataset của Apache SpamAssassin.
Quy trình bao gồm:
- Tải dữ liệu.
- Phân tích cấu trúc email (Header, Body, HTML).
- Tiền xử lý (Chuyển HTML sang text, bỏ dấu câu, stemming…).
- Chuyển đổi thành vector (Bag of Words).
- Huấn luyện mô hình.
import tarfile
def fetch_spam_data():
spam_root = "http://spamassassin.apache.org/old/publiccorpus/"
ham_url = spam_root + "20030228_easy_ham.tar.bz2"
spam_url = spam_root + "20030228_spam.tar.bz2"
spam_path = Path() / "datasets" / "spam"
spam_path.mkdir(parents=True, exist_ok=True)
for dir_name, tar_name, url in (("easy_ham", "ham", ham_url),
("spam", "spam", spam_url)):
if not (spam_path / dir_name).is_dir():
path = (spam_path / tar_name).with_suffix(".tar.bz2")
print("Downloading", path)
urllib.request.urlretrieve(url, path)
tar_bz2_file = tarfile.open(path)
tar_bz2_file.extractall(path=spam_path, filter="data")
tar_bz2_file.close()
return [spam_path / dir_name for dir_name in ("easy_ham", "spam")]
ham_dir, spam_dir = fetch_spam_data()
Downloading datasets/spam/ham.tar.bz2
Downloading datasets/spam/spam.tar.bz2
Tải danh sách file.
ham_filenames = [f for f in sorted(ham_dir.iterdir()) if len(f.name) > 20]
spam_filenames = [f for f in sorted(spam_dir.iterdir()) if len(f.name) > 20]
len(ham_filenames)
2500
len(spam_filenames)
500
Sử dụng thư viện email của Python để phân tích (parse) nội dung.
import email
import email.policy
def load_email(filepath):
with open(filepath, "rb") as f:
return email.parser.BytesParser(policy=email.policy.default).parse(f)
ham_emails = [load_email(filepath) for filepath in ham_filenames]
spam_emails = [load_email(filepath) for filepath in spam_filenames]
Xem thử nội dung một email ham (thư sạch) và spam.
print(ham_emails[1].get_content().strip())
Martin A posted:
Tassos Papadopoulos, the Greek sculptor behind the plan, judged that the
limestone of Mount Kerdylio, 70 miles east of Salonika and not far from the
Mount Athos monastic community, was ideal for the patriotic sculpture.
As well as Alexander's granite features, 240 ft high and 170 ft wide, a
museum, a restored amphitheatre and car park for admiring crowds are
planned
---------------------
So is this mountain limestone or granite?
If it's limestone, it'll weather pretty fast.
------------------------ Yahoo! Groups Sponsor ---------------------~-->
4 DVDs Free +s&p Join Now
http://us.click.yahoo.com/pt6YBB/NXiEAA/mG3HAA/7gSolB/TM
---------------------------------------------------------------------~->
To unsubscribe from this group, send an email to:
forteana-unsubscribe@egroups.com
Your use of Yahoo! Groups is subject to http://docs.yahoo.com/info/terms/
print(spam_emails[6].get_content().strip())
Help wanted. We are a 14 year old fortune 500 company, that is
growing at a tremendous rate. We are looking for individuals who
want to work from home.
This is an opportunity to make an excellent income. No experience
is required. We will train you.
So if you are looking to be employed from home with a career that has
vast opportunities, then go:
http://www.basetel.com/wealthnow
We are looking for energetic and self motivated people. If that is you
than click on the link and fill out the form, and one of our
employement specialist will contact you.
To be removed from our link simple go to:
http://www.basetel.com/remove.html
4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40
Một số email có cấu trúc phức tạp (multipart). Hãy kiểm tra các loại cấu trúc.
def get_email_structure(email):
if isinstance(email, str):
return email
payload = email.get_payload()
if isinstance(payload, list):
multipart = ", ".join([get_email_structure(sub_email)
for sub_email in payload])
return f"multipart({multipart})"
else:
return email.get_content_type()
from collections import Counter
def structures_counter(emails):
structures = Counter()
for email in emails:
structure = get_email_structure(email)
structures[structure] += 1
return structures
structures_counter(ham_emails).most_common()
[('text/plain', 2408),
('multipart(text/plain, application/pgp-signature)', 66),
('multipart(text/plain, text/html)', 8),
('multipart(text/plain, text/plain)', 4),
('multipart(text/plain)', 3),
('multipart(text/plain, application/octet-stream)', 2),
('multipart(text/plain, text/enriched)', 1),
('multipart(text/plain, application/ms-tnef, text/plain)', 1),
('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',
1),
('multipart(text/plain, video/mng)', 1),
('multipart(text/plain, multipart(text/plain))', 1),
('multipart(text/plain, application/x-pkcs7-signature)', 1),
('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',
1),
('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',
1),
('multipart(text/plain, application/x-java-applet)', 1)]
structures_counter(spam_emails).most_common()
[('text/plain', 218),
('text/html', 183),
('multipart(text/plain, text/html)', 45),
('multipart(text/html)', 20),
('multipart(text/plain)', 19),
('multipart(multipart(text/html))', 5),
('multipart(text/plain, image/jpeg)', 3),
('multipart(text/html, application/octet-stream)', 2),
('multipart(text/plain, application/octet-stream)', 1),
('multipart(text/html, text/plain)', 1),
('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),
('multipart(multipart(text/plain, text/html), image/gif)', 1),
('multipart/alternative', 1)]
Kiểm tra Header của email.
for header, value in spam_emails[0].items():
print(header, ":", value)
Return-Path : <12a1mailbot1@web.de>
Delivered-To : zzzz@localhost.spamassassin.taint.org
Received : from localhost (localhost [127.0.0.1]) by phobos.labs.spamassassin.taint.org (Postfix) with ESMTP id 136B943C32 for <zzzz@localhost>; Thu, 22 Aug 2002 08:17:21 -0400 (EDT)
Received : from mail.webnote.net [193.120.211.219] by localhost with POP3 (fetchmail-5.9.0) for zzzz@localhost (single-drop); Thu, 22 Aug 2002 13:17:21 +0100 (IST)
Received : from dd_it7 ([210.97.77.167]) by webnote.net (8.9.3/8.9.3) with ESMTP id NAA04623 for <zzzz@spamassassin.taint.org>; Thu, 22 Aug 2002 13:09:41 +0100
From : 12a1mailbot1@web.de
Received : from r-smtp.korea.com - 203.122.2.197 by dd_it7 with Microsoft SMTPSVC(5.5.1775.675.6); Sat, 24 Aug 2002 09:42:10 +0900
To : dcek1a1@netsgo.com
Subject : Life Insurance - Why Pay More?
Date : Wed, 21 Aug 2002 20:31:57 -1600
MIME-Version : 1.0
Message-ID : <0103c1042001882DD_IT7@dd_it7>
Content-Type : text/html; charset="iso-8859-1"
Content-Transfer-Encoding : quoted-printable
spam_emails[0]["Subject"]
'Life Insurance - Why Pay More?'
Chia dữ liệu train/test.
import numpy as np
from sklearn.model_selection import train_test_split
X = np.array(ham_emails + spam_emails, dtype=object)
y = np.array([0] * len(ham_emails) + [1] * len(spam_emails))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=42)
Hàm chuyển đổi HTML sang văn bản thuần (plain text) bằng Regular Expression.
import re
from html import unescape
def html_to_plain_text(html):
text = re.sub(r'<head.*?>.*?</head>', '', html, flags=re.M | re.S | re.I)
text = re.sub(r'<a\s.*?>', ' HYPERLINK ', text, flags=re.M | re.S | re.I)
text = re.sub(r'<.*?>', '', text, flags=re.M | re.S)
text = re.sub(r'(\s*\n)+', '\n', text, flags=re.M | re.S)
return unescape(text)
Thử nghiệm hàm trên.
html_spam_emails = [email for email in X_train[y_train==1]
if get_email_structure(email) == "text/html"]
sample_html_spam = html_spam_emails[7]
print(sample_html_spam.get_content().strip()[:1000], "...")
<HTML><HEAD><TITLE></TITLE><META http-equiv="Content-Type" content="text/html; charset=windows-1252"><STYLE>A:link {TEX-DECORATION: none}A:active {TEXT-DECORATION: none}A:visited {TEXT-DECORATION: none}A:hover {COLOR: #0033ff; TEXT-DECORATION: underline}</STYLE><META content="MSHTML 6.00.2713.1100" name="GENERATOR"></HEAD>
<BODY text="#000000" vLink="#0033ff" link="#0033ff" bgColor="#CCCC99"><TABLE borderColor="#660000" cellSpacing="0" cellPadding="0" border="0" width="100%"><TR><TD bgColor="#CCCC99" valign="top" colspan="2" height="27">
<font size="6" face="Arial, Helvetica, sans-serif" color="#660000">
<b>OTC</b></font></TD></TR><TR><TD height="2" bgcolor="#6a694f">
<font size="5" face="Times New Roman, Times, serif" color="#FFFFFF">
<b> Newsletter</b></font></TD><TD height="2" bgcolor="#6a694f"><div align="right"><font color="#FFFFFF">
<b>Discover Tomorrow's Winners </b></font></div></TD></TR><TR><TD height="25" colspan="2" bgcolor="#CCCC99"><table width="100%" border="0" ...
print(html_to_plain_text(sample_html_spam.get_content())[:1000], "...")
OTC
Newsletter
Discover Tomorrow's Winners
For Immediate Release
Cal-Bay (Stock Symbol: CBYI)
Watch for analyst "Strong Buy Recommendations" and several advisory newsletters picking CBYI. CBYI has filed to be traded on the OTCBB, share prices historically INCREASE when companies get listed on this larger trading exchange. CBYI is trading around 25 cents and should skyrocket to $2.66 - $3.25 a share in the near future.
Put CBYI on your watch list, acquire a position TODAY.
REASONS TO INVEST IN CBYI
A profitable company and is on track to beat ALL earnings estimates!
One of the FASTEST growing distributors in environmental & safety equipment instruments.
Excellent management team, several EXCLUSIVE contracts. IMPRESSIVE client list including the U.S. Air Force, Anheuser-Busch, Chevron Refining and Mitsubishi Heavy Industries, GE-Energy & Environmental Research.
RAPIDLY GROWING INDUSTRY
Industry revenues exceed $900 million, estimates indicate that there could be as much as $25 billi ...
Hàm tổng hợp để lấy nội dung text từ bất kỳ email nào.
def email_to_text(email):
html = None
for part in email.walk():
ctype = part.get_content_type()
if not ctype in ("text/plain", "text/html"):
continue
try:
content = part.get_content()
except: # phòng trường hợp lỗi encoding
content = str(part.get_payload())
if ctype == "text/plain":
return content
else:
html = content
if html:
return html_to_plain_text(html)
print(email_to_text(sample_html_spam)[:100], "...")
OTC
Newsletter
Discover Tomorrow's Winners
For Immediate Release
Cal-Bay (Stock Symbol: CBYI)
Wat ...
Sử dụng NLTK để thực hiện Stemming (đưa từ về dạng gốc, ví dụ: computing -> comput).
import nltk
stemmer = nltk.PorterStemmer()
for word in ("Computations", "Computation", "Computing", "Computed", "Compute",
"Compulsive"):
print(word, "=>", stemmer.stem(word))
Computations => comput
Computation => comput
Computing => comput
Computed => comput
Compute => comput
Compulsive => compuls
Sử dụng urlextract để phát hiện URL.
# Kiểm tra xem notebook đang chạy trên Colab hay Kaggle
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules
# Nếu chạy trên Colab/Kaggle thì cài đặt urlextract
if IS_COLAB or IS_KAGGLE:
%pip install -q -U urlextract
import urlextract # có thể yêu cầu kết nối mạng để tải danh sách domain
url_extractor = urlextract.URLExtract()
some_text = "Will it detect github.com and https://youtu.be/7Pq-S557XQU?t=3m32s"
print(url_extractor.find_urls(some_text))
['github.com', 'https://youtu.be/7Pq-S557XQU?t=3m32s']
Xây dựng Transformer để biến đổi email thành bộ đếm từ (Word Count).
from sklearn.base import BaseEstimator, TransformerMixin
class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):
def __init__(self, strip_headers=True, lower_case=True,
remove_punctuation=True, replace_urls=True,
replace_numbers=True, stemming=True):
self.strip_headers = strip_headers
self.lower_case = lower_case
self.remove_punctuation = remove_punctuation
self.replace_urls = replace_urls
self.replace_numbers = replace_numbers
self.stemming = stemming
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
X_transformed = []
for email in X:
text = email_to_text(email) or ""
if self.lower_case:
text = text.lower()
if self.replace_urls and url_extractor is not None:
urls = list(set(url_extractor.find_urls(text)))
urls.sort(key=lambda url: len(url), reverse=True)
for url in urls:
text = text.replace(url, " URL ")
if self.replace_numbers:
text = re.sub(r'\d+(?:\.\d*)?(?:[eE][+-]?\d+)?', 'NUMBER', text)
if self.remove_punctuation:
text = re.sub(r'\W+', ' ', text, flags=re.M)
word_counts = Counter(text.split())
if self.stemming and stemmer is not None:
stemmed_word_counts = Counter()
for word, count in word_counts.items():
stemmed_word = stemmer.stem(word)
stemmed_word_counts[stemmed_word] += count
word_counts = stemmed_word_counts
X_transformed.append(word_counts)
return np.array(X_transformed)
X_few = X_train[:3]
X_few_wordcounts = EmailToWordCounterTransformer().fit_transform(X_few)
X_few_wordcounts
array([Counter({'chuck': 1, 'murcko': 1, 'wrote': 1, 'stuff': 1, 'yawn': 1, 'r': 1}),
Counter({'the': 11, 'of': 9, 'and': 8, 'all': 3, 'christian': 3, 'to': 3, 'by': 3, 'jefferson': 2, 'i': 2, 'have': 2, 'superstit': 2, 'one': 2, 'on': 2, 'been': 2, 'ha': 2, 'half': 2, 'rogueri': 2, 'teach': 2, 'jesu': 2, 'some': 1, 'interest': 1, 'quot': 1, 'url': 1, 'thoma': 1, 'examin': 1, 'known': 1, 'word': 1, 'do': 1, 'not': 1, 'find': 1, 'in': 1, 'our': 1, 'particular': 1, 'redeem': 1, 'featur': 1, 'they': 1, 'are': 1, 'alik': 1, 'found': 1, 'fabl': 1, 'mytholog': 1, 'million': 1, 'innoc': 1, 'men': 1, 'women': 1, 'children': 1, 'sinc': 1, 'introduct': 1, 'burnt': 1, 'tortur': 1, 'fine': 1, 'imprison': 1, 'what': 1, 'effect': 1, 'thi': 1, 'coercion': 1, 'make': 1, 'world': 1, 'fool': 1, 'other': 1, 'hypocrit': 1, 'support': 1, 'error': 1, 'over': 1, 'earth': 1, 'six': 1, 'histor': 1, 'american': 1, 'john': 1, 'e': 1, 'remsburg': 1, 'letter': 1, 'william': 1, 'short': 1, 'again': 1, 'becom': 1, 'most': 1, 'pervert': 1, 'system': 1, 'that': 1, 'ever': 1, 'shone': 1, 'man': 1, 'absurd': 1, 'untruth': 1, 'were': 1, 'perpetr': 1, 'upon': 1, 'a': 1, 'larg': 1, 'band': 1, 'dupe': 1, 'import': 1, 'led': 1, 'paul': 1, 'first': 1, 'great': 1, 'corrupt': 1}),
Counter({'url': 4, 's': 3, 'group': 3, 'to': 3, 'in': 2, 'forteana': 2, 'martin': 2, 'an': 2, 'and': 2, 'we': 2, 'is': 2, 'yahoo': 2, 'unsubscrib': 2, 'y': 1, 'adamson': 1, 'wrote': 1, 'for': 1, 'altern': 1, 'rather': 1, 'more': 1, 'factual': 1, 'base': 1, 'rundown': 1, 'on': 1, 'hamza': 1, 'career': 1, 'includ': 1, 'hi': 1, 'belief': 1, 'that': 1, 'all': 1, 'non': 1, 'muslim': 1, 'yemen': 1, 'should': 1, 'be': 1, 'murder': 1, 'outright': 1, 'know': 1, 'how': 1, 'unbias': 1, 'memri': 1, 'don': 1, 't': 1, 'html': 1, 'rob': 1, 'sponsor': 1, 'number': 1, 'dvd': 1, 'free': 1, 'p': 1, 'join': 1, 'now': 1, 'from': 1, 'thi': 1, 'send': 1, 'email': 1, 'egroup': 1, 'com': 1, 'your': 1, 'use': 1, 'of': 1, 'subject': 1})],
dtype=object)
Xây dựng Transformer để biến bộ đếm từ thành Vector (Ma trận thưa).
from scipy.sparse import csr_matrix
class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):
def __init__(self, vocabulary_size=1000):
self.vocabulary_size = vocabulary_size
def fit(self, X, y=None):
total_count = Counter()
for word_count in X:
for word, count in word_count.items():
total_count[word] += min(count, 10)
most_common = total_count.most_common()[:self.vocabulary_size]
self.vocabulary_ = {word: index + 1
for index, (word, count) in enumerate(most_common)}
return self
def transform(self, X, y=None):
rows = []
cols = []
data = []
for row, word_count in enumerate(X):
for word, count in word_count.items():
rows.append(row)
cols.append(self.vocabulary_.get(word, 0))
data.append(count)
return csr_matrix((data, (rows, cols)),
shape=(len(X), self.vocabulary_size + 1))
vocab_transformer = WordCounterToVectorTransformer(vocabulary_size=10)
X_few_vectors = vocab_transformer.fit_transform(X_few_wordcounts)
X_few_vectors
<Compressed Sparse Row sparse matrix of dtype 'int64'
with 20 stored elements and shape (3, 11)>
X_few_vectors.toarray()
array([[ 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[99, 11, 9, 8, 3, 1, 3, 1, 3, 2, 3],
[67, 0, 1, 2, 3, 4, 1, 2, 0, 1, 0]])
vocab_transformer.vocabulary_
{'the': 1,
'of': 2,
'and': 3,
'to': 4,
'url': 5,
'all': 6,
'in': 7,
'christian': 8,
'on': 9,
'by': 10}
Cuối cùng, tạo Pipeline hoàn chỉnh và huấn luyện LogisticRegression.
from sklearn.pipeline import Pipeline
preprocess_pipeline = Pipeline([
("email_to_wordcount", EmailToWordCounterTransformer()),
("wordcount_to_vector", WordCounterToVectorTransformer()),])
X_train_transformed = preprocess_pipeline.fit_transform(X_train)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
log_clf = LogisticRegression(max_iter=1000, random_state=42)
score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3)
score.mean()
np.float64(0.985)
Kết quả trên 98.5%! Hãy kiểm tra lần cuối trên tập test.
from sklearn.metrics import precision_score, recall_score
X_test_transformed = preprocess_pipeline.transform(X_test)
log_clf = LogisticRegression(max_iter=1000, random_state=42)
log_clf.fit(X_train_transformed, y_train)
y_pred = log_clf.predict(X_test_transformed)
print(f"Precision: {precision_score(y_test, y_pred):.2%}")
print(f"Recall: {recall_score(y_test, y_pred):.2%}")
Precision: 96.88%
Recall: 97.89%