kernel-triton-writing

Par nvidia · skills

UNIQUEMENT pour le développement de kernels OpenAI Triton (@triton.jit). NE JAMAIS utiliser pour les kernels CUDA C++, TileIR, ou les outils de profilage (ncu, nsys). La demande de l'utilisateur doit impliquer Triton explicitement. Couvre les patterns spécifiques à Triton : elementwise fusionné, réductions (softmax, LayerNorm, RMSNorm), GEMM tuilé avec triton.autotune, et flash attention. Workflow : conception, écriture, vérification (avec fast-path pour les demandes explicites).

npx skills add https://github.com/nvidia/skills --skill kernel-triton-writing

Écriture de noyaux Triton

Principes

Correction d'abord

  1. Ne jamais benchmarker avant que la vérification réussisse.
  2. Toujours masquer les lectures et écritures pour les formes non divisibles.
  3. Inclure les exports kernel_fn, reference_fn et get_inputs() pour les scripts compagnons.
  4. Toujours exécuter scripts/verify_kernel.py pour valider par rapport à la référence.

Règles de précision FP16/BF16 (LIBERTÉ FAIBLE -- suivre exactement)

Les fonctions transcendantales (tl.exp, tl.log, tl.math.erf, tl.math.tanh) nécessitent des entrées fp32.

# FAUX -- erreur de compilation ou résultats incorrects avec fp16/bf16 :
result = tl.exp(x)

# CORRECT -- convertir en fp32, calculer, reconvertir :
x_fp32 = x.to(tl.float32)
result = tl.exp(x_fp32).to(x.dtype)

Règle : toute fonction mathématique au-delà de l'arithmétique basique (+, -, *, /) nécessite une conversion fp32 en entrée, et une conversion au dtype d'origine en sortie.

Contraintes de précision supplémentaires :

  • tl.sigmoid() n'est pas disponible dans certaines versions de Triton. Utiliser 1.0 / (1.0 + tl.exp(-x_fp32)).
  • Toujours reconvertir en x.dtype avant tl.store -- les incompatibilités causent "Type mismatch, store Float32 to Float16".
  • Contrairement à PyTorch, Triton N'élève PAS automatiquement fp16/bf16 en fp32 pour l'accumulation. Toujours utiliser des accumulateurs tl.float32 pour tl.dot.
  • TF32 pour matmul : Sur Ampere+/Hopper, tl.dot utilise TF32 par défaut pour les entrées fp32 (comme torch.mm). Ne PAS ajouter input_precision="ieee" -- c'est 3-8x plus lent. TF32 est le défaut correct. Si la vérification échoue à cause de la précision TF32 (~0,01-0,1 diff abs), assurer que reference_fn utilise aussi TF32 (torch.mm simple, sans allow_tf32=False).

Éviter la synchronisation CPU-GPU (LIBERTÉ FAIBLE)

Ne jamais appeler .item() dans les wrappers de noyau. Cela force une synchronisation CPU-GPU (~50-100us par appel).

Piège Correction
tensor.item() pour seed x.data_ptr() % (2**31)
torch.randint(...).item() Utiliser les métadonnées tensor pour un seed pseudo-aléatoire
Allocation de sortie à chaque appel Accepter la sortie pré-allouée comme paramètre
Boucles Python appelant le noyau Opérations batch

Sémantique de division d'entiers C (CRITIQUE)

Triton utilise la sémantique C (arrondir vers zéro) pour // et %, NOT la sémantique Python (arrondir vers l'infini négatif). Cela n'a d'importance que quand les opérandes peuvent être négatifs.

Expression Python Triton/C
-7 // 2 -4 -3
-7 % 2 1 -1

Motif sûr : Assurer que tous les index/offset sont non-négatifs. Si des valeurs négatives sont possibles, utiliser (idx % BLOCK + BLOCK) % BLOCK.

Voir references/concepts-semantics.md pour les règles complètes et l'exception scalaire uniquement.

Modèle mental de conception de noyau

  • Axe de parallélisation : Les noyaux élément par élément se parallélisent sur les éléments aplatis. Les noyaux ligne par ligne (LayerNorm, softmax) se parallélisent sur les lignes. Les noyaux matmul se divisent en tuiles 2D (M, N).
  • Taille de bloc : Puissance de 2 uniquement (256, 512, 1024, 2048). Commencer avec 1024 pour H100, 512 pour V100.
  • Coalescence mémoire : Les threads adjacents doivent accéder à des adresses mémoire adjacentes. Le compilateur gère cela automatiquement à partir de l'arithmétique des pointeurs au niveau du bloc.
  • Grille : Utiliser triton.cdiv(n_elements, BLOCK_SIZE). Avec autotune, la grille doit être une lambda : lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),).
  • Ordre des décorateurs : @triton.autotune (le plus extérieur) -> @triton.heuristics -> @triton.jit (le plus intérieur).
  • reset_to_zero : Requis pour autotune sur les noyaux qui accumulent (par ex., sortie matmul). Sans cela, les configurations ultérieures voient des valeurs résiduelles des essais antérieurs.

