nemo-mbridge-resiliency

Par nvidia · skills

Fonctionnalités de résilience dans Megatron Bridge, notamment la tolérance aux pannes, la détection des stragglers, le redémarrage en cours de processus, la préemption et la machine d'états de relance.

npx skills add https://github.com/nvidia/skills --skill nemo-mbridge-resiliency

Résilience

Docs stables : @docs/training/resiliency.md, @docs/training/checkpointing.md Card : @skills/nemo-mbridge-resiliency/card.yaml

Activation

Tolérance aux pannes (Slurm uniquement)

Option 1 : plugin NeMo Run (recommandé)

from megatron.bridge.recipes.run_plugins import FaultTolerancePlugin
import nemo_run as run

task = run.Script(...)
run_plugins = [
    FaultTolerancePlugin(
        enable_ft_package=True,
        calc_ft_timeouts=True,
        num_in_job_restarts=3,
        num_job_retries_on_failure=2,
        initial_rank_heartbeat_timeout=1800,
        rank_heartbeat_timeout=300,
    )
]
run.run(task, plugins=run_plugins, executor=executor)
Paramètre du plugin Défaut Description
num_in_job_restarts 3 Max redémarrages au sein du même job
num_job_retries_on_failure 2 Max lancements de nouveaux jobs en cas d'échec
initial_rank_heartbeat_timeout 1800 Timeout du premier heartbeat (secondes)
rank_heartbeat_timeout 300 Timeout des heartbeats suivants (secondes)

Option 2 : config directe + ft_launcher

from megatron.bridge.training.config import FaultToleranceConfig

cfg.ft = FaultToleranceConfig(
    enable_ft_package=True,
    calc_ft_timeouts=True,
    simulate_fault=False,
    simulated_fault_type="random",
)

Lancez avec ft_launcher (pas torchrun) :

export GROUP_RANK=0  # requis pour non-Slurm
ft_launcher \
    --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
    --nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \
    --ft-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
    --ft-rank_out_of_section_timeout=300 \
    your_training_script.py
Paramètre de config Défaut Description
enable_ft_package False Activer la tolérance aux pannes
calc_ft_timeouts False Calcul auto des timeouts optimaux
simulate_fault False Activer la simulation de pannes pour les tests
simulated_fault_type "random" "rank_hung", "rank_killed" ou "random"
simulated_fault_rank None Rank spécifique à mettre en panne (aléatoire si None)
simulated_fault_base_delay 0 Délai de base avant simulation de panne

Le suivi des timeouts basé sur des sections couvre setup, étapes d'entraînement, checkpointing et temps hors section indépendamment. Les timeouts sont sauvegardés dans ft_state.json pour les exécutions suivantes quand calc_ft_timeouts=True.

Détection de stragglers NVRx

from megatron.bridge.training.config import NVRxStragglerDetectionConfig

cfg.nvrx_straggler = NVRxStragglerDetectionConfig(
    enabled=True,
    report_time_interval=300.0,
    calc_relative_gpu_perf=True,
    calc_individual_gpu_perf=True,
    num_gpu_perf_scores_to_print=5,
    gpu_relative_perf_threshold=0.7,
    gpu_individual_perf_threshold=0.7,
    stop_if_detected=False,
    enable_logging=True,
)
Paramètre Défaut Description
enabled False Activer la détection de stragglers
report_time_interval 300.0 Secondes entre les vérifications de stragglers
calc_relative_gpu_perf True Comparer les ranks entre eux
calc_individual_gpu_perf True Suivre la dégradation par rank au fil du temps
gpu_relative_perf_threshold 0.7 Seuil de performance relative (0-1)
gpu_individual_perf_threshold 0.7 Seuil de performance individuelle (0-1)
stop_if_detected False Terminer l'entraînement en cas de straggler détecté
num_gpu_perf_scores_to_print 5 Nombre de meilleurs/pires scores à afficher
profiling_interval 1 Intervalle de profilage pour le détecteur

Préemption

Plugin (Slurm)

