Écriture de noyaux Triton
Principes
Correction d'abord
- Ne jamais benchmarker avant que la vérification réussisse.
- Toujours masquer les lectures et écritures pour les formes non divisibles.
- Inclure les exports
kernel_fn,reference_fnetget_inputs()pour les scripts compagnons. - Toujours exécuter
scripts/verify_kernel.pypour valider par rapport à la référence.
Règles de précision FP16/BF16 (LIBERTÉ FAIBLE -- suivre exactement)
Les fonctions transcendantales (tl.exp, tl.log, tl.math.erf, tl.math.tanh) nécessitent des entrées fp32.
# FAUX -- erreur de compilation ou résultats incorrects avec fp16/bf16 :
result = tl.exp(x)
# CORRECT -- convertir en fp32, calculer, reconvertir :
x_fp32 = x.to(tl.float32)
result = tl.exp(x_fp32).to(x.dtype)
Règle : toute fonction mathématique au-delà de l'arithmétique basique (+, -, *, /) nécessite une conversion fp32 en entrée, et une conversion au dtype d'origine en sortie.
Contraintes de précision supplémentaires :
tl.sigmoid()n'est pas disponible dans certaines versions de Triton. Utiliser1.0 / (1.0 + tl.exp(-x_fp32)).- Toujours reconvertir en
x.dtypeavanttl.store-- les incompatibilités causent "Type mismatch, store Float32 to Float16". - Contrairement à PyTorch, Triton N'élève PAS automatiquement fp16/bf16 en fp32 pour l'accumulation. Toujours utiliser des accumulateurs
tl.float32pourtl.dot. - TF32 pour matmul : Sur Ampere+/Hopper,
tl.dotutilise TF32 par défaut pour les entrées fp32 (commetorch.mm). Ne PAS ajouterinput_precision="ieee"-- c'est 3-8x plus lent. TF32 est le défaut correct. Si la vérification échoue à cause de la précision TF32 (~0,01-0,1 diff abs), assurer quereference_fnutilise aussi TF32 (torch.mmsimple, sansallow_tf32=False).
Éviter la synchronisation CPU-GPU (LIBERTÉ FAIBLE)
Ne jamais appeler .item() dans les wrappers de noyau. Cela force une synchronisation CPU-GPU (~50-100us par appel).
| Piège | Correction |
|---|---|
tensor.item() pour seed |
x.data_ptr() % (2**31) |
torch.randint(...).item() |
Utiliser les métadonnées tensor pour un seed pseudo-aléatoire |
| Allocation de sortie à chaque appel | Accepter la sortie pré-allouée comme paramètre |
| Boucles Python appelant le noyau | Opérations batch |
Sémantique de division d'entiers C (CRITIQUE)
Triton utilise la sémantique C (arrondir vers zéro) pour // et %, NOT la sémantique Python (arrondir vers l'infini négatif). Cela n'a d'importance que quand les opérandes peuvent être négatifs.
| Expression | Python | Triton/C |
|---|---|---|
-7 // 2 |
-4 |
-3 |
-7 % 2 |
1 |
-1 |
Motif sûr : Assurer que tous les index/offset sont non-négatifs. Si des valeurs négatives sont possibles, utiliser (idx % BLOCK + BLOCK) % BLOCK.
Voir references/concepts-semantics.md pour les règles complètes et l'exception scalaire uniquement.
Modèle mental de conception de noyau
- Axe de parallélisation : Les noyaux élément par élément se parallélisent sur les éléments aplatis. Les noyaux ligne par ligne (LayerNorm, softmax) se parallélisent sur les lignes. Les noyaux matmul se divisent en tuiles 2D (M, N).
- Taille de bloc : Puissance de 2 uniquement (256, 512, 1024, 2048). Commencer avec 1024 pour H100, 512 pour V100.
- Coalescence mémoire : Les threads adjacents doivent accéder à des adresses mémoire adjacentes. Le compilateur gère cela automatiquement à partir de l'arithmétique des pointeurs au niveau du bloc.
- Grille : Utiliser
triton.cdiv(n_elements, BLOCK_SIZE). Avec autotune, la grille doit être une lambda :lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),). - Ordre des décorateurs :
@triton.autotune(le plus extérieur) ->@triton.heuristics->@triton.jit(le plus intérieur). reset_to_zero: Requis pour autotune sur les noyaux qui accumulent (par ex., sortie matmul). Sans cela, les configurations ultérieures voient des valeurs résiduelles des essais antérieurs.
Flux de travail
Chemin rapide : Si l'utilisateur demande explicitement un noyau Triton (par ex., "Écris un noyau Triton pour X", "Implémente softmax en Triton"), commencer à la Phase 2. N'utiliser les Phases 0-1 que quand la demande est ambiguë sur l'appropriatesse de Triton.
Phase 0 : Acheminer l'opérateur (uniquement pour les demandes ambiguës)
Passer cette phase si l'utilisateur demande explicitement un noyau Triton. N'utiliser que quand la demande est ambiguë (par ex., "optimise cette opération").
Triton gagne quand 2+ opérations peuvent partager des registres au lieu d'écrire/lire la mémoire globale. Règles rapides :
| Motif | Décision |
|---|---|
Opération élément par élément simple (relu, sigmoid) |
SKIP -- PyTorch est déjà optimal |
| Matmul autonome | SKIP -- cuBLAS est optimisé |
| Attention standard | SKIP -- Utiliser FlashAttention |
| Chaîne élément par élément (2+ ops), réduction, matmul + épilogue | UTILISER TRITON |
Si SKIP, recommander l'alternative et ARRÊTER. Voir references/operator-routing.md pour les cas limites.
Phase 1 : Analyser l'opérateur (uniquement pour les demandes ambiguës)
À partir de la demande de l'utilisateur, identifier : (1) type d'opération, (2) stratégie de parallélisation, (3) formes et dtypes des entrées.
Phase 2 : Concevoir le noyau
Choisir le squelette ci-dessous qui correspond à votre opération. Ces squelettes suffisent pour les noyaux élément par élément, réduction, matmul et fusion -- NE PAS lire les fichiers de référence pour ces motifs courants. Consulter references/ uniquement lors de l'implémentation de motifs peu courants (grouped GEMM, TMA, fonctions extern) ou du débogage de problèmes.
Squelette élément par élément (GELU, dropout, opérations fusionnées sur tenseurs plats) :
@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# ... calculer ...
tl.store(out_ptr + offsets, result, mask=mask)
Squelette ligne par ligne (softmax, LayerNorm, RMSNorm -- un programme par ligne) :
@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
# ... réduire / normaliser ...
tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)
Squelette matmul en tuiles (GEMM avec tuilage 2D, ordre groupé et autotune) :
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(0)
num_m_blocks = tl.cdiv(M, BLOCK_M)
num_n_blocks = tl.cdiv(N, BLOCK_N)
# Ordre groupé pour la localité du cache L2
num_pid_in_group = GROUP_SIZE_M * num_n_blocks
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
offs_k += BLOCK_K
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)
Phase 3 : Écrire le noyau
Créer un répertoire de sortie, puis écrire le fichier de noyau dans {output_dir}/kernel.py.
Le fichier de noyau DOIT inclure :
- Fonction de noyau décorée avec
@triton.jit @triton.autotunepour les noyaux de production (voir references/api-core.md)- Fonction wrapper Python (nom descriptif pour l'import externe)
- Exports de contrat fixe (les scripts compagnons comptent sur ces noms exacts) :
kernel_fn-- alias à la fonction wrapperreference_fn(*args)-- référence PyTorch avec signature identiqueget_inputs()-- retournelistde tensors CUDA frais pour test/benchmark
Exemple concis (GELU fusionné + dropout) :
import triton
import triton.language as tl
import torch
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
],
key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
x_ptr, out_ptr, n_elements, p, seed,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
x_fp32 = x.to(tl.float32)
x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)
random = tl.rand(seed, offsets)
x = tl.where(random > p, x / (1.0 - p), 0.0)
tl.store(out_ptr + offsets, x, mask=mask)
def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
n_elements = x.numel()
out = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
seed = (x.data_ptr() % (2**31)) ^ n_elements # seed sync-free
fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
return out
# --- Contrat fixe (les scripts compagnons comptent sur ces noms) ---
kernel_fn = fused_gelu_dropout_triton
def reference_fn(x, p=0.1):
torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel())
return torch.nn.functional.dropout(
torch.nn.functional.gelu(x), p, training=True
)
def get_inputs():
return [torch.randn(128 * 1024 * 1024, device="cuda")]
Pour plus de motifs (SiLU+mul, RMSNorm, linear+GELU, add+LayerNorm), voir references/patterns-fusion.md. Pour les motifs GEMM, voir references/patterns-gemm.md.
Phase 4 : Vérifier la correction
Exécuter le script de vérification compagnon :
python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3
Sortie :
{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}
Arrêter si correct: false. Corriger le noyau avant le benchmark.
Guide de tolérance :
| Dtype | rtol | atol | Notes |
|---|---|---|---|
| float16 | 1e-3 | 1e-3 | |
| bfloat16 | 1e-2 | 1e-2 | |
| float32 | 1e-5 | 1e-5 | Opérations élément par élément |
| float32 (matmul) | 1e-2 | 1e-1 | L'ordre d'accumulation TF32 diffère entre les tuiles Triton et cuBLAS |
Phase 5 : Benchmark de performance (optionnel)
Ne benchmark que si l'utilisateur demande explicitement les chiffres de performance. Sauter cette phase pour les demandes axées sur la correction.
python scripts/benchmark_kernel.py {output_dir}/kernel.py
Sortie :
{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}
Références (consulter uniquement en cas de problème)
Les squelettes et principes ci-dessus couvrent les noyaux élément par élément, réduction, matmul et fusion. NE PAS lire les fichiers de référence pour ces motifs courants.
Consulter references/ uniquement quand :
- Implémenter des motifs peu courants (grouped GEMM, TMA, matmul persistant, fonctions extern)
- Déboguer une erreur de compilation ou un résultat incorrect non couvert par la table d'erreurs ci-dessous
- Avoir besoin de détails API pour une opération
tl.*peu familière
Comment chercher : Grep pour votre mot-clé dans references/. Lire uniquement le fichier vers lequel Grep pointe.
| Fichier | Quand l'utiliser |
|---|---|
references/api-core.md |
Options triton.autotune / triton.Config peu familières |
references/api-language.md |
Opérations tl.* peu familières |
references/patterns-gemm.md |
Grouped GEMM, matmul persistant, TMA, formats MX |
references/patterns-advanced.md |
Détails Flash attention, passes arrière, libdevice |
references/troubleshooting.md |
Ops de débogage, mode interpréteur, variables d'env |
Gestion des erreurs et dépannage
Erreurs courantes
| Erreur / Symptôme | Cause | Correction |
|---|---|---|
| "Type mismatch, store Float32 to Float16" | Manque .to(x.dtype) avant store |
Reconvertir le résultat fp32 |
BLOCK_SIZE is not a constexpr |
Taille de bloc passée comme valeur runtime | Ajouter l'annotation : tl.constexpr |
shape mismatch dans opération binaire |
Les formes de tenseur ne se broadcastent pas | Vérifier avec tl.static_print; utiliser [:, None] / [None, :] |
| Grandes différences partout | Mauvais dtype dans tl.load |
Vérifier que le dtype de chargement correspond à l'entrée |
| Matmul 3-8x plus lent que prévu | input_precision="ieee" sur tl.dot |
Retirer; utiliser le défaut TF32. Assurer que reference_fn utilise aussi TF32 |
| Matmul ~0,01-0,1 diff abs vs référence | Incompatibilité TF32 vs IEEE | Utiliser la même précision dans le noyau et la référence (TF32 pour les deux) |
| Diffs aux limites | Masque manquant | Ajouter le masque à toutes les opérations load/store |
| Diffs aléatoires | Race condition | Vérifier les atomiques et l'ordre |
| NaN/Inf | Division par zéro ou overflow fp16 | Garder avec epsilon; utiliser l'accumulateur tl.float32 |
grid must be a tuple |
Lambda de grille retourne int, pas tuple | Retourner (value,) avec virgule finale |
expected constexpr dans tl.arange |
Argument non-constexpr | Les deux args de tl.arange(start, end) doivent être constexpr |
triton.OutOfResources |
Pression registre/mémoire partagée | Réduire BLOCK_SIZE ou num_stages |
| Noyau ne se mettant pas à jour après édition | Cache de compilation obsolète | rm -rf ~/.triton/cache/ |
| Résultats mal appairés vs PyTorch | Sémantique de division d'entiers C | Triton utilise la troncature; voir references/concepts-semantics.md |
Pour la table d'erreurs étendue, les problèmes de mode interpréteur et les variables d'environnement, voir references/troubleshooting.md.
Quand abandonner
Arrêter et rapporter l'échec si :
- Pas un bon choix -- Matmul pur ou flux de contrôle complexe (la Phase 0 doit le détecter).
- Vérification échoue après 3 tentatives -- Problèmes numériques trop sévères à corriger.
- Aucun speedup -- La référence est déjà bien optimisée (cuBLAS, cuDNN).
- Incompatibilité matériel -- GPU cible non disponible pour test.