Guide d'exploration de la base de code TensorRT-LLM
Pourquoi c'est important
TRT-LLM est une vaste base de code (~500K lignes) avec de nombreuses abstractions réutilisables. La source la plus courante de travail gaspillé est de réimplémenter quelque chose qui existe déjà. Sur la branche short-seq MHA, ~250 lignes ont été écrites en 4 itérations avant de découvrir qu'une dispatch de 10 lignes vers une méthode existante (forward_context_default) était la bonne solution.
Règle empirique : Passez 30 minutes à lire le code existant avant d'écrire 1 ligne de nouveau code.
OBLIGATOIRE : Ignorez le backend TensorRT, concentrez-vous sur le backend PyTorch
Flux de travail d'exploration étape par étape
Étape 1 : Mapper la classe que vous modifiez
Avant d'ajouter du code à une classe, comprenez sa structure complète :
# Listez toutes les méthodes (pas seulement forward*)
grep -n "def " tensorrt_llm/_torch/modules/attention.py | head -50
# Listez tous les attributs définis dans __init__
grep -n "self\." tensorrt_llm/_torch/modules/attention.py | grep "__init__" -A 200 | head -80
# Trouvez la hiérarchie de classe
grep -n "class MLA\|class Attention\|class TrtllmAttention" tensorrt_llm/_torch/modules/attention.py
Étape 2 : Tracer les méthodes forward existantes
Lisez TOUTES les méthodes forward de la classe. Comprenez ce que chacune fait, quelles entrées elle attend et quels backends elle utilise.
# Trouvez toutes les méthodes forward
grep -n "def forward" tensorrt_llm/_torch/modules/attention.py
# Pour chacune, lisez l'implémentation complète (pas seulement la signature)
Posez-vous ces questions :
- Y a-t-il déjà une méthode forward existante qui calcule ce dont j'ai besoin ?
- Puis-je dispatcher vers une méthode existante en configurant le bon état ?
- Qu'aurais-je besoin de changer (attributs, gardes, assertions) pour la réutiliser ?
Étape 3 : Rechercher les backends et utilitaires existants
| Ce dont vous avez besoin | Recherchez | Résultats courants |
|---|---|---|
| Calcul d'attention | TrtllmAttention, create_attention, FlashInferAttention |
Gère les séquences compactées, longueurs variables, cache KV nativement |
| Fusion compilée | maybe_compile, maybe_compiled_cat, maybe_compiled_copy_ |
Déjà dans tensorrt_llm/_torch/utils.py |
| Application RoPE | RotaryEmbedding, apply_rotary_pos_emb, rope_fusion |
Plusieurs implémentations existent ; vérifiez celle utilisée par le chemin de code actuel |
| Gestion du cache KV | mla_rope_append_paged_kv, append_paged_kv, latent_cache |
Opérations RoPE + cache fusionnées dans les kernels C++ |
| Attention sparse | DSATrtllmAttention, indexer, topk_indices |
Backend spécifique à DSA avec routage sparse |
# Motif de recherche générique
grep -rn "KEYWORD" tensorrt_llm/_torch/ --include="*.py" | head -20
Étape 4 : Vérifier ce que gèrent les kernels fusionnés
De nombreuses opérations que vous pourriez implémenter manuellement sont déjà gérées par des kernels C++ fusionnés :
# Trouvez ce que le kernel d'attention gère en interne
grep -rn "latent_cache\|rope.*fuse\|rope_fusion" tensorrt_llm/_torch/attention_backend/
Surprise courante : Quand rope_fusion=True (apply_rotary_emb=False), le kernel d'attention fusionné gère RoPE en interne via latent_cache. Écrire du code RoPE personnalisé en Python est inutile et appliquera RoPE deux fois.
Étape 5 : Vérifier les assertions et invariants
Les assertions existantes peuvent avoir besoin d'être mises à jour quand vous ajoutez un nouveau chemin de code. Ne les contournez pas — changez-les si votre nouveau chemin les rend invalides :
# Trouvez les assertions dans la classe
grep -n "assert " tensorrt_llm/_torch/modules/attention.py
Exemple : Les modèles DSA avaient assert self.mha is None. En ajoutant short-seq MHA (qui crée self.mha pour les modèles DSA), l'assertion a été changée en assert self.mqa is not None — l'invariant réel testé.
Étape 6 : Comprendre les layouts de poids
Les layouts de poids diffèrent souvent entre les checkpoints HuggingFace et le format chargé de TRT-LLM :
# Trouvez le code de chargement/transformation de poids
grep -rn "load_.*weight\|weight.*transform\|load_kv_b_proj" tensorrt_llm/_torch/models/
# Vérifiez comment les poids sont organisés après chargement
grep -n "def load_" tensorrt_llm/_torch/models/modeling_deepseekv3.py
Critique pour les tests : Initialisez toujours les poids de test dans le layout chargé, pas le layout du checkpoint HF.
Étape 7 : Tracer les limitations des méthodes
Après avoir identifié une méthode à réutiliser, comprenez ce qu'elle ne fait PAS :
# Trouvez tous les appelants de la méthode pour voir son contexte de dispatch
grep -rn "forward_context_default\|forward_context(" tensorrt_llm/_torch/modules/attention.py
# Cherchez le dispatcher qui route vers cette méthode
# Souvent nommé similairement mais sans suffixe (p. ex., forward_context dispatch vers forward_context_default)
Posez-vous ces questions :
- Quels scénarios cette méthode gère-t-elle ? (prefill frais ? KV en cache ? contexte partitionné ?)
- Quels scénarios ne gère-t-elle PAS ?
- Y a-t-il un dispatcher de haut niveau qui route vers cette méthode pour le bon sous-ensemble de cas ?
- Si j'appelle cette méthode directement, quels scénarios vais-je traiter silencieusement mal ?
Exemple : forward_context_default() gère le prefill frais mais n'assiste PAS sur les tokens KV en cache. forward_context() est le dispatcher qui route vers forward_context_default, forward_context_with_cached_kv ou forward_context_with_chunked_prefill selon l'état du contexte et la version SM. Appeler forward_context_default directement pendant un contexte partitionné omet silencieusement les tokens en cache.
Motifs de découverte clés
Motif : « Puis-je réutiliser une méthode forward existante ? »
- Lisez la méthode forward cible (p. ex.,
forward_context_default) - Comparez-la à ce que votre nouveau chemin de code doit faire
- Si >70% de chevauchement, dispatchez vers la méthode existante au lieu d'en écrire une nouvelle
- Ajustez les attributs/état dans
__init__pour que la dispatch fonctionne
Motif : « Est-ce déjà géré par un kernel fusionné ? »
- Vérifiez si l'opération relève du scope du backend d'attention
- Vérifiez le flag
apply_rotary_emb/rope_fusion - Vérifiez la gestion de
latent_cache - Si le kernel fusionné le gère, ne le réimplémentez PAS en Python
Motif : « Suis-je en train d'appeler le bon niveau d'abstraction ? »
- Identifiez la méthode que vous prévoyez d'appeler
- Recherchez les méthodes qui APPELLENT cette méthode — il peut y avoir un dispatcher au-dessus
- Vérifiez si le dispatcher gère les cas limites que votre appel direct manquerait
- Préférez appeler le dispatcher plutôt que le gestionnaire spécifique
# Trouvez ce qui appelle forward_context_default pour découvrir la chaîne de dispatch
grep -n "forward_context_default" tensorrt_llm/_torch/modules/attention.py
Motif : « Un utilitaire existe-t-il déjà ? »
- Recherchez dans
tensorrt_llm/_torch/utils.pyles aides compilées - Recherchez dans
tensorrt_llm/_torch/modules/les utilitaires au niveau du module - Recherchez dans les fixtures de test dans
tests/unittest/_torch/les motifs de configuration de test
Erreurs courantes d'exploration
| Erreur | Conséquence | Prévention |
|---|---|---|
| Lire seulement la méthode que vous modifiez | Manquez qu'une autre méthode fait ce dont vous avez besoin | Lisez TOUTES les méthodes de la classe |
| Rechercher seulement le nom de fonction exact | Manquez les implémentations équivalentes | Recherchez le concept (p. ex., « attention », « rope », « expand kv ») |
| Supposer que les assertions sont immuables | Contourner avec des hacks (attributs séparés) | Questionner si l'intention de l'assertion s'applique toujours |
| Ne pas lire les capacités du kernel fusionné | Réimplémenter ce qu'il fait déjà | Vérifiez ce que latent_cache, rope_fusion etc. contrôlent |
| Lire seulement le code Python | Manquez les implémentations C++ appelées via bindings | Vérifiez tensorrt_llm/_torch/attention_backend/ pour les kernels natifs |
| Appeler une méthode directement au lieu de via son dispatcher | Manquez les cas limites (KV en cache, contexte partitionné, gating de version SM) | Recherchez les appelants de la méthode pour trouver la chaîne de dispatch |
| Supposer un comportement numérique uniforme sur le matériel | Dégradation silencieuse de la précision sur certaines versions SM | Vérifiez les gardes get_sm_version() près du site d'appel ; testez sur plusieurs matériels |
Référence de fichiers pour l'exploration
| Domaine | Fichiers clés à lire |
|---|---|
| Modules d'attention | tensorrt_llm/_torch/modules/attention.py |
| Backends d'attention | tensorrt_llm/_torch/attention_backend/ (trtllm_attention.py, sparse/) |
| Définitions de modèles | tensorrt_llm/_torch/models/modeling_*.py |
| Utilitaires | tensorrt_llm/_torch/utils.py |
| RoPE | tensorrt_llm/_torch/modules/rotary_embedding.py |
| Fixtures de test | tests/unittest/_torch/attention/ |
| Chargement de poids | tensorrt_llm/_torch/models/modeling_deepseekv3.py (recherchez load_) |