Ajouter un noyau cuTile à TileGym
Workflow end-to-end pour ajouter un nouvel opérateur (ex. my_op) avec backend cuTile.
Règles d'exécution
DOIT suivre ces règles strictement :
- Utiliser TodoWrite pour créer la checklist ci-dessous AVANT d'écrire du code
- Exécuter les étapes dans l'ordre — NE PAS sauter ou combiner les étapes
- Marquer chaque todo comme
completedaprès avoir fini,in_progressau démarrage - Si une étape ne s'applique pas (ex. pas d'impl cuTile), la marquer
completedavec une note, NE PAS sauter silencieusement - Chaque étape DOIT résulter en une écriture de fichier ou une décision de saut explicite — pas d'omissions silencieuses
Instructions
DOIT copier cette checklist dans TodoWrite au démarrage :
- [ ] Step 1: Register dispatch interface in ops.py
- [ ] Step 2: Implement cuTile backend
- [ ] Step 3: Register in __init__.py (cutile)
- [ ] Step 4: Add tests
- [ ] Step 5: Add benchmark to tests/benchmark
- [ ] Step 6: Verify (run pytest + lint)
Étape 1 : Enregistrer l'interface de dispatch
Fichier : src/tilegym/ops/ops.py
Ajouter une fonction @dispatch — c'est le seul point d'entrée pour tous les backends.
@dispatch(
"my_op",
)
def my_op(
input: torch.Tensor,
out: Optional[torch.Tensor] = None,
**kwargs: Any,
):
"""
Description of my_op.
Args:
input: Input tensor
out: Optional preallocated output tensor
**kwargs: Additional arguments for backend-specific configurations
Returns:
torch.Tensor
"""
raise NotImplementedError(f"my_op is not implemented for {get_current_backend()}")
Règles clés :
- Le corps de la fonction lève uniquement
NotImplementedError - Inclure
**kwargspour les paramètres spécifiques au backend
Référence : Voir les ops existantes dans src/tilegym/ops/ops.py (ex. silu_and_mul, softmax)
Étape 2 : Implémenter le backend cuTile
Fichier : src/tilegym/ops/cutile/my_op.py
La structure du fichier suit ce template :
import torch
import cuda.tile as ct
from tilegym.backend import register_impl
@ct.kernel
def my_op_kernel_ct(x, output, n_elements: ct.Constant[int], BLOCK_SIZE: ct.Constant[int]):
bid = ct.bid(0)
indices = bid * BLOCK_SIZE + ct.arange(0, BLOCK_SIZE)
x_val = ct.gather(x, indices)
# ... compute ...
ct.scatter(output, indices, result)
@register_impl("my_op", backend="cutile")
def my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs) -> torch.Tensor:
n = input.numel()
if out is None:
out = torch.empty_like(input)
grid = ((n + 1023) // 1024,)
ct.launch(stream, grid, kernel, (some args, ...))
return out
Référence : src/tilegym/ops/cutile/silu_and_mul.py
Étape 3 : Enregistrer dans __init__.py (CRITIQUE)
Oublier cette étape signifie que l'implémentation du backend cuTile ne sera jamais chargée.
Fichier : src/tilegym/ops/cutile/__init__.py
Ajouter à l'intérieur du bloc if is_backend_available("cutile"): (par ordre alphabétique) :
from . import my_op
Et dans la section d'importation de fonction :
from .my_op import my_op
Et ajouter "my_op" à __all__.
Étape 4 : Ajouter des tests
Fichier : tests/ops/test_my_op.py
CRITIQUE : Toujours importer depuis tilegym.ops, JAMAIS depuis tilegym.ops.cutile.my_op.
import pytest
import torch
from tilegym.backend import is_backend_available, set_backend
from .. import common
_backends = ["cutile"]
class Test_MY_OP(common.PyTestCase):
@staticmethod
def reference(input):
"""Reference implementation using PyTorch."""
return torch.some_reference(input)
@pytest.mark.parametrize("shape, dtype", [
((1024,), torch.float16),
((1024, 512), torch.float32),
((64, 64, 64), torch.bfloat16),
])
@pytest.mark.parametrize("backend", _backends)
def test_op(self, shape, dtype, backend, arch):
if backend == "cutile" and not is_backend_available("cutile"):
pytest.skip("Cutile backend not available")
try:
set_backend(backend)
except Exception as e:
pytest.skip(f"Backend is not supported: {e}")
self.setUp()
from tilegym.ops import my_op
A = torch.randn(*shape, dtype=dtype, device="cuda")
self.assertCorrectness(
my_op, self.reference, {"input": A},
atol=1e-3, rtol=1e-3,
)
Motifs clés :
_backends = ["cutile"]test_op: utiliserset_backend(backend)avec try-except, appelerself.setUp()
Référence : tests/ops/test_silu_and_mul.py
Voici les erreurs courantes.
1. Missing _backends list (inside class)
2. test_op / test_op_xxx — missing @pytest.mark.parametrize("backend", _backends), backend parameter, and tilegym.is_backend_available / tilegym.set_backend pattern
Étape 5 : Ajouter un benchmark à tests/benchmark
Fichier : tests/benchmark/bench_my_op.py
Règles clés de benchmark_rules.md :
- Appeler l'op via
tilegym.ops.my_op(a, b, ..., backend=backend)— ne PAS utiliserset_backend. - Définir
ALL_BACKENDS(inclure au minimumcutileettorch), filtrer avecget_supported_backends(). - Implémenter
reference_my_op(...)et l'enregistrer :register_impl("my_op", "torch")(reference_my_op). - Utiliser
create_benchmark_config()pour construire des configstriton.testing.Benchmark(ex. par shape/dtype). - Utiliser
@triton.testing.perf_report([...])surbench_my_op(...); à l'intérieur de la fonction bench : vérification de correction avectorch.testing.assert_close(fn(), ref(), ...), puisms = triton.testing.do_bench(fn)(oudo_bench_cudagraph), calculer GB/s ou TFLOPS, et retourner la métrique. - Point d'entrée :
if __name__ == "__main__": bench_my_op.run(print_data=True).
Structure du template :
import torch
import triton
import triton.testing
import tilegym
from tilegym.backend import is_backend_available, register_impl
ALL_BACKENDS = [
("cutile", "cuTile", ("orange", "-")) if is_backend_available("cutile") else None,
("torch", "PyTorch", ("green", "-")),
]
def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]
def reference_my_op(input: torch.Tensor, out: torch.Tensor = None, **kwargs):
"""Reference implementation using PyTorch."""
...
register_impl("my_op", "torch")(reference_my_op)
def create_benchmark_config(datatype, ...):
available_backends = get_supported_backends()
if not available_backends:
return None
backends, names, styles = zip(*available_backends)
return triton.testing.Benchmark(
x_names=["M"], # or other dimension names
x_vals=[...],
line_arg="backend",
line_vals=list(backends),
line_names=list(names),
styles=list(styles),
ylabel="GB/s", # or TFLOPS
plot_name="my-op-...",
args={"datatype": datatype, ...},
)
@triton.testing.perf_report([
create_benchmark_config(datatype, ...)
for datatype in [torch.float16, torch.float32]
for ... in [...]
])
def bench_my_op(M, backend, datatype, ..., device="cuda"):
x = torch.randn(..., dtype=datatype, device=device)
fn = lambda: tilegym.ops.my_op(x, backend=backend)
ref = lambda: reference_my_op(x)
torch.testing.assert_close(fn(), ref(), rtol=1e-2, atol=1e-2)
ms = triton.testing.do_bench(fn) # or do_bench_cudagraph(fn)
# Compute metric (e.g. GB/s or TFLOPS) from ms and problem size
return metric
if __name__ == "__main__":
bench_my_op.run(print_data=True)
Noms de graphiques Benchmark : Doivent inclure le suffixe -TFLOPS ou -GBps
- Exemple :
plot_name=f"persistent-layer-norm-M{num_rows}-{dtype_name}-GBps"
Étape 6 : Vérifier
# Run tests
pytest tests/ops/test_my_op.py -v
# Run benchmark (optional)
python tests/benchmark/bench_my_op.py
# Lint
pre-commit run -a