nemo-mbridge-perf-megatron-fsdp

Par nvidia · skills

Guide opérationnel pour activer Megatron FSDP dans Megatron-Bridge, incluant les paramètres de configuration, les points d'ancrage dans le code, les pièges courants et la procédure de vérification.

npx skills add https://github.com/nvidia/skills --skill nemo-mbridge-perf-megatron-fsdp

Skill Megatron FSDP

Pour le contexte stable et le niveau de recommandation, voir :

  • @docs/training/megatron-fsdp.md
  • @skills/nemo-mbridge-perf-megatron-fsdp/card.yaml

Activation

Override minimal Megatron FSDP dans Bridge :

cfg.dist.use_megatron_fsdp = True
cfg.ddp.use_megatron_fsdp = True
cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params"
cfg.ddp.average_in_collective = False
cfg.checkpoint.ckpt_format = "fsdp_dtensor"

Exemple de correction de recette :

cfg = llama3_8b_pretrain_config()
cfg.dist.use_megatron_fsdp = True
cfg.ddp.use_megatron_fsdp = True
cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params"
cfg.ddp.average_in_collective = False
cfg.checkpoint.ckpt_format = "fsdp_dtensor"
cfg.checkpoint.save = "/tmp/fsdp_ckpts"
cfg.checkpoint.load = None

Note sur le harnais de performance :

python scripts/performance/launch.py --use_megatron_fsdp true

Ancres de code

Définition de la config Bridge :

use_megatron_fsdp: bool = False
"""Use Megatron's Fully Sharded Data Parallel. Cannot be used together with use_torch_fsdp2."""

use_torch_fsdp2: bool = False
"""Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel.
It is still not in a stable release stage, and may therefore contain bugs or other
potential issues."""

Validation Bridge :

if self.dist.use_megatron_fsdp and self.dist.use_torch_fsdp2:
    raise ValueError(...)
...
assert not self.dist.use_tp_pp_dp_mapping, "use_tp_pp_dp_mapping is not supported with Megatron FSDP"
...
assert self.checkpoint.ckpt_format == "fsdp_dtensor", (
    "Megatron FSDP only supports fsdp_dtensor checkpoint format"
)

Sélection du wrapper runtime :

if use_megatron_fsdp:
    DP = FullyShardedDataParallel
elif use_torch_fsdp2:
    DP = TorchFullyShardedDataParallel
else:
    DP = DistributedDataParallel
...
DP(
    config=get_model_config(model_chunk),
    ddp_config=ddp_config,
    module=model_chunk,
    ...
    pg_collection=pg_collection,
)

Overrides du harnais de performance :

recipe.ddp.use_megatron_fsdp = True
recipe.ddp.data_parallel_sharding_strategy = "optim_grads_params"
recipe.ddp.keep_fp8_transpose_cache = False
recipe.ddp.average_in_collective = False
...
recipe.checkpoint.load = None

Pièges

  1. Les recettes publiques exposent souvent use_megatron_fsdp mais conservent par défaut ckpt_format="torch_dist". Si la sauvegarde/restauration est activée, basculez vers fsdp_dtensor.
  2. use_torch_fsdp2 existe, mais sur la branche validée Bridge échoue avant l'entraînement car _ddp_wrap passe pg_collection.
  3. Le CPU offloading n'est valide que lorsque pipeline_model_parallel_size == 1 et la réaccumulation d'activation est désactivée.
  4. Upstream avertit que FSDP et TP/CP peuvent demander des réglages CUDA_DEVICE_MAX_CONNECTIONS différents sur Hopper et antérieurs.
  5. Megatron FSDP et FSDP2 s'excluent mutuellement.

Vérification

Utilisez le test fonctionnel de smoke existant sur 2 GPU :

CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.run --nproc_per_node=2 \
  -m pytest tests/functional_tests/training/test_megatron_fsdp.py::TestMegatronFSDP::test_fsdp_pretrain_basic -v -s

Critères de succès :

  • Pytest rapporte 1 passed
  • Le log montre une perte finie à la dernière itération
  • L'exécution se termine sans assertion de format de checkpoint

Skills similaires