Flux de travail

Chemin rapide : Si l'utilisateur demande explicitement un noyau Triton (par ex., "Écris un noyau Triton pour X", "Implémente softmax en Triton"), commencer à la Phase 2. N'utiliser les Phases 0-1 que quand la demande est ambiguë sur l'appropriatesse de Triton.

Phase 0 : Acheminer l'opérateur (uniquement pour les demandes ambiguës)

Passer cette phase si l'utilisateur demande explicitement un noyau Triton. N'utiliser que quand la demande est ambiguë (par ex., "optimise cette opération").

Triton gagne quand 2+ opérations peuvent partager des registres au lieu d'écrire/lire la mémoire globale. Règles rapides :

Motif Décision
Opération élément par élément simple (relu, sigmoid) SKIP -- PyTorch est déjà optimal
Matmul autonome SKIP -- cuBLAS est optimisé
Attention standard SKIP -- Utiliser FlashAttention
Chaîne élément par élément (2+ ops), réduction, matmul + épilogue UTILISER TRITON

Si SKIP, recommander l'alternative et ARRÊTER. Voir references/operator-routing.md pour les cas limites.

Phase 1 : Analyser l'opérateur (uniquement pour les demandes ambiguës)

À partir de la demande de l'utilisateur, identifier : (1) type d'opération, (2) stratégie de parallélisation, (3) formes et dtypes des entrées.

Phase 2 : Concevoir le noyau

Choisir le squelette ci-dessous qui correspond à votre opération. Ces squelettes suffisent pour les noyaux élément par élément, réduction, matmul et fusion -- NE PAS lire les fichiers de référence pour ces motifs courants. Consulter references/ uniquement lors de l'implémentation de motifs peu courants (grouped GEMM, TMA, fonctions extern) ou du débogage de problèmes.

Squelette élément par élément (GELU, dropout, opérations fusionnées sur tenseurs plats) :

@triton.jit
def kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # ... calculer ...
    tl.store(out_ptr + offsets, result, mask=mask)

Squelette ligne par ligne (softmax, LayerNorm, RMSNorm -- un programme par ligne) :

@triton.jit
def kernel(x_ptr, out_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    x = tl.load(x_ptr + row_idx * n_cols + col_offsets, mask=mask, other=0.0)
    # ... réduire / normaliser ...
    tl.store(out_ptr + row_idx * n_cols + col_offsets, result, mask=mask)

Squelette matmul en tuiles (GEMM avec tuilage 2D, ordre groupé et autotune) :

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr, M, N, K,
    stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    pid = tl.program_id(0)
    num_m_blocks = tl.cdiv(M, BLOCK_M)
    num_n_blocks = tl.cdiv(N, BLOCK_N)
    # Ordre groupé pour la localité du cache L2
    num_pid_in_group = GROUP_SIZE_M * num_n_blocks
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_m_blocks - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
        b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
        offs_k += BLOCK_K

    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc.to(c_ptr.dtype.element_ty), mask=c_mask)

Phase 3 : Écrire le noyau

Créer un répertoire de sortie, puis écrire le fichier de noyau dans {output_dir}/kernel.py.

Le fichier de noyau DOIT inclure :

  • Fonction de noyau décorée avec @triton.jit
  • @triton.autotune pour les noyaux de production (voir references/api-core.md)
  • Fonction wrapper Python (nom descriptif pour l'import externe)
  • Exports de contrat fixe (les scripts compagnons comptent sur ces noms exacts) :
    • kernel_fn -- alias à la fonction wrapper
    • reference_fn(*args) -- référence PyTorch avec signature identique
    • get_inputs() -- retourne list de tensors CUDA frais pour test/benchmark

Exemple concis (GELU fusionné + dropout) :

import triton
import triton.language as tl
import torch

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
        triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
    ],
    key=['n_elements'],
)
@triton.jit
def fused_gelu_dropout_kernel(
    x_ptr, out_ptr, n_elements, p, seed,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask=mask)
    x_fp32 = x.to(tl.float32)
    x = (0.5 * x_fp32 * (1.0 + tl.math.erf(x_fp32 * 0.7071067811865476))).to(x.dtype)

    random = tl.rand(seed, offsets)
    x = tl.where(random > p, x / (1.0 - p), 0.0)

    tl.store(out_ptr + offsets, x, mask=mask)


def fused_gelu_dropout_triton(x: torch.Tensor, p: float = 0.1) -> torch.Tensor:
    n_elements = x.numel()
    out = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    seed = (x.data_ptr() % (2**31)) ^ n_elements  # seed sync-free
    fused_gelu_dropout_kernel[grid](x, out, n_elements, p, seed)
    return out


# --- Contrat fixe (les scripts compagnons comptent sur ces noms) ---
kernel_fn = fused_gelu_dropout_triton

def reference_fn(x, p=0.1):
    torch.manual_seed((x.data_ptr() % (2**31)) ^ x.numel())
    return torch.nn.functional.dropout(
        torch.nn.functional.gelu(x), p, training=True
    )

