Tests de parité pour Megatron Bridge
Cette skill fournit le cadre décisionnel pour choisir le bon outil de vérification et interpréter les résultats. Pour le workflow complet d'intégration de modèle (qui inclut les tests de parité aux jalons 1 et 2), voir la skill add-model-support.
Décision rapide : quel outil exécuter
| Ce que vous voulez vérifier | Outil | GPU ? | Quand l'utiliser |
|---|---|---|---|
| Tous les poids font un aller-retour exact (GPU unique) | hf_megatron_roundtrip.py |
Non | Premier test après écriture d'un bridge |
| Les poids font un aller-retour avec TP/PP/EP | hf_megatron_roundtrip_multi_gpu.py |
Oui | Après réussite sur GPU unique |
| Équivalence des logits en forward-pass | compare_hf_and_megatron/compare.py |
Oui | Après réussite de l'aller-retour |
| Santé du texte généré | hf_to_megatron_generate_text.py |
Oui | Grands modèles qui OOM compare.py |
| Vérification programmatique des poids | weights_verification_table() |
Oui | À l'intérieur de scripts Python |
| Santé de génération VLM | hf_to_megatron_generate_vlm.py |
Oui | Modèles VLM |
Tous les outils se trouvent dans examples/conversion/.
Stratégie de test en 3 niveaux
Niveau 1 : Aller-retour du dictionnaire d'état (correspondance exacte)
La vérification la plus rapide et la plus fondamentale. Si les mappages ne peuvent pas faire un aller-retour parfait des poids, rien d'autre ne fonctionnera.
# Aller-retour sur GPU unique
uv run python examples/conversion/hf_megatron_roundtrip.py \
--hf-model-id <org>/<model>
# Multi-GPU avec TP=2
uv run python -m torch.distributed.run --nproc_per_node=2 \
examples/conversion/hf_megatron_roundtrip_multi_gpu.py \
--hf-model-id <org>/<model> --tp 2
# Multi-GPU avec PP=2
uv run python -m torch.distributed.run --nproc_per_node=2 \
examples/conversion/hf_megatron_roundtrip_multi_gpu.py \
--hf-model-id <org>/<model> --pp 2
Attendu : Chaque poids affiche « Matches Original: ✓ ». Tout « ✗ » indique une erreur dans le mappage des paramètres.
Tolérance : Correspondance exacte (max_diff == 0.0). Les conversions aller-retour sont des simple restructurations de tenseurs — aucune arithmétique en virgule flottante n'est impliquée.
Pour une vérification programmatique à l'intérieur des scripts, utilisez le vérificateur intégré :
from megatron.bridge.models.conversion.utils import weights_verification_table
weights_verification_table(bridge, hf_pretrained, megatron_model)
Niveau 2 : Parité en forward-pass (GPU / bfloat16)
Après réussite de l'aller-retour, vérifiez que les poids convertis produisent une sortie en forward-pass identique.
# Comparer les logits (charge les deux modèles HF et Megatron)
uv run python -m torch.distributed.run --nproc_per_node=2 \
examples/conversion/compare_hf_and_megatron/compare.py \
--hf_model_path <org>/<model> --tp 2 \
--prompt "The capital of France is"
Attendu : Similarité cosinus > 99,99 %, correspondance des prédictions du token suivant.
Pour les grands modèles qui OOM compare.py (qui charge les deux modèles), utilisez plutôt la génération de texte :
uv run python -m torch.distributed.run --nproc_per_node=2 \
examples/conversion/hf_to_megatron_generate_text.py \
--hf_model_path <org>/<model> --tp 2 \
--prompt "The capital of France is" --max_new_tokens 50
Niveau 3 : Parité à l'entraînement (optionnel)
Vérifiez que quelques étapes d'entraînement produisent une perte décroissante. Cela détecte les problèmes de calcul de gradient que les tests en forward-pass manquent. Utilisez un modèle jouet avec 2 couches et petites dimensions. Voir le pattern de test fonctionnel dans la skill add-model-support (Jalon 3, Phase 6).
Tableau des tolérances
| Niveau de test | Dtype | Device | Max Diff | Similarité cosinus |
|---|---|---|---|---|
| Aller-retour | float32 | CPU | 0.0 (exact) | 1.0 (exact) |
| Forward pass | bfloat16 | GPU | < 1e-2 | > 0,9999 |
| Forward pass | float16 | GPU | < 1e-3 | > 0,99999 |
Utilitaires de comparaison
Ces fonctions sont utiles lors de l'écriture de scripts de vérification personnalisés ou du débogage de défaillances. Elles ne font pas partie de la bibliothèque Bridge — copiez-les dans votre script selon vos besoins.
import torch
def compare_tensors(a, b, name=""):
"""Compare two tensors and report similarity metrics."""
max_diff = (a - b).abs().max().item()
mean_diff = (a - b).abs().mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
a.flatten().float(), b.flatten().float(), dim=0,
).item()
print(f"{name}: max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}, cosine_sim={cos_sim:.8f}")
return max_diff, mean_diff, cos_sim
def compare_state_dicts(sd_a, sd_b, prefix=""):
"""Compare two state dicts key-by-key, reporting per-parameter differences."""
keys_a, keys_b = set(sd_a.keys()), set(sd_b.keys())
missing, extra = keys_a - keys_b, keys_b - keys_a
if missing:
print(f"{prefix}Missing keys: {sorted(missing)}")
if extra:
print(f"{prefix}Extra keys: {sorted(extra)}")
max_diffs = {}
for key in sorted(keys_a & keys_b):
diff = (sd_a[key].float() - sd_b[key].float()).abs().max().item()
if diff > 0:
max_diffs[key] = diff
print(f"{prefix}{key}: max_diff={diff:.6e}")
if not max_diffs and not missing and not extra:
print(f"{prefix}All {len(keys_a & keys_b)} parameters match exactly.")
return missing, extra, max_diffs
Workflow de débogage
Quand un test de parité échoue, suivez cette séquence :
-
Exécutez l'aller-retour sur GPU unique — s'il échoue, le mappage lui-même est erroné. Vérifiez
mapping_registry()dans le fichier bridge. -
Si GPU unique réussit mais multi-GPU échoue — la scatter/gather TP/PP est erronée. Comparez le résultat TP=1 par rapport à chaque shard TP. Voir la skill
nccl-contiguous-tensorspour les problèmes spécifiques à NCCL. -
Si l'aller-retour réussit mais le forward-pass échoue — les poids ont été chargés correctement mais l'architecture du modèle diffère. Vérifiez le mappage de config
provider_bridge()(normalisation, activation, RoPE, etc.). -
Utilisez le template de script de débogage de la skill
add-model-supportpour inspecter le nommage des clés runtime vs safetensors et le mappage de config du bridge.
Pour le catalogue complet des pièges (interleaving QKV, exports MoE fusionnés, embeddings liés, dequantization FP8, alias TE LayerNorm, etc.), voir la section Pièges de la skill add-model-support.
Ancres de code
| Composant | Chemin |
|---|---|
| Aller-retour sur GPU unique | examples/conversion/hf_megatron_roundtrip.py |
| Aller-retour multi-GPU | examples/conversion/hf_megatron_roundtrip_multi_gpu.py |
| Comparaison forward-pass | examples/conversion/compare_hf_and_megatron/compare.py |
| Génération de texte | examples/conversion/hf_to_megatron_generate_text.py |
| Génération VLM | examples/conversion/hf_to_megatron_generate_vlm.py |
| CLI checkpoint | examples/conversion/convert_checkpoints.py |
| Créateur de modèle jouet | examples/conversion/create_hf_toy_model.py |
| Utilitaire de vérification | src/megatron/bridge/models/conversion/utils.py |
| Vérification d'adaptateur | examples/conversion/adapter/verify_adapter.py |