triton-kernel-programming

Par mkurman · zorai

Modèle d'implémentation pratique et référence API pour écrire, optimiser, déboguer et benchmarker des kernels GPU Triton. Couvre l'ensemble de la surface API `triton.language`, les patterns d'autotuning, les workflows de profilage et l'intégration en production.

npx skills add https://github.com/mkurman/zorai --skill triton-kernel-programming

Programmation de Noyaux Triton

Vue d'ensemble

Cette skill fournit une référence pratique pour construire des noyaux Triton en production. Elle couvre l'API triton.language, les décorateurs d'autotuning, le modèle de compilation @triton.jit, les workflows de débogage/interpréteur et les benchmarks triton.testing.

Quand utiliser

Utilisez cette skill quand :

  • Implémenter un noyau GPU custom en Triton
  • Optimiser la latence d'inférence pour des opérations transformer petit batch
  • Fusionner des opérations (ex : matmul + activation, attention avec softmax)
  • Porter des noyaux CUDA vers Triton pour une maintenance plus facile

Ne pas utiliser pour :

  • Les opérations PyTorch standard qui tournent déjà vite (utiliser torch.compile)
  • Les patterns de parallélisme distribué ou multi-GPU
  • Les charges de travail liées au CPU

Installation

# Triton est livré avec PyTorch ≥2.0. Installez la dernière version :
pip install -U triton

# Ou compilez depuis la source pour les dernières fonctionnalités :
git clone https://github.com/triton-lang/triton.git
cd triton
pip install -r python/requirements.txt
pip install .

# Pour le profiling :
pip install nvitools  # Helpers de profiling NVIDIA
pip install torch_tb_profiler  # Profiling PyTorch

Référence de l'API Principale

Décorateur @triton.jit

Compile une fonction Python en noyau GPU. Tout le code doit être du Triton valide (subset de Python + opérations triton.language).

@triton.jit
def kernel(  # ← noyau compilé
    ptr,              # arguments runtime : pointeurs, scalaires
    BLOCK: tl.constexpr,  # constexpr : baked in au moment de la compilation
):
    pid = tl.program_id(axis=0)  # Indice programme SPMD
    ...

triton.language (tl) — Opérations Clés

Catégorie Opération Description
Indexation tl.program_id(axis) Indice programme SPMD selon l'axe 0, 1 ou 2
Plages tl.arange(start, end) Tenseur plage 1D pour l'adressage vectorisé
Arithmétique tl.sum, tl.max, tl.min, tl.argmax Réduction de bloc selon l'axe
Arithmétique tl.dot(a, b) Multiplication matricielle de bloc (déclenche les tensor cores)
Activation tl.exp, tl.log, tl.sigmoid, tl.tanh Math élément par élément
Activation tl.sqrt, tl.abs, tl.where Opérations élément par élément
Mémoire tl.load(ptr, mask=, other=) Chargement vectoriel de la mémoire globale
Mémoire tl.store(ptr, val, mask=) Stockage vectoriel en mémoire globale
Mémoire tl.atomic_add(ptr, val) Ajout atomique (pour les réductions)
Cast tensor.to(tl.float16) Conversion de type
Cast tl.cast(tensor, tl.float32) Conversion de type explicite
Débogage tl.device_print("x:", x) Impression runtime
Débogage tl.device_assert(cond, "msg") Assertion runtime
Débogage tl.static_print(x) Impression compile-time
Débogage tl.static_assert(cond, "msg") Assertion compile-time

Opérations Mémoire — Bonne Pratique de Masquage

# Toujours masquer les loads/stores pour la sécurité :
mask = offsets < n_elements
x = tl.load(ptr + offsets, mask=mask, other=0.0)
# 'other' fournit une valeur par défaut sûre pour les positions hors limites

# Pour la boucle interne matmul, utiliser other=0.0 pour les tuiles partielles :
a = tl.load(a_ptrs, mask=offsets_k[None, :] < K - k, other=0.0)
b = tl.load(b_ptrs, mask=offsets_k[:, None] < K - k, other=0.0)

Modèles de Noyaux Complets

Modèle 1 : Fusion Élément par Élément (ex : LayerNorm)