from megatron.bridge.recipes.run_plugins import PreemptionPlugin

plugins = [
    PreemptionPlugin(
        preempt_time=60,
        enable_exit_handler=True,
        enable_exit_handler_for_data_loader=False,
    )
]
Paramètre du plugin Défaut Description
preempt_time 60 Secondes avant la limite du job pour envoyer le signal
enable_exit_handler True Activer le gestionnaire de signal dans l'entraînement
enable_exit_handler_for_data_loader False Activer pour les workers du dataloader

Config directe

import signal
cfg.train.exit_signal_handler = True
cfg.train.exit_signal = signal.SIGTERM
cfg.train.exit_signal_handler_for_dataloader = False

Machine à états re-run (expérimental)

from megatron.bridge.training.config import RerunStateMachineConfig

cfg.rerun_state_machine = RerunStateMachineConfig(
    rerun_mode="validate_results",
    check_for_nan_in_loss=True,
    check_for_spiky_loss=False,
    spiky_loss_factor=10.0,
)
Paramètre Défaut Description
rerun_mode "disabled" "disabled", "validate_results", "report_determinism_stats"
check_for_nan_in_loss True Vérifier les NaN dans la loss
check_for_spiky_loss False Vérifier les pertes anormalement grandes
spiky_loss_factor 10.0 Loss signalée si > facteur * max observé (augmenter pour les grands modèles)

Codes de sortie : 16 = reprendre pour désambiguïser, 17 = validation échouée.

Redémarrage en processus (expérimental)

from megatron.bridge.training.config import InProcessRestartConfig

cfg.inprocess_restart = InProcessRestartConfig(
    enabled=True,
    granularity="node",
    soft_timeout=60.0,
    hard_timeout=90.0,
)
Paramètre Défaut Description
enabled False Activer le redémarrage en processus
active_world_size None Ranks exécutant la charge (les autres sont des réserves chaudes)
granularity "node" Granularité de redémarrage "node" ou "rank"
max_iterations None Max tentatives de redémarrage (None = illimité)
soft_timeout 60.0 Détecter les blocages avec GIL libéré (secondes)
hard_timeout 90.0 Forcer l'arrêt des ranks bloqués (secondes)
heartbeat_interval 30.0 Intervalle de heartbeat (secondes)
heartbeat_timeout 60.0 Timeout de heartbeat manquant (secondes)
barrier_timeout 120.0 Timeout de barrière distribuée (secondes)
completion_timeout 120.0 Timeout de barrière de fin (secondes)
empty_cuda_cache True Vider le cache CUDA lors du redémarrage
max_rank_faults None Max défaillances de rank avant arrêt
monitor_process_logdir None Répertoire pour les logs du monitor

Variables d'environnement requises :

export TORCH_CPP_LOG_LEVEL=error
export TORCH_NCCL_RETHROW_CUDA_ERRORS=0
export NCCL_NVLS_ENABLE=0

Le timeout du watchdog NCCL de PyTorch doit dépasser hard_timeout. L'Executor Slurm de NeMo-Run n'est pas supporté ; lancez directement avec srun --kill-on-bad-exit=0.

Sauvegarde de checkpoint async

cfg.checkpoint.async_save = True
cfg.checkpoint.ckpt_format = "torch_dist"

Checkpointing local (NVRx)

cfg.checkpoint.non_persistent_local_ckpt_dir = "/local/scratch/ckpt"
cfg.checkpoint.non_persistent_local_ckpt_algo = "fully_parallel"

Ancres de code

Tolérance aux pannes

  • Config : src/megatron/bridge/training/config.pyFaultToleranceConfig
  • Runtime : src/megatron/bridge/training/fault_tolerance.py
  • Plugin : src/megatron/bridge/recipes/run_plugins.pyFaultTolerancePlugin
  • Perf plugin : scripts/performance/nemo-mbridge-resiliency_plugins.py
  • Tests : tests/unit_tests/training/test_fault_tolerance.py
  • Exemple : examples/training_features/nemo-mbridge-resiliency/fault_tolerance/

