kernel-cute-writing

Par nvidia · skills

Écrire et implémenter des kernels GPU avec le DSL CuTe de NVIDIA (API Python CUTLASS 4.x) — PAS pour Triton, CUDA C++, ou des explications conceptuelles. Se déclenche uniquement quand l'utilisateur souhaite écrire ou implémenter un kernel, pas quand il pose des questions sur les concepts ou layouts du DSL CuTe. Le DSL CuTe utilise les décorateurs `cute.jit`/`cute.kernel` et les imports `cutlass.cute`. Couvre les kernels élément par élément, les patterns GEMM, les réductions, la hiérarchie mémoire (global/shared/register/TMA), les opérations MMA sur tensor cores, le software pipelining et l'intégration avec les frameworks.

npx skills add https://github.com/nvidia/skills --skill kernel-cute-writing

CuTe DSL

CuTe DSL est un langage spécialisé basé sur Python pour le développement de kernels GPU, faisant partie de CUTLASS 4.x. Il fournit des abstractions Python sur les templates C++ de CUTLASS avec compilation JIT vers des kernels CUDA optimisés via MLIR et ptxas.

Quand l'utiliser

Déclencheurs :

  • Écrire des kernels CUDA en Python (element-wise, GEMM, ops personnalisées)
  • Optimiser les patterns d'accès mémoire GPU (chargements vectorisés, TMA, mémoire partagée)
  • Construire des kernels de tensor core (MMA) pour Ampere/Hopper/Blackwell
  • Intégrer des kernels GPU personnalisés avec PyTorch ou JAX
  • Prototyper des kernels haute performance sans metaprogrammation C++

Symptômes (mauvais outil sinon) :

  • Besoin de coordination de mémoire partagée ou MMA tensor core → utiliser CuTe DSL (pas Triton pour les patterns complexes)
  • Besoin d'ops element-wise simples sans mémoire partagée → CuTe DSL ou Triton conviennent tous les deux
  • Besoin d'appeler des kernels C++ CUTLASS existants → utiliser les APIs CUTLASS C++ à la place
  • Besoin de réductions, scans ou ops collectives non-GEMM → envisager CUB/Thrust

Mots-clés : cute, cutlass, cute.jit, cute.kernel, from_dlpack, zipped_divide, TiledMMA, TiledCopy, TMA, WGMMA, tcgen05, pipeline, mbarrier

Requirements

Requirement Detail
Platform Linux x86_64 uniquement
Python 3.10–3.13
GPU NVIDIA Ampere+ (SM80, SM90, SM100)
CUDA Driver ≥ 575.51.03 (compat Toolkit 12.9)
Install pip install nvidia-cutlass-dsl
Optional apache-tvm-ffi, torch-c-dlpack-ext

Workflows

Workflow 0: Démarrer à partir d'exemples (Recommandé)

