CuTe DSL
CuTe DSL est un langage spécialisé basé sur Python pour le développement de kernels GPU, faisant partie de CUTLASS 4.x. Il fournit des abstractions Python sur les templates C++ de CUTLASS avec compilation JIT vers des kernels CUDA optimisés via MLIR et ptxas.
Quand l'utiliser
Déclencheurs :
- Écrire des kernels CUDA en Python (element-wise, GEMM, ops personnalisées)
- Optimiser les patterns d'accès mémoire GPU (chargements vectorisés, TMA, mémoire partagée)
- Construire des kernels de tensor core (MMA) pour Ampere/Hopper/Blackwell
- Intégrer des kernels GPU personnalisés avec PyTorch ou JAX
- Prototyper des kernels haute performance sans metaprogrammation C++
Symptômes (mauvais outil sinon) :
- Besoin de coordination de mémoire partagée ou MMA tensor core → utiliser CuTe DSL (pas Triton pour les patterns complexes)
- Besoin d'ops element-wise simples sans mémoire partagée → CuTe DSL ou Triton conviennent tous les deux
- Besoin d'appeler des kernels C++ CUTLASS existants → utiliser les APIs CUTLASS C++ à la place
- Besoin de réductions, scans ou ops collectives non-GEMM → envisager CUB/Thrust
Mots-clés : cute, cutlass, cute.jit, cute.kernel, from_dlpack, zipped_divide, TiledMMA, TiledCopy, TMA, WGMMA, tcgen05, pipeline, mbarrier
Requirements
| Requirement | Detail |
|---|---|
| Platform | Linux x86_64 uniquement |
| Python | 3.10–3.13 |
| GPU | NVIDIA Ampere+ (SM80, SM90, SM100) |
| CUDA Driver | ≥ 575.51.03 (compat Toolkit 12.9) |
| Install | pip install nvidia-cutlass-dsl |
| Optional | apache-tvm-ffi, torch-c-dlpack-ext |
Workflows
Workflow 0: Démarrer à partir d'exemples (Recommandé)
Pour tout kernel non trivial (GEMM, attention, pipeliné, ops fusionnées), commencez par trouver l'exemple existant le plus similaire à utiliser comme point de départ — étudiez sa structure, puis reworkez-la pour votre cas d'usage. Ne copiez pas les exemples textuellement ; ils ciblent des dtypes, architectures et formes de problème spécifiques qui diffèrent probablement.
-
Choisissez l'exemple le plus proche dans l'index ci-dessous. Préférez les exemples correspondant à l'architecture GPU cible (vérifiez avec
torch.cuda.get_device_capability()) quand l'opération est similaire.Récupérez via
web_fetchavec URL de basehttps://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSLOperation Arch Chemin d'exemple (ajouter à l'URL de base) Element-wise add SM80 ampere/elementwise_add.pyElement-wise + autotune SM80 ampere/elementwise_add_autotune.pyElement-wise apply SM80 ampere/elementwise_apply.pySGEMM (scalar) SM80 ampere/sgemm.pyTensor-core GEMM SM80 ampere/tensorop_gemm.pyFlash Attention v2 SM80 ampere/flash_attention_v2.pyHSTU Attention SM80 ampere/hstu_attention.pyShared memory allocator SM80 ampere/smem_allocator.pyCTA norm (LayerNorm) SM90 hopper/cta_norm.pyDense GEMM SM90 hopper/dense_gemm.pyDense GEMM persistent SM90 hopper/dense_gemm_persistent.pyFlash MHA SM90 hopper/fmha.pyDense GEMM SM100 blackwell/dense_gemm.pyDense GEMM persistent SM100 blackwell/dense_gemm_persistent.pyDense GEMM + alpha/beta SM100 blackwell/dense_gemm_alpha_beta_persistent.pyRMSNorm SM100 blackwell/rmsnorm.pyReduce SM100 blackwell/reduce.pyFlash MHA SM100 blackwell/fmha.pyGrouped GEMM SM100 blackwell/grouped_gemm.pyMamba2 SSD SM100 blackwell/mamba2_ssd/GEMM tutorial (notebook) SM100 notebooks/tour_to_sol_gemm.ipynbExemple : Pour récupérer le Hopper dense GEMM :
web_fetch https://raw.githubusercontent.com/NVIDIA/cutlass/main/examples/python/CuTeDSL/hopper/dense_gemm.py -
Lisez d'abord les matériaux de référence — avant de plonger dans le code d'exemple, lisez la documentation
references/pertinente pour comprendre les patterns et les APIs :- Pour GEMM :
references/patterns-gemm.md(tiling 3 niveaux, fusion épilogue,cute.compileavecmark_layout_dynamic, layouts de mémoire partagée) - Pour les réductions :
references/patterns-reduction.md(réductions warp, pattern cachecute.compile) - Pour element-wise :
references/patterns-elementwise.md(variations A–E) - Toujours :
references/api-arch.md(APIs disponibles, mises en garde spécifiques à l'arch)
Cela vous donne la base conceptuelle pour reworker l'exemple intelligemment plutôt que de tenter de copier-coller des pipelines complexes.
- Pour GEMM :
-
Récupérez et étudiez le code source de l'exemple — lisez pour la structure, pas pour copier :
- Identifiez : décorateurs, stratégie de tiling, utilisation de la mémoire partagée, flow de mainloop
- Notez quels dtype/arch il cible (beaucoup d'exemples sont spécifiques fp16/bf16)
- Vérifiez s'il utilise des APIs liées à une arch spécifique (TMA → SM90+, tcgen05 → SM100)
-
Reworkez pour la charge de travail de l'utilisateur (ne pas copier-coller) :
- Changez les formes, data types, tile sizes pour correspondre aux requirements
- Remplacez la logique de calcul (épilogue, fusion activation) au besoin
- Si le dtype diffère (ex. l'exemple est fp16, besoin fp32), attendez-vous à des changements de vectorisation et layout — les patterns de boucle scalaire dans
references/peuvent être un meilleur point de départ qu'adapter un exemple vectorisé - Le wrapper runtime doit être lightweight :
kernel_fn()ne devrait que appelerfrom_dlpack()+ le kernel compilé. Ne jamais allouer des tenseurs intermédiaires, copier des données, ou recompiler par appel — cela appartient à la setup une seule fois - Appliquez les optimisations des docs de référence de cette skill
⛔ Blackwell/Hopper GEMM + tenseurs supplémentaires — ARRÊT : Si le GPU cible est SM90+ (Hopper/Blackwell) et le GEMM nécessite des tenseurs supplémentaires au-delà de A, B, C dans l'épilogue (ex. vecteur bias, entrées activation), ne pas tenter. Ces exemples utilisent des descripteurs TMA pour tout mouvement de données — ajouter des tenseurs nécessite de modifier la setup des descripteurs TMA, ce qui est prohibitivement complexe. À la place, dites à l'utilisateur cette limitation et suggérez une approche deux kernels : exécutez le kernel GEMM tel quel, puis appliquez bias + activation dans un kernel element-wise séparé (Workflow 1). Plain GEMM (juste A×B→C avec alpha/beta scalaire) sur Hopper/Blackwell est OK.
-
Validez et benchmarkez en utilisant les scripts d'accompagnement :
python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3 python scripts/benchmark_kernel.py kernel.pyLe fichier kernel doit exporter
kernel_fn,reference_fn, etget_inputs().
Quand sauter les exemples : Les opérations purely element-wise (Workflow 1) ont des patterns complets dans references/patterns-elementwise.md — pas besoin de récupérer des exemples externes.
Kernels de réduction (softmax, layernorm, RMSNorm) : Utilisez references/patterns-reduction.md qui fournit des patterns complets et éprouvés pour les réductions float32 utilisant boucles scalaires + butterfly shuffle + mémoire partagée.
Workflow 1: Kernel Element-wise
Pour les opérations unaires/binaires/in-place qui mappent des entrées vers des sorties 1:1.
-
Déterminez la structure du kernel : nombre d'entrées/sorties, rang du tenseur, arch cible
-
Sélectionnez le pattern depuis
references/patterns-elementwise.md(Variations A–E) -
Écrivez le kernel en appliquant tous quatre principes invariants :
- P1 :
from_dlpack(tensor, assumed_align=16)pour les chargements vectorisés - P2 : Dérivez
vec_sizedeelement_type.width - P3 :
cute.zipped_divide(mA, tiler)pour l'accès coalescent - P4 :
cutlass.dynamic_expr(thread_idx < total)pour les bornes
- P1 :
-
Règles critiques : Pas de early return, pas de
a * 2(utiliseza + a), pas decute.math.sigmoid -
Pré-compilez avec
cute.compile(): Toujours pré-compiler le kernel une seule fois en utilisantcute.compile()pour quekernel_fnappelle l'objet compilé, pas@cute.jitdirectement. Sans pré-compilation, chaque appel recompile (~20-50ms overhead). Utilisez.mark_layout_dynamic()pour qu'un seul kernel compilé gère des formes d'entrée arbitraires sans recompilation :# Compilez une seule fois avec des layouts dynamiques — fonctionne pour toute forme fake_x = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"), assumed_align=16).mark_layout_dynamic() fake_out = from_dlpack(torch.empty(1, 1, dtype=torch.float16, device="cuda"), assumed_align=16).mark_layout_dynamic() compiled_kernel = cute.compile(host_fn, fake_x, fake_out) def kernel_fn(x): out = torch.empty_like(x) compiled_kernel(from_dlpack(x, assumed_align=16).mark_layout_dynamic(), from_dlpack(out, assumed_align=16).mark_layout_dynamic()) return out -
Vérifiez la correctness en utilisant le script d'accompagnement :
python scripts/verify_kernel.py kernel.py --rtol 1e-3 --atol 1e-3Le fichier kernel doit exporter
kernel_fn,reference_fn, etget_inputs(). -
Benchmarkez en utilisant le script d'accompagnement :
python scripts/benchmark_kernel.py kernel.py
Workflow 2: Kernel GEMM
Pour la multiplication matricielle avec tiling, mémoire partagée, et tensor cores.
- Définissez le problème : formes (M, N, K), data types, architecture cible
- Choisissez le tiling : CTA tile (bM, bN, bK), étapes pipeline, cluster shape
- Partitionnement trois niveaux (voir
references/patterns-gemm.md) :- Level 1 : Tiling CTA avec
local_tile() - Level 2 : Partitionnement copy (global → partagée) avec
TiledCopy - Level 3 : Partitionnement compute (partagée → registre) avec
TiledMMA
- Level 1 : Tiling CTA avec
- Mémoire partagée : Utilisez des layouts swizzled (
make_smem_layout_atom) pour éviter les conflits de bank - Mainloop : Boucle K-tile avec copy → sync → MMA → sync
- Pipeline : Utilisez
PipelineTmaAsync(Hopper) ouPipelineTmaUmma(Blackwell). ⚠️ Les pipelines basés sur TMA gèrent le mouvement de données via des descripteurs TMA — ajouter des tenseurs supplémentaires (bias, entrées activation) à l'épilogue nécessite de modifier la setup des descripteurs, ce qui est prohibitivement complexe. Voir la condition d'arrêt dans Workflow 0 étape 4. - Épilogue : Store prédicaté avec scaling alpha/beta
- Pré-compilez avec
cute.compile(): Toujours pré-compiler le kernel GEMM pour quekernel_fnappelle l'objet compilé, pas@cute.jitdirectement. Sans pré-compilation, chaque appel recompile (~20-50ms overhead). - Autotune : Cherchez sur les tile sizes, cluster shapes, profondeurs pipeline
Workflow 3: Intégration Framework
Pour envelopper les kernels CuTe DSL comme opérateurs personnalisés PyTorch/JAX.
- Écrivez un kernel en utilisant Workflow 1 ou 2
- Créez un wrapper : Acceptez
torch.Tensor, convertissez viafrom_dlpack, appelez la fonction host - Pour la production : Compilez avec TVM FFI pour un passage de tenseur zéro-overhead :
compiled = cute.compile(host_fn, *fake_tensors, options="--enable-tvm-ffi") compiled(torch_a, torch_b) # Direct torch.Tensor, pas de from_dlpack - Pour le déploiement : Utilisez la compilation AOT → export vers
.o→ charger à l'exécution
Workflow 4: Debugging & Profiling
- Définissez l'environnement :
CUTE_DSL_PRINT_IR=1,CUTE_DSL_KEEP_PTX=1 - Utilisez
cute.printf()pour les valeurs runtime (pas Pythonprint) - Inspectez le code généré :
compiled.__ptx__,compiled.__mlir__ - Profilez : Activez
CUTE_DSL_LINEINFO=1, utilisez Nsight Compute/Systems - Debuggez la mémoire : Lancez avec
compute-sanitizer python script.py
Output Formats
Un projet de kernel CuTe DSL typique :
kernel_dir/
kernel.py # @cute.kernel + @cute.jit functions
test_kernel.py # Test de correctness vs référence PyTorch
bench_kernel.py # Benchmark avec setup cute.compile()
Indicateurs de succès :
- Le test de correctness passe (
torch.testing.assert_close) - Nsight montre des chargements vectorisés (LDG.128/LDG.256), pas des chargements scalaires
- Pour GEMM : utilisation du tensor core > 80% dans Nsight Compute
Companion Script Contract
Les fichiers kernel utilisés avec scripts/verify_kernel.py et scripts/benchmark_kernel.py doivent exporter trois noms :
kernel_fn(*inputs)— le wrapper kernel CuTe DSL (appellecute.compile+ exécute le kernel)reference_fn(*inputs)— implémentation de référence PyTorch (même signature)get_inputs()— retourne une liste de tenseurs CUDA pour les tests
# Exemple kernel.py contract
import torch
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
def kernel_fn(x):
out = torch.empty_like(x)
# ... appel du kernel cute compilé ...
return out
def reference_fn(x):
return torch.nn.functional.gelu(x)
def get_inputs():
return [torch.randn(1024, 512, dtype=torch.float16, device="cuda")]
Examples
Example: 2D Unary Element-wise (ReLU)
import torch, cutlass, cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
@cute.kernel
def relu_kernel(gA: cute.Tensor, gC: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
bdim, _, _ = cute.arch.block_dim()
idx = bidx * bdim + tidx
m, n = gA.shape[1]
total = m * n
if cutlass.dynamic_expr(idx < total):
a = gA[(None, (idx // n, idx % n))].load()
gC[(None, (idx // n, idx % n))] = cute.where(a > 0, a, 0)
@cute.jit
def relu_host(mA: cute.Tensor, mC: cute.Tensor):
vec = 16 // (mA.element_type.width // 8)
gA = cute.zipped_divide(mA, (1, vec))
gC = cute.zipped_divide(mC, (1, vec))
T = 256
N = cute.size(gA.shape[1])
relu_kernel(gA, gC).launch(grid=((N+T-1)//T,1,1), block=(T,1,1))
x = torch.randn(1024, 512, dtype=torch.float16, device="cuda")
out = torch.empty_like(x)
relu_host(from_dlpack(x, assumed_align=16), from_dlpack(out, assumed_align=16))
Error Handling
| Error | Cause | Fix |
|---|---|---|
MLIR function requires a Context |
Appelé @kernel depuis Python | Lancez via la fonction host @cute.jit |
DSLAstPreprocessorError on return |
Early return dans @kernel | Utilisez if cutlass.dynamic_expr(cond): |
| Type mismatch on store | a * 2 promeut FP16→FP32 |
Utilisez a + a ou .to(cutlass.Float16) |
could not get source code |
Kernel dans le contexte exec() |
Écrivez dans un fichier et importez |
| Scalar loads dans Nsight | Hint d'alignement manquant | Ajoutez assumed_align=16 à from_dlpack |
Missing required argument |
Pas tous les params @jit passés | Passez TOUS les paramètres déclarés |
AttributeError: sigmoid |
Pas de cute.math.sigmoid |
Utilisez 1.0/(1.0+cute.math.exp(-x)) |
Voir references/troubleshooting.md pour la table d'erreurs complète et les limitations.
Règle de debugging : Ne supprimez jamais kernel.py pendant le debugging. Utilisez backup_file pour sauvegarder un checkpoint, puis edit_file pour itérer. Si vous êtes bloqué, revert_file pour restaurer la sauvegarde. Un kernel partiellement fonctionnant est toujours mieux que pas de kernel.
Finding More Information
Tier 1: Ce fichier (SKILL.md)
Les workflows ci-dessus couvrent les kernels element-wise, GEMM, intégration framework, et debugging. Cherchez d'abord dans ce fichier pour les questions de procédure.
Tier 2: Répertoire references/
Grep pour les mots-clés à travers references/. Les en-têtes sont grep-friendly.
| File | Content |
|---|---|
concepts-architecture.md |
Abstractions de base, terminologie, pipeline de compilation |
concepts-layouts.md |
Layout algebra : composition, complément, divide, swizzle |
concepts-tensors.md |
Types de tenseurs, partitionnement, tiling, prédication |
concepts-mma.md |
Atomes MMA, TiledMMA, ops de tensor core par architecture |
patterns-getting-started.md |
Installation, décorateurs, première walkthrough de kernel |
patterns-elementwise.md |
Principes invariants, variations de pattern, impl de référence |
patterns-gemm.md |
Tiling 3 niveaux, mémoire partagée, pipelining, autotuning |
patterns-memory.md |
from_dlpack, TMA, cp.async, TMEM, copy atoms |
patterns-compilation.md |
Control flow, JIT caching, TVM FFI, compilation AOT |
patterns-pipeline.md |
Producteur-consommateur, classes pipeline, barrières, warp specialization |
api-core.md |
cute module : layouts, tenseurs, math, copy, gemm, printing |
api-arch.md |
cute.arch : thread indexing, sync, atomics, memory ops |
api-nvgpu.md |
cute.nvgpu : warp/warpgroup/cpasync/tcgen05 MMA et copy |
api-runtime-utils.md |
Runtime : from_dlpack, tenseurs fake, utils, schedulers |
troubleshooting.md |
Debugging, env vars, erreurs courantes, limitations, FAQ |
Comment chercher : Grep pour votre mot-clé à travers references/. Lisez uniquement le fichier et la section pointée par Grep.
Tier 3: Documentation Originale
Si les Tiers 1–2 ne répondent pas, consultez la source :
- Web: https://docs.nvidia.com/cutlass/latest/
- GitHub: https://github.com/NVIDIA/cutlass
- Récupérez des pages de documentation spécifiques ou cherchez "CUTLASS CuTe DSL <topic>"
- Envisagez de distiller la réponse dans references/