@triton.jit
def layernorm_kernel(
    input_ptr, output_ptr, weight_ptr, bias_ptr,
    row_stride, n_cols, eps,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    row_start = pid * row_stride
    offsets = row_start + tl.arange(0, BLOCK_SIZE)
    mask = tl.arange(0, BLOCK_SIZE) < n_cols

    x = tl.load(input_ptr + offsets, mask=mask, other=0.0)

    # Moyenne
    mean = tl.sum(x, axis=0) / n_cols
    # Variance
    x_shifted = x - mean
    var = tl.sum(x_shifted * x_shifted, axis=0) / n_cols
    # Normalisation
    x_norm = x_shifted / tl.sqrt(var + eps)
    # Mise à l'échelle + décalage
    w = tl.load(weight_ptr + tl.arange(0, BLOCK_SIZE), mask=mask)
    b = tl.load(bias_ptr + tl.arange(0, BLOCK_SIZE), mask=mask)
    y = x_norm * w + b

    tl.store(output_ptr + offsets, y, mask=mask)

Modèle 2 : Softmax de Style Flash Attention avec Calcul Sûr en Ligne

@triton.jit
def fused_attention_kernel(
    q_ptr, k_ptr, v_ptr, output_ptr,
    stride_qh, stride_qd,
    stride_kh, stride_kd,
    stride_vh, stride_vd,
    stride_oh, stride_od,
    H, D,
    BLOCK_D: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid_h = tl.program_id(0)  # indice head

    offs_d = tl.arange(0, BLOCK_D)
    offs_n = tl.arange(0, BLOCK_N)

    # Charger le bloc Q pour cette head
    q_ptrs = q_ptr + pid_h * stride_qh + offs_d[:, None] * stride_qd
    q = tl.load(q_ptrs)  # (BLOCK_D, 1)

    # Softmax sûr en ligne sur la séquence KV
    m_i = tl.full([BLOCK_N], -float('inf'), dtype=tl.float32)
    z_i = tl.zeros([BLOCK_N], dtype=tl.float32)
    acc = tl.zeros([BLOCK_D, BLOCK_N], dtype=tl.float32)

    for start_n in range(0, N, BLOCK_N):
        k_ptrs = k_ptr + pid_h * stride_kh + offs_n[None, :] * stride_kd + start_n * stride_kd
        k = tl.load(k_ptrs, mask=offs_n[None, :] < N - start_n, other=0.0)

        # S = Q @ K^T
        s = tl.dot(q.T, k)  # (1, BLOCK_N)

        # Softmax sûr en ligne
        m_ij = tl.maximum(m_i, s)
        p = tl.exp(s - m_ij)
        alpha = tl.exp(m_i - m_ij)
        acc = acc * alpha + p * k.T  # accumulation pondérée
        z_i = z_i * alpha + p
        m_i = m_i * 0 + m_ij  # mise à jour broadcast

    output = acc / z_i

    # Stocker
    out_ptrs = output_ptr + pid_h * stride_oh + offs_d[:, None] * stride_od
    tl.store(out_ptrs, output)

Modèle 3 : GEMM FP8 avec Split-K (Optimisé pour l'Inférence)

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 4}, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 8}, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'SPLIT_K': 16}, num_warps=8),
    ],
    key=['M', 'N', 'K'],
    prune_configs_by={
        'early_config_prune': lambda configs, named_args: [
            c for c in configs if c.kwargs['BLOCK_SIZE_M'] * c.kwargs['SPLIT_K'] <= 128
        ],
    },
)
@triton.jit
def fp8_gemm_splitk_kernel(
    a_ptr, b_ptr, c_ptr, partial_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    SPLIT_K: tl.constexpr,
):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    k_block_id = pid // num_pid_m
    pid_m = pid % num_pid_m

    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_n = tl.arange(0, BLOCK_SIZE_N)
    offs_k = k_block_id * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)

    a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    for k in range(0, K // SPLIT_K, BLOCK_SIZE_K):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K // SPLIT_K - k, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K // SPLIT_K - k, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    # Écrire la somme partielle
    partial_idx = k_block_id * M + pid_m * BLOCK_SIZE_M
    partial_ptrs = partial_ptr + partial_idx
    tl.store(partial_ptrs, tl.sum(acc, axis=1)[:, None])

Stratégie d'Autotuning

Quand l'Autotuning Est Essentiel

Scénario Impact Autotune
Formes d'entrée variables (VLLM, serving) Critique — cache par forme
Formes de production fixes Exécuter une fois, fixer la config
Opérations liées à la mémoire (softmax, normes) Moins critique — le pattern d'accès mémoire domine
Opérations liées au calcul (GEMM) Critique — différence de perf 2–5x entre configs

Heuristiques de Design de Config

# Règle générale : le produit des dimensions de tuile doit tenir dans les registres
# BLOCK_SIZE_M * BLOCK_SIZE_N * element_size <= register_budget

# Pour NVIDIA A100/H100 (fp16 matmul) :
configs = [
    # Équilibré : bon partout
    triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=3),
    # Débit : grandes tuiles pour compute-bound
    triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_warps=8, num_stages=4),
    # Latence : petites tuiles pour memory-bound / petit M
    triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=2),
    # AMD MI300X : utiliser moins de warps, peut nécessiter waves_per_eu
    triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4, num_stages=0),
]

