adding-cutile-kernel

Par nvidia · skills

Ajouter un nouvel opérateur de kernel GPU cuTile à TileGym. Couvre l'enregistrement du dispatch dans ops.py, l'implémentation du backend cuTile, les exports dans __init__.py, la création de tests et le benchmark dans tests/benchmark. À utiliser lors de l'ajout, de la création ou de l'implémentation d'un nouvel opérateur/kernel cuTile dans TileGym, ou pour savoir comment enregistrer un nouvel op cuTile.

npx skills add https://github.com/nvidia/skills --skill adding-cutile-kernel

Ajouter un noyau cuTile à TileGym

Workflow end-to-end pour ajouter un nouvel opérateur (ex. my_op) avec backend cuTile.

Règles d'exécution

DOIT suivre ces règles strictement :

  1. Utiliser TodoWrite pour créer la checklist ci-dessous AVANT d'écrire du code
  2. Exécuter les étapes dans l'ordre — NE PAS sauter ou combiner les étapes
  3. Marquer chaque todo comme completed après avoir fini, in_progress au démarrage
  4. Si une étape ne s'applique pas (ex. pas d'impl cuTile), la marquer completed avec une note, NE PAS sauter silencieusement
  5. Chaque étape DOIT résulter en une écriture de fichier ou une décision de saut explicite — pas d'omissions silencieuses

Instructions

DOIT copier cette checklist dans TodoWrite au démarrage :

- [ ] Step 1: Register dispatch interface in ops.py
- [ ] Step 2: Implement cuTile backend
- [ ] Step 3: Register in __init__.py (cutile)
- [ ] Step 4: Add tests
- [ ] Step 5: Add benchmark to tests/benchmark
- [ ] Step 6: Verify (run pytest + lint)

Étape 1 : Enregistrer l'interface de dispatch

Fichier : src/tilegym/ops/ops.py

Ajouter une fonction @dispatch — c'est le seul point d'entrée pour tous les backends.

@dispatch(
    "my_op",
)
def my_op(
    input: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    **kwargs: Any,
):
    """
    Description of my_op.

    Args:
        input: Input tensor
        out: Optional preallocated output tensor
        **kwargs: Additional arguments for backend-specific configurations

    Returns:
        torch.Tensor
    """
    raise NotImplementedError(f"my_op is not implemented for {get_current_backend()}")

Règles clés :

  • Le corps de la fonction lève uniquement NotImplementedError
  • Inclure **kwargs pour les paramètres spécifiques au backend

Référence : Voir les ops existantes dans src/tilegym/ops/ops.py (ex. silu_and_mul, softmax)

Étape 2 : Implémenter le backend cuTile

Fichier : src/tilegym/ops/cutile/my_op.py

La structure du fichier suit ce template :

import torch
import cuda.tile as ct

from tilegym.backend import register_impl


@ct.kernel
def my_op_kernel_ct(x, output, n_elements: ct.Constant[int], BLOCK_SIZE: ct.Constant[int]):
    bid = ct.bid(0)
    indices = bid * BLOCK_SIZE + ct.arange(0, BLOCK_SIZE)
    x_val = ct.gather(x, indices)
    # ... compute ...
    ct.scatter(output, indices, result)