Pour tout kernel non trivial (GEMM, attention, pipeliné, ops fusionnées), commencez par trouver l'exemple existant le plus similaire à utiliser comme point de départ — étudiez sa structure, puis reworkez-la pour votre cas d'usage. Ne copiez pas les exemples textuellement ; ils ciblent des dtypes, architectures et formes de problème spécifiques qui diffèrent probablement.

  1. Choisissez l'exemple le plus proche dans l'index ci-dessous. Préférez les exemples correspondant à l'architecture GPU cible (vérifiez avec torch.cuda.get_device_capability()) quand l'opération est similaire.

    Récupérez via web_fetch avec URL de base https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL

    Operation Arch Chemin d'exemple (ajouter à l'URL de base)
    Element-wise add SM80 ampere/elementwise_add.py
    Element-wise + autotune SM80 ampere/elementwise_add_autotune.py
    Element-wise apply SM80 ampere/elementwise_apply.py
    SGEMM (scalar) SM80 ampere/sgemm.py
    Tensor-core GEMM SM80 ampere/tensorop_gemm.py
    Flash Attention v2 SM80 ampere/flash_attention_v2.py
    HSTU Attention SM80 ampere/hstu_attention.py
    Shared memory allocator SM80 ampere/smem_allocator.py
    CTA norm (LayerNorm) SM90 hopper/cta_norm.py
    Dense GEMM SM90 hopper/dense_gemm.py
    Dense GEMM persistent SM90 hopper/dense_gemm_persistent.py
    Flash MHA SM90 hopper/fmha.py
    Dense GEMM SM100 blackwell/dense_gemm.py
    Dense GEMM persistent SM100 blackwell/dense_gemm_persistent.py
    Dense GEMM + alpha/beta SM100 blackwell/dense_gemm_alpha_beta_persistent.py
    RMSNorm SM100 blackwell/rmsnorm.py
    Reduce SM100 blackwell/reduce.py
    Flash MHA SM100 blackwell/fmha.py
    Grouped GEMM SM100 blackwell/grouped_gemm.py
    Mamba2 SSD SM100 blackwell/mamba2_ssd/
    GEMM tutorial (notebook) SM100 notebooks/tour_to_sol_gemm.ipynb

    Exemple : Pour récupérer le Hopper dense GEMM :

    web_fetch https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL/hopper/dense_gemm.py
  2. Lisez d'abord les matériaux de référence — avant de plonger dans le code d'exemple, lisez la documentation references/ pertinente pour comprendre les patterns et les APIs :

    • Pour GEMM : references/patterns-gemm.md (tiling 3 niveaux, fusion épilogue, cute.compile avec mark_layout_dynamic, layouts de mémoire partagée)
    • Pour les réductions : references/patterns-reduction.md (réductions warp, pattern cache cute.compile)
    • Pour element-wise : references/patterns-elementwise.md (variations A–E)
    • Toujours : references/api-arch.md (APIs disponibles, mises en garde spécifiques à l'arch)

    Cela vous donne la base conceptuelle pour reworker l'exemple intelligemment plutôt que de tenter de copier-coller des pipelines complexes.

  3. Récupérez et étudiez le code source de l'exemple — lisez pour la structure, pas pour copier :

    • Identifiez : décorateurs, stratégie de tiling, utilisation de la mémoire partagée, flow de mainloop
    • Notez quels dtype/arch il cible (beaucoup d'exemples sont spécifiques fp16/bf16)
    • Vérifiez s'il utilise des APIs liées à une arch spécifique (TMA → SM90+, tcgen05 → SM100)
  4. Reworkez pour la charge de travail de l'utilisateur (ne pas copier-coller) :

    • Changez les formes, data types, tile sizes pour correspondre aux requirements
    • Remplacez la logique de calcul (épilogue, fusion activation) au besoin
    • Si le dtype diffère (ex. l'exemple est fp16, besoin fp32), attendez-vous à des changements de vectorisation et layout — les patterns de boucle scalaire dans references/ peuvent être un meilleur point de départ qu'adapter un exemple vectorisé
    • Le wrapper runtime doit être lightweight : kernel_fn() ne devrait que appeler from_dlpack() + le kernel compilé. Ne jamais allouer des tenseurs intermédiaires, copier des données, ou recompiler par appel — cela appartient à la setup une seule fois
    • Appliquez les optimisations des docs de référence de cette skill

    ⛔ Blackwell/Hopper GEMM + tenseurs supplémentaires — ARRÊT : Si le GPU cible est SM90+ (Hopper/Blackwell) et le GEMM nécessite des tenseurs supplémentaires au-delà de A, B, C dans l'épilogue (ex. vecteur bias, entrées activation), ne pas tenter. Ces exemples utilisent des descripteurs TMA pour tout mouvement de données — ajouter des tenseurs nécessite de modifier la setup des descripteurs TMA, ce qui est prohibitivement complexe. À la place, dites à l'utilisateur cette limitation et suggérez une approche deux kernels : exécutez le kernel GEMM tel quel, puis appliquez bias + activation dans un kernel element-wise séparé (Workflow 1). Plain GEMM (juste A×B→C avec alpha/beta scalaire) sur Hopper/Blackwell est OK.

  5. Validez et benchmarkez en utilisant les scripts d'accompagnement :

    python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3
    python scripts/benchmark_kernel.py kernel.py

    Le fichier kernel doit exporter kernel_fn, reference_fn, et get_inputs().

Quand sauter les exemples : Les opérations purely element-wise (Workflow 1) ont des patterns complets dans references/patterns-elementwise.md — pas besoin de récupérer des exemples externes.

Kernels de réduction (softmax, layernorm, RMSNorm) : Utilisez references/patterns-reduction.md qui fournit des patterns complets et éprouvés pour les réductions float32 utilisant boucles scalaires + butterfly shuffle + mémoire partagée.

Workflow 1: Kernel Element-wise

Pour les opérations unaires/binaires/in-place qui mappent des entrées vers des sorties 1:1.

  1. Déterminez la structure du kernel : nombre d'entrées/sorties, rang du tenseur, arch cible

  2. Sélectionnez le pattern depuis references/patterns-elementwise.md (Variations A–E)

  3. Écrivez le kernel en appliquant tous quatre principes invariants :

    • P1 : from_dlpack(tensor, assumed_align=16) pour les chargements vectorisés
    • P2 : Dérivez vec_size de element_type.width
    • P3 : cute.zipped_divide(mA, tiler) pour l'accès coalescent
    • P4 : cutlass.dynamic_expr(thread_idx < total) pour les bornes
  4. Règles critiques : Pas de early return, pas de a * 2 (utilisez a + a), pas de cute.math.sigmoid

  5. Pré-compilez avec cute.compile() : Toujours pré-compiler le kernel une seule fois en utilisant cute.compile() pour que kernel_fn appelle l'objet compilé, pas @cute.jit directement. Sans pré-compilation, chaque appel recompile (~20-50ms overhead). Utilisez .mark_layout_dynamic() pour qu'un seul kernel compilé gère des formes d'entrée arbitraires sans recompilation :

    # Compilez une seule fois avec des layouts dynamiques — fonctionne pour toute forme
    fake_x = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"),
                          assumed_align=16).mark_layout_dynamic()
    fake_out = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"),
                            assumed_align=16).mark_layout_dynamic()
    compiled_kernel = cute.compile(host_fn, fake_x, fake_out)
    
    def kernel_fn(x):
        out = torch.empty_like(x)
        compiled_kernel(from_dlpack(x, assumed_align=16).mark_layout_dynamic(),
                        from_dlpack(out, assumed_align=16).mark_layout_dynamic())
        return out
  6. Vérifiez la correctness en utilisant le script d'accompagnement :

    python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3

    Le fichier kernel doit exporter kernel_fn, reference_fn, et get_inputs().

  7. Benchmarkez en utilisant le script d'accompagnement :

    python scripts/benchmark_kernel.py kernel.py

Workflow 2: Kernel GEMM

Pour la multiplication matricielle avec tiling, mémoire partagée, et tensor cores.

  1. Définissez le problème : formes (M, N, K), data types, architecture cible
  2. Choisissez le tiling : CTA tile (bM, bN, bK), étapes pipeline, cluster shape
  3. Partitionnement trois niveaux (voir references/patterns-gemm.md) :
    • Level 1 : Tiling CTA avec local_tile()
    • Level 2 : Partitionnement copy (global → partagée) avec TiledCopy
    • Level 3 : Partitionnement compute (partagée → registre) avec TiledMMA
  4. Mémoire partagée : Utilisez des layouts swizzled (make_smem_layout_atom) pour éviter les conflits de bank
  5. Mainloop : Boucle K-tile avec copy → sync → MMA → sync
  6. Pipeline : Utilisez PipelineTmaAsync (Hopper) ou PipelineTmaUmma (Blackwell). ⚠️ Les pipelines basés sur TMA gèrent le mouvement de données via des descripteurs TMA — ajouter des tenseurs supplémentaires (bias, entrées activation) à l'épilogue nécessite de modifier la setup des descripteurs, ce qui est prohibitivement complexe. Voir la condition d'arrêt dans Workflow 0 étape 4.
  7. Épilogue : Store prédicaté avec scaling alpha/beta
  8. Pré-compilez avec cute.compile() : Toujours pré-compiler le kernel GEMM pour que kernel_fn appelle l'objet compilé, pas @cute.jit directement. Sans pré-compilation, chaque appel recompile (~20-50ms overhead).
  9. Autotune : Cherchez sur les tile sizes, cluster shapes, profondeurs pipeline

Workflow 3: Intégration Framework

Pour envelopper les kernels CuTe DSL comme opérateurs personnalisés PyTorch/JAX.

  1. Écrivez un kernel en utilisant Workflow 1 ou 2
  2. Créez un wrapper : Acceptez torch.Tensor, convertissez via from_dlpack, appelez la fonction host
  3. Pour la production : Compilez avec TVM FFI pour un passage de tenseur zéro-overhead :
    compiled = cute.compile(host_fn, *fake_tensors, options="--enable-tvm-ffi")
    compiled(torch_a, torch_b)  # Direct torch.Tensor, pas de from_dlpack
  4. Pour le déploiement : Utilisez la compilation AOT → export vers .o → charger à l'exécution

Workflow 4: Debugging & Profiling

  1. Définissez l'environnement : CUTE_DSL_PRINT_IR=1, CUTE_DSL_KEEP_PTX=1
  2. Utilisez cute.printf() pour les valeurs runtime (pas Python print)
  3. Inspectez le code généré : compiled.__ptx__, compiled.__mlir__
  4. Profilez : Activez CUTE_DSL_LINEINFO=1, utilisez Nsight Compute/Systems
  5. Debuggez la mémoire : Lancez avec compute-sanitizer python script.py

Output Formats

Un projet de kernel CuTe DSL typique :

kernel_dir/
  kernel.py          # @cute.kernel + @cute.jit functions
  test_kernel.py     # Test de correctness vs référence PyTorch
  bench_kernel.py    # Benchmark avec setup cute.compile()

Indicateurs de succès :

  • Le test de correctness passe (torch.testing.assert_close)
  • Nsight montre des chargements vectorisés (LDG.128/LDG.256), pas des chargements scalaires
  • Pour GEMM : utilisation du tensor core > 80% dans Nsight Compute

Companion Script Contract

Les fichiers kernel utilisés avec scripts/verify_kernel.py et scripts/benchmark_kernel.py doivent exporter trois noms :

  • kernel_fn(*inputs) — le wrapper kernel CuTe DSL (appelle cute.compile + exécute le kernel)
  • reference_fn(*inputs) — implémentation de référence PyTorch (même signature)
  • get_inputs() — retourne une liste de tenseurs CUDA pour les tests
# Exemple kernel.py contract
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

def kernel_fn(x):
    out = torch.empty_like(x)
    # ... appel du kernel cute compilé ...
    return out

def reference_fn(x):
    return torch.nn.functional.gelu(x)

def get_inputs():
    return [torch.randn(1024, 512, dtype=torch.float16, device="cuda")]

Examples

Example: 2D Unary Element-wise (ReLU)

import torch, cutlass, cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack

@cute.kernel
def relu_kernel(gA: cute.Tensor, gC: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    bdim, _, _ = cute.arch.block_dim()
    idx = bidx * bdim + tidx
    m, n = gA.shape[1]
    total = m * n
    if cutlass.dynamic_expr(idx < total):
        a = gA[(None, (idx // n, idx % n))].load()
        gC[(None, (idx // n, idx % n))] = cute.where(a > 0, a, 0)

@cute.jit
def relu_host(mA: cute.Tensor, mC: cute.Tensor):
    vec = 16 // (mA.element_type.width // 8)
    gA = cute.zipped_divide(mA, (1, vec))
    gC = cute.zipped_divide(mC, (1, vec))
    T = 256
    N = cute.size(gA.shape[1])
    relu_kernel(gA, gC).launch(grid=((N+T-1)//T,1,1), block=(T,1,1))

x = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
out = torch.empty_like(x)
relu_host(from_dlpack(x, assumed_align=16), from_dlpack(out, assumed_align=16))

Error Handling

Error Cause Fix
MLIR function requires a Context Appelé @kernel depuis Python Lancez via la fonction host @cute.jit
DSLAstPreprocessorError on return Early return dans @kernel Utilisez if cutlass.dynamic_expr(cond):
Type mismatch on store a * 2 promeut FP16→FP32 Utilisez a + a ou .to(cutlass.Float16)
could not get source code Kernel dans le contexte exec() Écrivez dans un fichier et importez
Scalar loads dans Nsight Hint d'alignement manquant Ajoutez assumed_align=16 à from_dlpack
Missing required argument Pas tous les params @jit passés Passez TOUS les paramètres déclarés
AttributeError: sigmoid Pas de cute.math.sigmoid Utilisez 1.0/(1.0+cute.math.exp(-x))

Voir references/troubleshooting.md pour la table d'erreurs complète et les limitations.

Règle de debugging : Ne supprimez jamais kernel.py pendant le debugging. Utilisez backup_file pour sauvegarder un checkpoint, puis edit_file pour itérer. Si vous êtes bloqué, revert_file pour restaurer la sauvegarde. Un kernel partiellement fonctionnant est toujours mieux que pas de kernel.

Finding More Information

Tier 1: Ce fichier (SKILL.md)

Les workflows ci-dessus couvrent les kernels element-wise, GEMM, intégration framework, et debugging. Cherchez d'abord dans ce fichier pour les questions de procédure.

Tier 2: Répertoire references/

Grep pour les mots-clés à travers references/. Les en-têtes sont grep-friendly.

File Content
concepts-architecture.md Abstractions de base, terminologie, pipeline de compilation
concepts-layouts.md Layout algebra : composition, complément, divide, swizzle
concepts-tensors.md Types de tenseurs, partitionnement, tiling, prédication
concepts-mma.md Atomes MMA, TiledMMA, ops de tensor core par architecture
patterns-getting-started.md Installation, décorateurs, première walkthrough de kernel
patterns-elementwise.md Principes invariants, variations de pattern, impl de référence
patterns-gemm.md Tiling 3 niveaux, mémoire partagée, pipelining, autotuning
patterns-memory.md from_dlpack, TMA, cp.async, TMEM, copy atoms
patterns-compilation.md Control flow, JIT caching, TVM FFI, compilation AOT
patterns-pipeline.md Producteur-consommateur, classes pipeline, barrières, warp specialization
api-core.md cute module : layouts, tenseurs, math, copy, gemm, printing
api-arch.md cute.arch : thread indexing, sync, atomics, memory ops
api-nvgpu.md cute.nvgpu : warp/warpgroup/cpasync/tcgen05 MMA et copy
api-runtime-utils.md Runtime : from_dlpack, tenseurs fake, utils, schedulers
troubleshooting.md Debugging, env vars, erreurs courantes, limitations, FAQ

Comment chercher : Grep pour votre mot-clé à travers references/. Lisez uniquement le fichier et la section pointée par Grep.

Tier 3: Documentation Originale

Si les Tiers 1–2 ne répondent pas, consultez la source :

Skills similaires