def get_inputs():
    return [torch.randn(128 * 1024 * 1024, device="cuda")]

Pour plus de motifs (SiLU+mul, RMSNorm, linear+GELU, add+LayerNorm), voir references/patterns-fusion.md. Pour les motifs GEMM, voir references/patterns-gemm.md.

Phase 4 : Vérifier la correction

Exécuter le script de vérification compagnon :

python scripts/verify_kernel.py {output_dir}/kernel.py --rtol 1e-3 --atol 1e-3

Sortie :

{"correct": true, "max_abs_diff": 1.2e-7, "max_rel_diff": 3.4e-6, "details": "..."}

Arrêter si correct: false. Corriger le noyau avant le benchmark.

Guide de tolérance :

Dtype rtol atol Notes
float16 1e-3 1e-3
bfloat16 1e-2 1e-2
float32 1e-5 1e-5 Opérations élément par élément
float32 (matmul) 1e-2 1e-1 L'ordre d'accumulation TF32 diffère entre les tuiles Triton et cuBLAS

Phase 5 : Benchmark de performance (optionnel)

Ne benchmark que si l'utilisateur demande explicitement les chiffres de performance. Sauter cette phase pour les demandes axées sur la correction.

python scripts/benchmark_kernel.py {output_dir}/kernel.py

Sortie :

{"kernel_time_ms": 0.45, "reference_time_ms": 1.23, "speedup": 2.73, "warmup_iters": 10, "benchmark_iters": 40}

Références (consulter uniquement en cas de problème)

Les squelettes et principes ci-dessus couvrent les noyaux élément par élément, réduction, matmul et fusion. NE PAS lire les fichiers de référence pour ces motifs courants.

Consulter references/ uniquement quand :

  • Implémenter des motifs peu courants (grouped GEMM, TMA, matmul persistant, fonctions extern)
  • Déboguer une erreur de compilation ou un résultat incorrect non couvert par la table d'erreurs ci-dessous
  • Avoir besoin de détails API pour une opération tl.* peu familière

Comment chercher : Grep pour votre mot-clé dans references/. Lire uniquement le fichier vers lequel Grep pointe.

Fichier Quand l'utiliser
references/api-core.md Options triton.autotune / triton.Config peu familières
references/api-language.md Opérations tl.* peu familières
references/patterns-gemm.md Grouped GEMM, matmul persistant, TMA, formats MX
references/patterns-advanced.md Détails Flash attention, passes arrière, libdevice
references/troubleshooting.md Ops de débogage, mode interpréteur, variables d'env

Gestion des erreurs et dépannage

Erreurs courantes

Erreur / Symptôme Cause Correction
"Type mismatch, store Float32 to Float16" Manque .to(x.dtype) avant store Reconvertir le résultat fp32
BLOCK_SIZE is not a constexpr Taille de bloc passée comme valeur runtime Ajouter l'annotation : tl.constexpr
shape mismatch dans opération binaire Les formes de tenseur ne se broadcastent pas Vérifier avec tl.static_print; utiliser [:, None] / [None, :]
Grandes différences partout Mauvais dtype dans tl.load Vérifier que le dtype de chargement correspond à l'entrée
Matmul 3-8x plus lent que prévu input_precision="ieee" sur tl.dot Retirer; utiliser le défaut TF32. Assurer que reference_fn utilise aussi TF32
Matmul ~0,01-0,1 diff abs vs référence Incompatibilité TF32 vs IEEE Utiliser la même précision dans le noyau et la référence (TF32 pour les deux)
Diffs aux limites Masque manquant Ajouter le masque à toutes les opérations load/store
Diffs aléatoires Race condition Vérifier les atomiques et l'ordre
NaN/Inf Division par zéro ou overflow fp16 Garder avec epsilon; utiliser l'accumulateur tl.float32
grid must be a tuple Lambda de grille retourne int, pas tuple Retourner (value,) avec virgule finale
expected constexpr dans tl.arange Argument non-constexpr Les deux args de tl.arange(start, end) doivent être constexpr
triton.OutOfResources Pression registre/mémoire partagée Réduire BLOCK_SIZE ou num_stages
Noyau ne se mettant pas à jour après édition Cache de compilation obsolète rm -rf ~/.triton/cache/
Résultats mal appairés vs PyTorch Sémantique de division d'entiers C Triton utilise la troncature; voir references/concepts-semantics.md

Pour la table d'erreurs étendue, les problèmes de mode interpréteur et les variables d'environnement, voir references/troubleshooting.md.

Quand abandonner

Arrêter et rapporter l'échec si :

  1. Pas un bon choix -- Matmul pur ou flux de contrôle complexe (la Phase 0 doit le détecter).
  2. Vérification échoue après 3 tentatives -- Problèmes numériques trop sévères à corriger.
  3. Aucun speedup -- La référence est déjà bien optimisée (cuBLAS, cuDNN).
  4. Incompatibilité matériel -- GPU cible non disponible pour test.

Skills similaires