Skill Megatron FSDP
Pour le contexte stable et le niveau de recommandation, voir :
- @docs/training/megatron-fsdp.md
- @skills/perf-megatron-fsdp/card.yaml
Activation
Surcharge minimale Megatron FSDP dans Bridge :
cfg.dist.use_megatron_fsdp = True
cfg.ddp.use_megatron_fsdp = True
cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params"
cfg.ddp.average_in_collective = False
cfg.checkpoint.ckpt_format = "fsdp_dtensor"
Exemple de correction de recette :
cfg = llama3_8b_pretrain_config()
cfg.dist.use_megatron_fsdp = True
cfg.ddp.use_megatron_fsdp = True
cfg.ddp.data_parallel_sharding_strategy = "optim_grads_params"
cfg.ddp.average_in_collective = False
cfg.checkpoint.ckpt_format = "fsdp_dtensor"
cfg.checkpoint.save = "/tmp/fsdp_ckpts"
cfg.checkpoint.load = None
Note sur le banc de performance :
python scripts/performance/launch.py --use_megatron_fsdp true
Ancres de code
Définition de la configuration Bridge :
use_megatron_fsdp: bool = False
"""Use Megatron's Fully Sharded Data Parallel. Cannot be used together with use_torch_fsdp2."""
use_torch_fsdp2: bool = False
"""Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel.
It is still not in a stable release stage, and may therefore contain bugs or other
potential issues."""
Validation de Bridge :
if self.dist.use_megatron_fsdp and self.dist.use_torch_fsdp2:
raise ValueError(...)
...
assert not self.dist.use_tp_pp_dp_mapping, "use_tp_pp_dp_mapping is not supported with Megatron FSDP"
...
assert self.checkpoint.ckpt_format == "fsdp_dtensor", (
"Megatron FSDP only supports fsdp_dtensor checkpoint format"
)
Sélection du wrapper runtime :
if use_megatron_fsdp:
DP = FullyShardedDataParallel
elif use_torch_fsdp2:
DP = TorchFullyShardedDataParallel
else:
DP = DistributedDataParallel
...
DP(
config=get_model_config(model_chunk),
ddp_config=ddp_config,
module=model_chunk,
...
pg_collection=pg_collection,
)
Surcharges du banc de performance :
recipe.ddp.use_megatron_fsdp = True
recipe.ddp.data_parallel_sharding_strategy = "optim_grads_params"
recipe.ddp.keep_fp8_transpose_cache = False
recipe.ddp.average_in_collective = False
...
recipe.checkpoint.load = None
Pièges
- Les recettes publiques exposent souvent
use_megatron_fsdpmais conservent par défautckpt_format="torch_dist". Si la sauvegarde/chargement est activé, basculez versfsdp_dtensor. use_torch_fsdp2existe, mais sur la branche validée Bridge échoue avant l'entraînement car_ddp_wrappassepg_collection.- Le déchargement CPU n'est valide que quand
pipeline_model_parallel_size == 1et la récompilation d'activations est désactivée. - L'upstream avertit que FSDP et TP/CP peuvent vouloir des paramètres
CUDA_DEVICE_MAX_CONNECTIONSdifférents sur Hopper et antérieurs. - Megatron FSDP et FSDP2 s'excluent mutuellement.
Vérification
Utilisez le test de fumée fonctionnel existant sur 2 GPU :
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.run --nproc_per_node=2 \
-m pytest tests/functional_tests/training/test_megatron_fsdp.py::TestMegatronFSDP::test_fsdp_pretrain_basic -v -s
Critères de succès :
- Pytest signale
1 passed - Le journal affiche une perte finie à la dernière itération
- L'exécution se termine sans erreur d'assertion de format de checkpoint