Workflow de Profiling

Étape par Étape : Profiler et Optimiser

# 1. Warmup : exécuter une fois pour déclencher la compilation JIT
output_triton = my_kernel(x, y)

# 2. Benchmark avec triton.testing
import triton.testing
ms, min_ms, max_ms = triton.testing.do_bench(
    lambda: my_kernel(x, y),
    quantiles=[0.5, 0.2, 0.8],
    warmup=100,  # itérations
    rep=100,     # itérations de mesure
)

# 3. Comparer avec la référence
ms_torch, _, _ = triton.testing.do_bench(lambda: torch.matmul(a, b))

# 4. Calculer TFLOPS
tflops = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
print(f"Triton: {tflops(ms):.2f} TFLOPS | Torch: {tflops(ms_torch):.2f} TFLOPS")

Intégration CUDA Graph (Production)

# Après que l'autotuning a sélectionné la meilleure config, capturer un CUDA graph :
import torch

def capture_gemm_graph(a, b):
    # Warmup avec la forme de production
    _ = triton_matmul(a, b)
    torch.cuda.synchronize()

    # Capturer le graph
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph):
        c = triton_matmul(a, b)

    return graph, c

# Rejouer pour l'inférence — élimine 1-2ms de surcharge JIT par lancement
graph.replay()

Cheatsheet de Débogage

Problème Symptôme Fix
Sortie incorrecte Off-by-one dans les offsets Vérifier la logique mask, utiliser % modulo pour les limites
Sortie NaN Instabilité numérique Soustraire le max avant exp ; vérifier division par zéro
Noyau lent (memory-bound) Util bande faible Augmenter les tailles de tuile, vérifier _b128 dans l'ISA
Noyau lent (compute-bound) TFLOPS faible Vérifier l'utilisation tensor core dans le PTX ; essayer le tuning num_stages
Erreur de compilation Problème fonction @triton.jit Vérifier les constructions Python non supportées (pas de dictionnaires, pas d'indexation dynamique)
Erreurs compute-sanitizer Accès hors limites Vérifier la couverture du mask pour les tuiles partielles
Surcharge de lancement élevée Latence CPU-side Utiliser CUDA Graphs pour l'inférence en production

Portes de Qualité

Porte Commande/Vérification Attendu
Exactitude torch.max(torch.abs(ref - triton_out)) < 0,01 (fp16) ou < 0,5 (fp8)
Autotuning Variable env TRITON_PRINT_AUTOTUNING=1 Meilleure config imprimée
Utilisation tensor core Vérifier PTX pour wgmma/mma Présent pour les noyaux matmul
Coalescence mémoire Vérifier ISA pour global_load_dwordx4 Présent dans la boucle critique
Utilisation LDS grep "triton_gpu.shared" depuis dump MLIR < 64 KB
Occupancy Calculer depuis les comptages VGPR/LDS > 50 % pour compute-bound
Speedup triton.testing.do_bench > 1,5x par rapport à PyTorch naïf

Références Croisées

  • Guideline triton-kernel-build-design — patterns de design complets, hiérarchie mémoire et référence d'optimisation
  • Tutoriels officiels : https://triton-lang.org/main/getting-started/tutorials/
  • dataset-curation-manifest — quand construire des noyaux de chargement de données
  • embedding-analysis — pour comprendre les patterns de calcul d'embeddings

Références

Ressource Lien
API Python Triton https://triton-lang.org/main/python-api/
Triton Autotune https://triton-lang.org/main/python-api/generated/triton.autotune.html
Tutoriels Triton https://triton-lang.org/main/getting-started/tutorials/
PyTorch User-Defined Triton https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html
Exercices Triton https://lweitkamp.github.io/triton_exercises/print.html
TK-GEMM (Llama3 FP8) https://pytorch.org/blog/accelerating-llama3

Skills similaires