@register_impl("my_op", backend="cutile")
def my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs) -> torch.Tensor:
    n = input.numel()
    if out is None:
        out = torch.empty_like(input)
    grid = ((n + 1023) // 1024,)
    ct.launch(stream, grid, kernel, (some args, ...))
    return out

Référence : src/tilegym/ops/cutile/silu_and_mul.py

Étape 3 : Enregistrer dans __init__.py (CRITIQUE)

Oublier cette étape signifie que l'implémentation du backend cuTile ne sera jamais chargée.

Fichier : src/tilegym/ops/cutile/__init__.py

Ajouter à l'intérieur du bloc if is_backend_available("cutile"): (par ordre alphabétique) :

from . import my_op

Et dans la section d'importation de fonction :

from .my_op import my_op

Et ajouter "my_op" à __all__.

Étape 4 : Ajouter des tests

Fichier : tests/ops/test_my_op.py

CRITIQUE : Toujours importer depuis tilegym.ops, JAMAIS depuis tilegym.ops.cutile.my_op.

import pytest
import torch

from tilegym.backend import is_backend_available, set_backend
from .. import common

_backends = ["cutile"]


class Test_MY_OP(common.PyTestCase):
    @staticmethod
    def reference(input):
        """Reference implementation using PyTorch."""
        return torch.some_reference(input)

    @pytest.mark.parametrize("shape, dtype", [
        ((1024,), torch.float16),
        ((1024, 512), torch.float32),
        ((64, 64, 64), torch.bfloat16),
    ])
    @pytest.mark.parametrize("backend", _backends)
    def test_op(self, shape, dtype, backend, arch):
        if backend == "cutile" and not is_backend_available("cutile"):
            pytest.skip("Cutile backend not available")
        try:
            set_backend(backend)
        except Exception as e:
            pytest.skip(f"Backend is not supported: {e}")

        self.setUp()

        from tilegym.ops import my_op

        A = torch.randn(*shape, dtype=dtype, device="cuda")
        self.assertCorrectness(
            my_op, self.reference, {"input": A},
            atol=1e-3, rtol=1e-3,
        )

Motifs clés :

  • _backends = ["cutile"]
  • test_op : utiliser set_backend(backend) avec try-except, appeler self.setUp()

Référence : tests/ops/test_silu_and_mul.py

Voici les erreurs courantes.

1. Missing _backends list (inside class)
2. test_op / test_op_xxx — missing @pytest.mark.parametrize("backend", _backends), backend parameter, and tilegym.is_backend_available / tilegym.set_backend pattern

Étape 5 : Ajouter un benchmark à tests/benchmark

Fichier : tests/benchmark/bench_my_op.py

Règles clés de benchmark_rules.md :

  • Appeler l'op via tilegym.ops.my_op(a, b, ..., backend=backend) — ne PAS utiliser set_backend.
  • Définir ALL_BACKENDS (inclure au minimum cutile et torch), filtrer avec get_supported_backends().
  • Implémenter reference_my_op(...) et l'enregistrer : register_impl("my_op", "torch")(reference_my_op).
  • Utiliser create_benchmark_config() pour construire des configs triton.testing.Benchmark (ex. par shape/dtype).
  • Utiliser @triton.testing.perf_report([...]) sur bench_my_op(...); à l'intérieur de la fonction bench : vérification de correction avec torch.testing.assert_close(fn(), ref(), ...), puis ms = triton.testing.do_bench(fn) (ou do_bench_cudagraph), calculer GB/s ou TFLOPS, et retourner la métrique.
  • Point d'entrée : if __name__ == "__main__": bench_my_op.run(print_data=True).

Structure du template :

import torch
import triton
import triton.testing

import tilegym
from tilegym.backend import is_backend_available, register_impl

ALL_BACKENDS = [
    ("cutile", "cuTile", ("orange", "-")) if is_backend_available("cutile") else None,
    ("torch", "PyTorch", ("green", "-")),
]

def get_supported_backends():
    return [p for p in ALL_BACKENDS if p is not None]

def reference_my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs):
    """Reference implementation using PyTorch."""
    ...

register_impl("my_op", "torch")(reference_my_op)

def create_benchmark_config(datatype, ...):
    available_backends = get_supported_backends()
    if not available_backends:
        return None
    backends, names, styles = zip(*available_backends)
    return triton.testing.Benchmark(
        x_names=["M"],  # or other dimension names
        x_vals=[...],
        line_arg="backend",
        line_vals=list(backends),
        line_names=list(names),
        styles=list(styles),
        ylabel="GB/s",  # or TFLOPS
        plot_name="my-op-...",
        args={"datatype": datatype, ...},
    )

@triton.testing.perf_report([
    create_benchmark_config(datatype, ...)
    for datatype in [torch.float16, torch.float32]
    for ... in [...]
])
def bench_my_op(M, backend, datatype, ..., device="cuda"):
    x = torch.randn(..., dtype=datatype, device=device)

    fn = lambda: tilegym.ops.my_op(x, backend=backend)
    ref = lambda: reference_my_op(x)
    torch.testing.assert_close(fn(), ref(), rtol=1e-2, atol=1e-2)

    ms = triton.testing.do_bench(fn)  # or do_bench_cudagraph(fn)
    # Compute metric (e.g. GB/s or TFLOPS) from ms and problem size
    return metric

if __name__ == "__main__":
    bench_my_op.run(print_data=True)

Noms de graphiques Benchmark : Doivent inclure le suffixe -TFLOPS ou -GBps

  • Exemple : plot_name=f"persistent-layer-norm-M{num_rows}-{dtype_name}-GBps"

Étape 6 : Vérifier

# Run tests
pytest tests/ops/test_my_op.py -v

# Run benchmark (optional)
python tests/benchmark/bench_my_op.py

# Lint
pre-commit run -a

Skills similaires