Détection de stragglers

  • Config : src/megatron/bridge/training/config.pyNVRxStragglerDetectionConfig
  • Runtime : src/megatron/bridge/training/nvrx_straggler.py
  • Boucle d'entraînement : src/megatron/bridge/training/train.pycheck_nvrx_straggler_detection
  • Tests : tests/unit_tests/training/test_nvrx_straggler.py, tests/functional_tests/training/test_nvrx_straggler.py
  • Exemple : examples/training_features/nemo-mbridge-resiliency/straggler_detection/

Redémarrage en processus

  • Config : src/megatron/bridge/training/config.pyInProcessRestartConfig
  • Runtime : src/megatron/bridge/training/inprocess_restart.py
  • Point d'entrée : src/megatron/bridge/training/pretrain.pymaybe_wrap_for_inprocess_restart
  • Tests : tests/unit_tests/training/test_inprocess_restart.py, tests/functional_tests/training/test_inprocess_restart.py

Préemption

  • Plugin : src/megatron/bridge/recipes/run_plugins.pyPreemptionPlugin
  • Gestionnaire de signal : src/megatron/bridge/training/utils/sig_utils.py
  • Tests : tests/unit_tests/recipes/test_run_plugins.py

Machine à états re-run

  • Config : src/megatron/bridge/training/config.pyRerunStateMachineConfig
  • Init : src/megatron/bridge/training/initialize.pyinit_rerun_state

Checkpointing

  • Sauvegarde async : src/megatron/bridge/training/checkpointing.pyschedule_async_save
  • Local ckpt : src/megatron/bridge/training/checkpointing.pyLocalCheckpointManager
  • Tests : tests/functional_tests/training/test_local_checkpointing.py

Pièges

  1. ft_launcher, pas torchrun : FaultToleranceConfig directe nécessite ft_launcher. Utiliser torchrun désactive silencieusement FT. Pour non-Slurm, définir GROUP_RANK=0.

  2. Async save nécessite torch_dist : async_save=True ne fonctionne qu'avec ckpt_format="torch_dist". Les autres formats échouent silencieusement ou génèrent une erreur.

  3. IPR + NeMo-Run : Le redémarrage en processus n'est pas compatible avec NeMo-Run ou les plugins de préemption Slurm. Nécessite des versions spécifiques de PyTorch/NCCL et des variables d'environnement.

  4. NVRx vs straggler legacy : Deux détecteurs existent. Utiliser NVRx (nvrx_straggler) ; ne pas activer les deux.

  5. stop_if_detected par défaut : NVRx enregistre mais n'arrête pas l'entraînement par défaut. Définir stop_if_detected=True pour l'arrêt automatique.

  6. Watchdog NCCL vs hard_timeout : Pour IPR, le timeout du watchdog NCCL doit dépasser hard_timeout ou PyTorch tue le processus avant la récupération.

  7. La machine à états rerun est alpha : Utiliser check_for_nan_in_loss=True pour la détection de NaN, mais ne pas se fier aux workflows complets de rerun pour le moment.

Vérification

Tolérance aux pannes

./examples/training_features/nemo-mbridge-resiliency/fault_tolerance/run_fault_tolerance.sh
./examples/training_features/nemo-mbridge-resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault

Cherchez les lignes de log [FaultTolerance] / [RankMonitorServer] avec les timeouts des sections. La panne simulée doit déclencher un redémarrage à partir du checkpoint.

Détection de stragglers

uv run python -m torch.distributed.run --nproc_per_node=2 \
    examples/training_features/nemo-mbridge-resiliency/straggler_detection/straggler_detection_example.py

Cherchez les rapports GPU relative performance et GPU individual performance avec les scores par rank.

Checkpoint async

Cherchez Scheduling async checkpoint save dans les logs. Les itérations d'entraînement doivent continuer pendant que les fichiers de checkpoint sont écrits.

Redémarrage en processus

pytest tests/functional_tests/training/test_inprocess_restart.py -v

Nécessite des versions PyTorch/NCCL compatibles.

Skills similaires