[DL101] Chương 9: Tăng tốc Transformers
Các kỹ thuật tối ưu hóa và tăng tốc mô hình Transformer
Bài viết có tham khảo, sử dụng và sửa đổi tài nguyên từ kho lưu trữ handson-mlp, tuân thủ giấy phép Apache‑2.0. Chúng tôi chân thành cảm ơn tác giả Aurélien Géron (@aureliengeron) vì sự chia sẻ kiến thức tuyệt vời và những đóng góp quý giá cho cộng đồng.
Ở chương này, chúng ta sẽ đi sâu vào những phương án tối ưu hóa giúp các mô hình ngôn ngữ lớn (LLMs) hoạt động nhanh hơn và hiệu quả hơn.
Transformer đã cách mạng hóa AI, nhưng chi phí tính toán và bộ nhớ của nó là rào cản lớn. Tại sao chúng lại tốn kém đến vậy?
- Cơ chế Attention : Khi chuỗi đầu vào dài gấp đôi, chi phí tính toán tăng gấp 4 lần.
- Autoregressive Decoding: Việc sinh từng từ một (tuần tự) không tận dụng hết khả năng song song của GPU.
- Kích thước khổng lồ: Các mô hình như GPT-3 có 175 tỷ tham số, yêu cầu hàng trăm GB VRAM chỉ để chứa trọng số.
Trong chương này, chúng ta sẽ giải quyết từng vấn đề bằng các giải pháp kỹ thuật tiên tiến.
Bạn có thể chạy trực tiếp các đoạn mã code tại: Google Colab.
Thiết lập môi trường (Setup)
Dự án này yêu cầu Python phiên bản 3.10 trở lên để đảm bảo tương thích với các thư viện mới nhất:
import sys
# Kiểm tra phiên bản Python
assert sys.version_info >= (3, 10)
Chúng ta cần xác định xem code đang chạy trên môi trường nào (Google Colab hay Kaggle) để có các hướng dẫn cấu hình phần cứng phù hợp:
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules
Chúng ta cũng cần PyTorch phiên bản ≥ 2.6.0 để sử dụng các tính năng tối ưu hóa mới nhất như scaled_dot_product_attention (SDPA):
from packaging.version import Version
import torch
# Đảm bảo PyTorch đủ mới
assert Version(torch.__version__) >= Version("2.6.0")
Chương này liên quan nhiều đến tính toán ma trận nặng, vì vậy việc có GPU là rất quan trọng. Đoạn code dưới đây sẽ tự động phát hiện thiết bị (CUDA cho NVIDIA GPU, MPS cho Apple Silicon, hoặc CPU):
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
output:
Using device: cuda
Cảnh báo người dùng nếu không tìm thấy bộ tăng tốc phần cứng, vì việc thực thi trên CPU sẽ rất chậm:
if device == "cpu":
print("Cảnh báo: Mạng nơ-ron có thể rất chậm nếu không có bộ tăng tốc phần cứng.")
if IS_COLAB:
print("Vào Runtime > Change runtime type và chọn GPU.")
if IS_KAGGLE:
print("Vào Settings > Accelerator và chọn GPU.")
Thiết lập kích thước phông chữ mặc định cho thư viện matplotlib để các biểu đồ hiển thị rõ ràng hơn:
import matplotlib.pyplot as plt
plt.rc('font', size=14)
plt.rc('axes', labelsize=14, titlesize=14)
plt.rc('legend', fontsize=14)
plt.rc('xtick', labelsize=10)
plt.rc('ytick', labelsize=10)
1. Tăng tốc độ giải mã tại thời điểm suy luận (Faster Decoding at Inference Time)
Quá trình sinh văn bản (text generation) trong các mô hình ngôn ngữ là tự hồi quy (autoregressive). Nghĩa là, để sinh ra token thứ , mô hình cần đầu vào là toàn bộ chuỗi từ token đến . Điều này tạo ra một nút thắt cổ chai về hiệu năng khi độ dài chuỗi tăng lên.
1.1. Key/Value Caching (KV Caching)
Lý thuyết
Trong cơ chế Self-Attention, tại mỗi bước thời gian , chúng ta tính toán:
Nếu không sử dụng caching, tại bước , chúng ta phải tính lại và cho tất cả các token từ đến , mặc dù chúng không thay đổi. Điều này dẫn đến độ phức tạp tính toán là cho toàn bộ quá trình sinh.
KV Caching giải quyết vấn đề này bằng cách lưu trữ các ma trận Key và Value của các token trước đó trong bộ nhớ GPU. Tại bước , chúng ta chỉ cần tính và cho token mới nhất, sau đó nối (concat) vào cache. Điều này giảm độ phức tạp tính toán xuống , nhưng đổi lại là chi phí bộ nhớ tăng lên (VRAM usage).
Thực hành
Dưới đây, chúng ta so sánh tốc độ sinh văn bản khi bật và tắt use_cache.
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "facebook/opt-125m"
# Tải mô hình và tokenizer, tự động phân bổ vào thiết bị (GPU/CPU)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Once upon a time there lived"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# So sánh hiệu năng giữa việc dùng cache và không dùng cache
for use_cache in (True, False):
print(f"{use_cache=}")
# %time là magic command của Jupyter để đo thời gian thực thi
%time model.generate(**inputs, max_new_tokens=500, do_sample=False, use_cache=use_cache)
output:
use_cache=True
CPU times: user 4.6 s, sys: 247 µs, total: 4.6 s
Wall time: 4.6 s
use_cache=False
CPU times: user 6.16 s, sys: 0 ns, total: 6.16 s
Wall time: 6.15 s
1.2. Speculative Decoding (Giải mã đầu cơ)
Lý thuyết
Các mô hình lớn (Target Model) thường bị giới hạn bởi băng thông bộ nhớ (memory-bound) khi sinh từng token một. Ý tưởng của Speculative Decoding là sử dụng một mô hình nhỏ hơn, chạy nhanh hơn (Draft Model) để sinh ra trước một đoạn ngắn (ví dụ: 5 token). Sau đó, mô hình lớn sẽ kiểm tra xem các token này có hợp lệ không trong một lần chạy (parallel verification).
Giả sử mô hình nhỏ sinh ra token với xác suất và mô hình lớn xác nhận với xác suất .
- Nếu , token được chấp nhận.
- Nếu , token bị từ chối với xác suất và quy trình dừng lại để lấy mẫu lại từ mô hình lớn.
Phương pháp này không làm thay đổi phân phối xác suất đầu ra của mô hình lớn nhưng có thể tăng tốc độ lên 2-3 lần.
Thực hành
Chúng ta sẽ sử dụng opt-350m làm mô hình đích và opt-125m làm mô hình nháp.
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import set_seed
set_seed(42)
# Mô hình lớn (Target Model)
target_model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m",
device_map="auto")
# Mô hình nhỏ (Draft Model)
draft_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m",
device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
prompt = "Once upon a time there lived"
inputs = tokenizer(prompt, return_tensors="pt").to(target_model.device)
# Sử dụng tham số `assistant_model` để kích hoạt Speculative Decoding
outputs = target_model.generate(**inputs, max_new_tokens=100, do_sample=True,
temperature=1, assistant_model=draft_model)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)
output:
Once upon a time there lived in an enchanted land known as Africa...
2. Tăng tốc Multi-Head Attention (Boosting Multi-Head Attention)
Độ phức tạp của Attention tiêu chuẩn là với là độ dài chuỗi. Điều này khiến việc xử lý văn bản dài trở nên cực kỳ tốn kém. Các phương pháp xấp xỉ (Approximate Attention) ra đời để giảm độ phức tạp này.
2.1. BigBird: Sparse Attention (Attention thưa)
Lý thuyết
BigBird thay thế cơ chế Full Attention (tất cả token nhìn thấy nhau) bằng sự kết hợp của 3 thành phần thưa:
- Random Attention: Mỗi token chú ý đến một số token ngẫu nhiên khác.
- Window Attention: Mỗi token chú ý đến các token lân cận (cửa sổ trượt).
- Global Attention: Một số token đặc biệt (như
[CLS]) được phép chú ý đến toàn bộ chuỗi và ngược lại.
Cấu trúc này biến ma trận Attention thành ma trận thưa, giảm độ phức tạp xuống . BigBird có khả năng xử lý chuỗi dài gấp 8 lần so với BERT tiêu chuẩn.
from transformers import pipeline
model_id = "google/bigbird-roberta-base"
# Pipeline tự động xử lý việc tokenization và inference
fill_mask = pipeline(task="fill-mask", model=model_id)
# Thử nghiệm điền từ vào chỗ trống
fill_mask("She was feeling unwell so she took some [MASK] medicine.")
output:
[{'score': 0.28187209367752075,
'token': 2457,
'token_str': 'pain',
'sequence': 'She was feeling unwell so she took some pain medicine.'},
...]
2.2. Reformer: LSH Attention (Locality-Sensitive Hashing)
Lý thuyết
Reformer dựa trên quan sát rằng bị chi phối bởi các phần tử lớn nhất. Nghĩa là vector Query chỉ cần chú ý đến các vector Key gần nó nhất trong không gian vector.
Để tìm các láng giềng gần nhất (Nearest Neighbors) nhanh chóng mà không cần duyệt toàn bộ, Reformer sử dụng Locality-Sensitive Hashing (LSH). Ý tưởng là sử dụng các phép xoay ngẫu nhiên (Random Rotations) để gán các vector gần nhau vào cùng một “bucket” (thùng chứa) với xác suất cao.
Hàm băm góc (Angular LSH) được định nghĩa như sau: Trong đó là ma trận ngẫu nhiên.
Thực hành
Dưới đây là cài đặt đơn giản của Angular LSH:
import torch
import torch.nn as nn
import torch.nn.functional as F
def angular_lsh(vectors, k):
# Tạo ma trận xoay ngẫu nhiên R
# vectors.size(-1): kích thước chiều vector (d_model)
# k // 2: số lượng siêu phẳng phân chia
R = torch.randn(vectors.size(-1), k // 2, device=vectors.device)
# Chuẩn hóa vector đầu vào về độ dài đơn vị
normalized_vectors = F.normalize(vectors, p=2.0, dim=1)
# Chiếu vector lên các siêu phẳng ngẫu nhiên
V_proj = normalized_vectors @ R
# Nối với phần đối để xác định vùng (bucket)
V_concat = torch.cat([V_proj, -V_proj], dim=1)
# Lấy chỉ số của thành phần lớn nhất làm mã băm (bucket id)
return torch.argmax(V_concat, dim=1)
torch.manual_seed(42)
vectors = torch.rand(16, 512) # 16 vector, mỗi vector 512 chiều
angular_lsh(vectors, k=4) # Phân vào 4 bucket
output:
tensor([2, 2, 0, 3, 0, 2, 2, 2, 2, 1, 1, 3, 3, 1, 2, 1])
2.3. Performer: Kernel Attention
Lý thuyết
Attention tiêu chuẩn tính . Chúng ta phải tính (kích thước ) trước rồi mới nhân với . Performer sử dụng phương pháp Kernel Trick để xấp xỉ hàm softmax: Nhờ tính chất kết hợp của phép nhân ma trận, ta có thể tính trước (kích thước ), giảm độ phức tạp từ xuống .
Performer sử dụng tính chất: Kỳ vọng của với là vector ngẫu nhiên Gaussian chính là hạt nhân Gaussian:
Từ đó, hàm đặc trưng được xây dựng dựa trên các đặc trưng ngẫu nhiên trực giao dương (Positive Orthogonal Random Features - FAVOR+).
Hãy kiểm chứng lý thuyết trên bằng PyTorch:
torch.manual_seed(42)
d, m = 64, 1024 # d: chiều dữ liệu, m: số lượng mẫu ngẫu nhiên
W = torch.randn(d, m)
X = torch.randn(5, d) / d ** 0.5
# Tính exp(XW) thực tế
R = torch.exp(X @ W)
# So sánh trung bình của R với giá trị lý thuyết exp(0.5 * ||x||^2)
print("Thực nghiệm:", R.mean(axis=-1))
print("Lý thuyết: ", torch.exp(0.5 * (X.norm(dim=-1)**2)))
output:
Thực nghiệm: tensor([1.5851, 1.6703, 1.6728, 1.9729, 1.7934])
Lý thuyết: tensor([1.6516, 1.7578, 1.6594, 1.8050, 1.7751])
Kết quả rất gần nhau! Bây giờ, chúng ta định nghĩa hàm để xấp xỉ softmax attention:
Việc trừ đi giá trị max là để ổn định số học (tránh tràn số).
def phi(X, W, dim_subtract_max=(-2, -1)):
# Tính bình phương norm của x
squared_norms = X.square().sum(dim=-1, keepdim=True)
# Chiếu x lên các chiều ngẫu nhiên W
X_proj = X @ W
# Trừ max để ổn định số học
max_vals = X_proj.amax(dim=dim_subtract_max, keepdim=True)
# Công thức FAVOR+ (cơ bản)
return torch.exp(X_proj - max_vals - squared_norms / 2) / W.size(-1) ** 0.5
Về mặt lý thuyết, . Hãy kiểm tra độ chính xác của xấp xỉ này đối với Attention matrix (chưa chuẩn hóa):
torch.manual_seed(42)
batch_size = 32
d_model = 512
n_heads = 8
Lq = 200 # Độ dài Query
Lk = 100 # Độ dài Key
m = 256 # Số lượng đặc trưng ngẫu nhiên
d_head = d_model // n_heads
# Tạo W ngẫu nhiên
W = torch.randn(n_heads, d_head, m)
Q = torch.randn(batch_size, n_heads, Lq, d_head)
K = torch.randn(batch_size, n_heads, Lk, d_head)
scale = 1 / d_head ** 0.5
# Áp dụng hàm phi
Qp = phi(Q * scale ** 0.5, W, dim_subtract_max=-1)
Kp = phi(K * scale ** 0.5, W)
# Tính Attention xấp xỉ
A = Qp @ Kp.transpose(-2, -1)
D = A.sum(dim=-1, keepdim=True) # Mẫu số để chuẩn hóa (tương tự softmax)
result = A / (D + 1e-6)
# Tính Attention chính xác (Ground Truth)
expected = torch.softmax(Q @ K.transpose(-2, -1) * scale, dim=-1)
# Tính sai số RMSE
rmse = F.mse_loss(result, expected) ** 0.5
print(f"RMSE ban đầu: {rmse:.4f}")
output:
RMSE ban đầu: 0.0171
Sai số tương đối tốt. Để cải thiện hơn nữa, Performer sử dụng Trực giao hóa (Orthogonalization) ma trận . Điều này giúp các vector ngẫu nhiên quét không gian hiệu quả hơn, giảm phương sai của bộ ước lượng.
def orthogonalize(W):
d_head = W.size(-2)
# Sử dụng phân rã QR để trực giao hóa từng khối của W
W_orth = torch.cat([torch.linalg.qr(W_chunk)[0]
for W_chunk in W.split(d_head, dim=-1)], dim=-1)
return W_orth * d_head ** 0.5
# Tạo W trực giao
W_orth = orthogonalize(W)
Hãy xem sai số thay đổi thế nào sau khi trực giao hóa:
Qp2 = phi(Q * scale ** 0.5, W_orth, dim_subtract_max=-1)
Kp2 = phi(K * scale ** 0.5, W_orth)
A2 = Qp2 @ Kp2.transpose(-2, -1)
D2 = A2.sum(dim=-1, keepdim=True)
result2 = A2 / (D2 + 1e-6)
rmse2 = F.mse_loss(result2, expected) ** 0.5
print(f"RMSE sau khi trực giao hóa: {rmse2:.4f}")
output:
RMSE sau khi trực giao hóa: 0.0160
Sai số đã giảm xuống. Bây giờ chúng ta sẽ đóng gói toàn bộ logic này vào một module PyTorch hoàn chỉnh mang tên FavorAttention. Điểm mấu chốt ở đây là thay đổi thứ tự tính toán:
class FavorAttention(nn.Module):
def __init__(self, d_model, n_heads, n_features):
super().__init__()
self.d_head = d_model // n_heads
# Khởi tạo trọng số ngẫu nhiên W cố định (buffer)
W = torch.randn(n_heads, self.d_head, n_features) # kích thước: [h, d, m]
W = orthogonalize(W)
self.register_buffer("W", W)
def forward(self, Q, K, V):
scale = self.d_head ** -0.25
# Tính phi(Q) và phi(K)
Qp = phi(Q * scale, self.W, dim_subtract_max=-1)
Kp = phi(K * scale, self.W)
# Tính mẫu số chuẩn hóa: D = Q' @ sum(K')
# Kp.sum(dim=-2) giúp giảm chiều Lk, đưa về O(N)
D = Qp @ Kp.sum(dim=-2).unsqueeze(-1) # B, h, Lq, 1
# Tính tử số: (Q' @ (K'^T @ V))
# Tính K'^T @ V trước: kích thước [B, h, m, d], không phụ thuộc N
Kp_T_V = Kp.transpose(-2, -1) @ V
# Nhân Q' với kết quả trên
return (Qp @ Kp_T_V) / (D + 1e-6)
Kiểm tra module FavorAttention:
torch.manual_seed(42)
Q = torch.randn(batch_size, n_heads, Lq, d_head)
K = torch.randn(batch_size, n_heads, Lk, d_head)
V = torch.randn(batch_size, n_heads, Lk, d_head)
favor_attn = FavorAttention(d_model, n_heads, 256)
approx_attn = favor_attn(Q, K, V)
import torch.nn.functional as F
# So sánh với Scaled Dot Product Attention chuẩn của PyTorch
attn = F.scaled_dot_product_attention(Q, K, V)
attn_rmse = F.mse_loss(approx_attn, attn) ** 0.5
print(f"Final RMSE: {attn_rmse:.4f}")
output:
Final RMSE: 0.1599
3. Chia sẻ Projections trong Multi-Head Attention
Trong Multi-Head Attention (MHA) tiêu chuẩn, mỗi đầu (head) có một bộ ma trận riêng biệt. Điều này tạo áp lực lớn lên băng thông bộ nhớ khi tải trọng số trong quá trình suy luận (memory bandwidth bound).
3.1. Multi-Query Attention (MQA)
Lý thuyết
Multi-Query Attention đề xuất sử dụng duy nhất một đầu cho Key và Value (), nhưng vẫn giữ nhiều đầu cho Query.
- MHA: có kích thước
- MQA: có kích thước
Điều này giảm đáng kể kích thước KV Cache và tốc độ truy cập bộ nhớ, giúp tăng tốc độ suy luận, dù có thể giảm nhẹ chất lượng mô hình.
Thực hành
Trong PyTorch, ta có thể mô phỏng điều này bằng cách broadcast hoặc sử dụng tham số trong scaled_dot_product_attention.
batch_size, Lq, Lk, d_head = 32, 100, 90, 64
n_heads = 8
n_groups = 1 # Chỉ có 1 nhóm K, V cho tất cả các đầu Query
query = torch.randn(batch_size, n_heads, Lq, d_head)
key = torch.randn(batch_size, n_groups, Lk, d_head)
value = torch.randn(batch_size, n_groups, Lk, d_head)
# enable_gqa=True cho phép tự động broadcast K, V nếu kích thước đầu không khớp
attn = F.scaled_dot_product_attention(query, key, value, enable_gqa=True)
print("MQA Output shape:", attn.shape)
output:
MQA Output shape: torch.Size([32, 8, 100, 64])
3.2. Grouped-Query Attention (GQA)
Lý thuyết
Grouped-Query Attention là điểm cân bằng giữa MHA và MQA. Thay vì chia sẻ cho tất cả các đầu Query (như MQA) hoặc không chia sẻ gì (như MHA), GQA chia các đầu Query thành nhóm, và mỗi nhóm chia sẻ chung một cặp .
Ví dụ: 8 đầu Query chia làm 2 nhóm (). Mỗi nhóm 4 đầu Query dùng chung 1 Key/Value head. Đây là phương pháp được sử dụng trong Llama-2 và Llama-3 để cân bằng giữa chất lượng và tốc độ.
batch_size, Lq, Lk, d_head = 32, 100, 90, 64
n_heads = 8
n_groups = 2 # Chia thành 2 nhóm
query = torch.randn(batch_size, n_heads, Lq, d_head)
key = torch.randn(batch_size, n_groups, Lk, d_head)
value = torch.randn(batch_size, n_groups, Lk, d_head)
# PyTorch xử lý việc khớp 8 queries với 2 keys/values
attn = F.scaled_dot_product_attention(query, key, value, enable_gqa=True)
print("GQA Output shape:", attn.shape)
output:
GQA Output shape: torch.Size([32, 8, 100, 64])
3.3. FlashAttention
Lý thuyết
FlashAttention là một bước đột phá về IO-aware (nhận thức In/Out). Vấn đề của Attention truyền thống không chỉ là phép tính, mà là lần truy cập bộ nhớ HBM (High Bandwidth Memory - bộ nhớ chậm) để đọc/ghi ma trận Attention score ().
FlashAttention sử dụng kỹ thuật Tiling (chia khối) để tính toán Attention từng phần ngay trong bộ nhớ SRAM (bộ nhớ cache nhanh của GPU) mà không bao giờ ghi ma trận ra HBM.
Để làm được điều này, nó sử dụng Online Softmax, cho phép cập nhật giá trị softmax dần dần khi có dữ liệu mới mà không cần nhìn thấy toàn bộ hàng.
Thực hành
Dưới đây là một bản cài đặt đơn giản bằng Python mô phỏng thuật toán FlashAttention (trên thực tế FlashAttention được viết bằng CUDA C++).
def flash_attention(Q, K, V, block_size_q, block_size_k):
Lq, d = Q.shape
Lk, _ = K.shape
O = torch.zeros_like(Q)
scale = d ** -0.5
# Vòng lặp ngoài: Duyệt qua các khối Query
for i_start in range(0, Lq, block_size_q):
i_end = min(i_start + block_size_q, Lq)
Q_block = Q[i_start:i_end]
# Khởi tạo các biến tích lũy cho khối Query hiện tại
O_block = torch.zeros(block_size_q, d, device=Q.device)
l_block = torch.zeros(block_size_q, 1, device=Q.device) # Mẫu số (li)
m_block = -torch.inf * torch.ones(block_size_q, 1, device=Q.device) # Max (mi)
# Vòng lặp trong: Duyệt qua các khối Key/Value
for j_start in range(0, Lk, block_size_k):
j_end = min(j_start + block_size_k, Lk)
K_block = K[j_start:j_end]
V_block = V[j_start:j_end]
# 1. Tính Attention Score cho khối (Q_i, K_j)
S_ij = Q_block @ K_block.T * scale
# 2. Cập nhật giá trị Max (m_new) dùng cho ổn định số học
m_ij_new, _ = torch.max(S_ij, dim=1, keepdim=True)
m_block_new = torch.maximum(m_block, m_ij_new)
# 3. Tính P_ij (phần tử mũ chưa chuẩn hóa)
# Ta trừ m_block_new để tránh tràn số (exp overflow)
P_ij = torch.exp(S_ij - m_block_new)
# Hệ số điều chỉnh cho các giá trị tích lũy cũ (do m_block thay đổi)
correction_factor = torch.exp(m_block - m_block_new)
# 4. Cập nhật mẫu số (l) và tử số (O)
l_block_new = ((l_block * correction_factor)
+ torch.sum(P_ij, dim=1, keepdim=True))
O_block = (O_block * correction_factor) + (P_ij @ V_block)
# Lưu trạng thái cho vòng lặp tiếp theo
l_block = l_block_new
m_block = m_block_new
# 5. Kết thúc duyệt Key, chuẩn hóa kết quả cuối cùng
O[i_start:i_end] = O_block / l_block
return O
Thử nghiệm code FlashAttention “tự chế” và so sánh với kết quả chuẩn:
torch.manual_seed(42)
block_size_q, block_size_k = 64, 64
Lq, Lk, d = 1280, 1152, 512
Q = torch.randn(Lq, d)
K = torch.randn(Lk, d)
V = torch.randn(Lk, d)
# Chạy FlashAttention thủ công
R1 = flash_attention(Q, K, V, block_size_q, block_size_k)
# Chạy chuẩn PyTorch
R2 = F.scaled_dot_product_attention(Q, K, V)
# Tính sai số
mse = F.mse_loss(R1, R2)
print(f"MSE Loss: {mse:.10f}")
output:
MSE Loss: 0.0000000000
4. Mở rộng quy mô với Mixture of Experts (MoE) và LoRA
4.1. Lý thuyết: Mixture of Experts (MoE)
Để mở rộng mô hình lên kích thước khổng lồ (nghìn tỷ tham số) mà chi phí tính toán không tăng tương ứng, chúng ta sử dụng Mixture of Experts (MoE).
Thay vì một mạng Feed-Forward (FFN) dày đặc, MoE thay thế bằng nhiều “chuyên gia” (experts) nhỏ. Một mạng cổng (Router/Gate) sẽ quyết định gửi token đến chuyên gia nào. Ví dụ: chỉ 2 trong số 8 chuyên gia được kích hoạt cho mỗi token (Sparse Activation). Điều này giúp mô hình có dung lượng tri thức lớn (nhiều tham số) nhưng chi phí suy luận thấp (ít tính toán).
4.2. Parameter-Efficient Fine-Tuning (PEFT) với LoRA
Một cách khác để xử lý các mô hình lớn một cách hiệu quả là sử dụng các kỹ thuật tinh chỉnh tham số hiệu quả (PEFT). LoRA (Low-Rank Adaptation) là kỹ thuật phổ biến nhất.
Thay vì cập nhật toàn bộ ma trận trọng số (rất tốn kém), LoRA giả định rằng sự thay đổi trọng số có hạng thấp (low-rank). Ta phân rã: Trong đó với . Ta chỉ huấn luyện và . Code dưới đây minh họa cách áp dụng LoRA:
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "EleutherAI/gpt-neo-125M"
# Load mô hình với độ chính xác float16 để tiết kiệm bộ nhớ
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",
dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# Cấu hình LoRA
lora_config = LoraConfig(
r=16, # Rank của ma trận cập nhật (càng lớn càng mạnh nhưng tốn VRAM)
lora_alpha=32, # Hệ số scale
target_modules=["q_proj", "v_proj"], # Chỉ áp dụng LoRA cho Query và Value projection
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Áp dụng LoRA vào mô hình
peft_model = get_peft_model(model, lora_config)
# In ra số lượng tham số cần huấn luyện (thường < 1% tổng tham số)
peft_model.print_trainable_parameters()
output:
trainable params: 589,824 || all params: 125,788,416 || trainable%: 0.4689
import torch
from datasets import load_dataset
from transformers import (TrainingArguments, Trainer,
DataCollatorForLanguageModeling)
# Phần này là import thư viện chuẩn bị cho việc huấn luyện (Trainer API)
5. Huấn luyện nhanh hơn (Faster Training)
5.1. Gradient Accumulation (Tích lũy Gradient)
Lý thuyết
Để huấn luyện mô hình tốt, chúng ta thường cần kích thước batch (Batch Size) lớn. Tuy nhiên, GPU thường bị giới hạn bộ nhớ VRAM. Gradient Accumulation là kỹ thuật giúp mô phỏng batch size lớn trên GPU nhỏ.
Thay vì cập nhật trọng số (optimizer.step()) sau mỗi batch nhỏ, chúng ta:
- Tính loss và gradient cho batch nhỏ.
- Cộng dồn gradient đó lại nhưng chưa cập nhật trọng số.
- Lặp lại bước 1-2 trong bước (accumulation steps).
- Cập nhật trọng số một lần bằng gradient tổng hợp.
Lưu ý: Cần chia nhỏ loss (loss / accumulation_steps) để trung bình gradient đúng.
# Mô hình đồ chơi đơn giản
model = torch.nn.Linear(10, 1).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# Dữ liệu giả lập
data_loader = [(torch.randn(8, 10), torch.randn(8, 1)) for _ in range(100)]
model.train()
accumulation_steps = 4 # Tích lũy 4 bước nhỏ thành 1 bước lớn
optimizer.zero_grad() # Reset gradient trước khi bắt đầu
for batch_index, (X_batch, y_batch) in enumerate(data_loader):
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
# Forward pass
y_pred = model(X_batch)
loss = criterion(y_pred, y_batch)
# CHÌA KHÓA: Chia nhỏ loss trước khi backward
loss = loss / accumulation_steps
loss.backward()
# Chỉ cập nhật trọng số sau mỗi accumulation_steps
if (batch_index + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()