Programmation de Noyaux Triton
Vue d'ensemble
Cette skill fournit une référence pratique pour construire des noyaux Triton en production. Elle couvre l'API triton.language, les décorateurs d'autotuning, le modèle de compilation @triton.jit, les workflows de débogage/interpréteur et les benchmarks triton.testing.
Quand utiliser
Utilisez cette skill quand :
- Implémenter un noyau GPU custom en Triton
- Optimiser la latence d'inférence pour des opérations transformer petit batch
- Fusionner des opérations (ex : matmul + activation, attention avec softmax)
- Porter des noyaux CUDA vers Triton pour une maintenance plus facile
Ne pas utiliser pour :
- Les opérations PyTorch standard qui tournent déjà vite (utiliser
torch.compile)
- Les patterns de parallélisme distribué ou multi-GPU
- Les charges de travail liées au CPU
Installation
# Triton est livré avec PyTorch ≥2.0. Installez la dernière version :
pip install -U triton
# Ou compilez depuis la source pour les dernières fonctionnalités :
git clone https://github.com/triton-lang/triton.git
cd triton
pip install -r python/requirements.txt
pip install .
# Pour le profiling :
pip install nvitools # Helpers de profiling NVIDIA
pip install torch_tb_profiler # Profiling PyTorch
Référence de l'API Principale
Décorateur @triton.jit
Compile une fonction Python en noyau GPU. Tout le code doit être du Triton valide (subset de Python + opérations triton.language).
@triton.jit
def kernel( # ← noyau compilé
ptr, # arguments runtime : pointeurs, scalaires
BLOCK: tl.constexpr, # constexpr : baked in au moment de la compilation
):
pid = tl.program_id(axis=0) # Indice programme SPMD
...
triton.language (tl) — Opérations Clés
| Catégorie |
Opération |
Description |
| Indexation |
tl.program_id(axis) |
Indice programme SPMD selon l'axe 0, 1 ou 2 |
| Plages |
tl.arange(start, end) |
Tenseur plage 1D pour l'adressage vectorisé |
| Arithmétique |
tl.sum, tl.max, tl.min, tl.argmax |
Réduction de bloc selon l'axe |
| Arithmétique |
tl.dot(a, b) |
Multiplication matricielle de bloc (déclenche les tensor cores) |
| Activation |
tl.exp, tl.log, tl.sigmoid, tl.tanh |
Math élément par élément |
| Activation |
tl.sqrt, tl.abs, tl.where |
Opérations élément par élément |
| Mémoire |
tl.load(ptr, mask=, other=) |
Chargement vectoriel de la mémoire globale |
| Mémoire |
tl.store(ptr, val, mask=) |
Stockage vectoriel en mémoire globale |
| Mémoire |
tl.atomic_add(ptr, val) |
Ajout atomique (pour les réductions) |
| Cast |
tensor.to(tl.float16) |
Conversion de type |
| Cast |
tl.cast(tensor, tl.float32) |
Conversion de type explicite |
| Débogage |
tl.device_print("x:", x) |
Impression runtime |
| Débogage |
tl.device_assert(cond, "msg") |
Assertion runtime |
| Débogage |
tl.static_print(x) |
Impression compile-time |
| Débogage |
tl.static_assert(cond, "msg") |
Assertion compile-time |
Opérations Mémoire — Bonne Pratique de Masquage
# Toujours masquer les loads/stores pour la sécurité :
mask = offsets < n_elements
x = tl.load(ptr + offsets, mask=mask, other=0.0)
# 'other' fournit une valeur par défaut sûre pour les positions hors limites
# Pour la boucle interne matmul, utiliser other=0.0 pour les tuiles partielles :
a = tl.load(a_ptrs, mask=offsets_k[None, :] < K - k, other=0.0)
b = tl.load(b_ptrs, mask=offsets_k[:, None] < K - k, other=0.0)
Modèles de Noyaux Complets
Modèle 1 : Fusion Élément par Élément (ex : LayerNorm)
@triton.jit
def layernorm_kernel(
input_ptr, output_ptr, weight_ptr, bias_ptr,
row_stride, n_cols, eps,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
row_start = pid * row_stride
offsets = row_start + tl.arange(0, BLOCK_SIZE)
mask = tl.arange(0, BLOCK_SIZE) < n_cols
x = tl.load(input_ptr + offsets, mask=mask, other=0.0)
# Moyenne
mean = tl.sum(x, axis=0) / n_cols
# Variance
x_shifted = x - mean
var = tl.sum(x_shifted * x_shifted, axis=0) / n_cols
# Normalisation
x_norm = x_shifted / tl.sqrt(var + eps)
# Mise à l'échelle + décalage
w = tl.load(weight_ptr + tl.arange(0, BLOCK_SIZE), mask=mask)
b = tl.load(bias_ptr + tl.arange(0, BLOCK_SIZE), mask=mask)
y = x_norm * w + b
tl.store(output_ptr + offsets, y, mask=mask)
Modèle 2 : Softmax de Style Flash Attention avec Calcul Sûr en Ligne
@triton.jit
def fused_attention_kernel(
q_ptr, k_ptr, v_ptr, output_ptr,
stride_qh, stride_qd,
stride_kh, stride_kd,
stride_vh, stride_vd,
stride_oh, stride_od,
H, D,
BLOCK_D: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_h = tl.program_id(0) # indice head
offs_d = tl.arange(0, BLOCK_D)
offs_n = tl.arange(0, BLOCK_N)
# Charger le bloc Q pour cette head
q_ptrs = q_ptr + pid_h * stride_qh + offs_d[:, None] * stride_qd
q = tl.load(q_ptrs) # (BLOCK_D, 1)
# Softmax sûr en ligne sur la séquence KV
m_i = tl.full([BLOCK_N], -float('inf'), dtype=tl.float32)
z_i = tl.zeros([BLOCK_N], dtype=tl.float32)
acc = tl.zeros([BLOCK_D, BLOCK_N], dtype=tl.float32)
for start_n in range(0, N, BLOCK_N):
k_ptrs = k_ptr + pid_h * stride_kh + offs_n[None, :] * stride_kd + start_n * stride_kd
k = tl.load(k_ptrs, mask=offs_n[None, :] < N - start_n, other=0.0)
# S = Q @ K^T
s = tl.dot(q.T, k) # (1, BLOCK_N)
# Softmax sûr en ligne
m_ij = tl.maximum(m_i, s)
p = tl.exp(s - m_ij)
alpha = tl.exp(m_i - m_ij)
acc = acc * alpha + p * k.T # accumulation pondérée
z_i = z_i * alpha + p
m_i = m_i * 0 + m_ij # mise à jour broadcast
output = acc / z_i
# Stocker
out_ptrs = output_ptr + pid_h * stride_oh + offs_d[:, None] * stride_od
tl.store(out_ptrs, output)
Modèle 3 : GEMM FP8 avec Split-K (Optimisé pour l'Inférence)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 4}, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 8}, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'SPLIT_K': 16}, num_warps=8),
],
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': lambda configs, named_args: [
c for c in configs if c.kwargs['BLOCK_SIZE_M'] * c.kwargs['SPLIT_K'] <= 128
],
},
)
@triton.jit
def fp8_gemm_splitk_kernel(
a_ptr, b_ptr, c_ptr, partial_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
k_block_id = pid // num_pid_m
pid_m = pid % num_pid_m
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
offs_k = k_block_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K // SPLIT_K, BLOCK_SIZE_K):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K // SPLIT_K - k, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K // SPLIT_K - k, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Écrire la somme partielle
partial_idx = k_block_id * M + pid_m * BLOCK_SIZE_M
partial_ptrs = partial_ptr + partial_idx
tl.store(partial_ptrs, tl.sum(acc, axis=1)[:, None])
Stratégie d'Autotuning
Quand l'Autotuning Est Essentiel
| Scénario |
Impact Autotune |
| Formes d'entrée variables (VLLM, serving) |
Critique — cache par forme |
| Formes de production fixes |
Exécuter une fois, fixer la config |
| Opérations liées à la mémoire (softmax, normes) |
Moins critique — le pattern d'accès mémoire domine |
| Opérations liées au calcul (GEMM) |
Critique — différence de perf 2–5x entre configs |
Heuristiques de Design de Config
# Règle générale : le produit des dimensions de tuile doit tenir dans les registres
# BLOCK_SIZE_M * BLOCK_SIZE_N * element_size <= register_budget
# Pour NVIDIA A100/H100 (fp16 matmul) :
configs = [
# Équilibré : bon partout
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=3),
# Débit : grandes tuiles pour compute-bound
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_warps=8, num_stages=4),
# Latence : petites tuiles pour memory-bound / petit M
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2),
# AMD MI300X : utiliser moins de warps, peut nécessiter waves_per_eu
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=0),
]
Workflow de Profiling
Étape par Étape : Profiler et Optimiser
# 1. Warmup : exécuter une fois pour déclencher la compilation JIT
output_triton = my_kernel(x, y)
# 2. Benchmark avec triton.testing
import triton.testing
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: my_kernel(x, y),
quantiles=[0.5, 0.2, 0.8],
warmup=100, # itérations
rep=100, # itérations de mesure
)
# 3. Comparer avec la référence
ms_torch, _, _ = triton.testing.do_bench(lambda: torch.matmul(a, b))
# 4. Calculer TFLOPS
tflops = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
print(f"Triton: {tflops(ms):.2f} TFLOPS | Torch: {tflops(ms_torch):.2f} TFLOPS")
Intégration CUDA Graph (Production)
# Après que l'autotuning a sélectionné la meilleure config, capturer un CUDA graph :
import torch
def capture_gemm_graph(a, b):
# Warmup avec la forme de production
_ = triton_matmul(a, b)
torch.cuda.synchronize()
# Capturer le graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
c = triton_matmul(a, b)
return graph, c
# Rejouer pour l'inférence — élimine 1-2ms de surcharge JIT par lancement
graph.replay()
Cheatsheet de Débogage
| Problème |
Symptôme |
Fix |
| Sortie incorrecte |
Off-by-one dans les offsets |
Vérifier la logique mask, utiliser % modulo pour les limites |
| Sortie NaN |
Instabilité numérique |
Soustraire le max avant exp ; vérifier division par zéro |
| Noyau lent (memory-bound) |
Util bande faible |
Augmenter les tailles de tuile, vérifier _b128 dans l'ISA |
| Noyau lent (compute-bound) |
TFLOPS faible |
Vérifier l'utilisation tensor core dans le PTX ; essayer le tuning num_stages |
| Erreur de compilation |
Problème fonction @triton.jit |
Vérifier les constructions Python non supportées (pas de dictionnaires, pas d'indexation dynamique) |
Erreurs compute-sanitizer |
Accès hors limites |
Vérifier la couverture du mask pour les tuiles partielles |
| Surcharge de lancement élevée |
Latence CPU-side |
Utiliser CUDA Graphs pour l'inférence en production |
Portes de Qualité
| Porte |
Commande/Vérification |
Attendu |
| Exactitude |
torch.max(torch.abs(ref - triton_out)) |
< 0,01 (fp16) ou < 0,5 (fp8) |
| Autotuning |
Variable env TRITON_PRINT_AUTOTUNING=1 |
Meilleure config imprimée |
| Utilisation tensor core |
Vérifier PTX pour wgmma/mma |
Présent pour les noyaux matmul |
| Coalescence mémoire |
Vérifier ISA pour global_load_dwordx4 |
Présent dans la boucle critique |
| Utilisation LDS |
grep "triton_gpu.shared" depuis dump MLIR |
< 64 KB |
| Occupancy |
Calculer depuis les comptages VGPR/LDS |
> 50 % pour compute-bound |
| Speedup |
triton.testing.do_bench |
> 1,5x par rapport à PyTorch naïf |
Références Croisées
- Guideline
triton-kernel-build-design — patterns de design complets, hiérarchie mémoire et référence d'optimisation
- Tutoriels officiels : https://triton-lang.org/main/getting-started/tutorials/
dataset-curation-manifest — quand construire des noyaux de chargement de données
embedding-analysis — pour comprendre les patterns de calcul d'embeddings
Références