pytorch-lightning

Par mkurman · zorai

Framework de deep learning (PyTorch Lightning). Organisez le code PyTorch en LightningModules, configurez des Trainers pour multi-GPU/TPU, implémentez des pipelines de données, callbacks, logging (W&B, TensorBoard), entraînement distribué (DDP, FSDP, DeepSpeed), pour un entraînement de réseaux de neurones à grande échelle.

npx skills add https://github.com/mkurman/zorai --skill pytorch-lightning

PyTorch Lightning

Aperçu

PyTorch Lightning est un framework de deep learning qui organise le code PyTorch pour éliminer le boilerplate tout en conservant une flexibilité totale. Automatisez les workflows d'entraînement, l'orchestration multi-dispositifs, et implémentez les bonnes pratiques pour l'entraînement et la mise à l'échelle de réseaux de neurones sur plusieurs GPUs/TPUs.

Quand utiliser cette compétence

Cette compétence doit être utilisée quand :

  • Construire, entraîner ou déployer des réseaux de neurones avec PyTorch Lightning
  • Organiser le code PyTorch en LightningModules
  • Configurer des Trainers pour un entraînement multi-GPU/TPU
  • Implémenter des pipelines de données avec LightningDataModules
  • Travailler avec des callbacks, logging et stratégies d'entraînement distribué (DDP, FSDP, DeepSpeed)
  • Structurer des projets de deep learning de manière professionnelle

Capacités principales

1. LightningModule - Définition du modèle

Organisez les modèles PyTorch en six sections logiques :

  1. Initialisation - __init__() et setup()
  2. Boucle d'entraînement - training_step(batch, batch_idx)
  3. Boucle de validation - validation_step(batch, batch_idx)
  4. Boucle de test - test_step(batch, batch_idx)
  5. Prédiction - predict_step(batch, batch_idx)
  6. Configuration de l'optimiseur - configure_optimizers()

Référence rapide de template : Voir scripts/template_lightning_module.py pour un boilerplate complet.

Documentation détaillée : Lire references/lightning_module.md pour la documentation complète des méthodes, hooks, propriétés et bonnes pratiques.

2. Trainer - Automatisation de l'entraînement

Le Trainer automatise la boucle d'entraînement, la gestion des dispositifs, les opérations de gradients et les callbacks. Fonctionnalités clés :

  • Support multi-GPU/TPU avec sélection de stratégie (DDP, FSDP, DeepSpeed)
  • Entraînement en précision mixte automatique
  • Accumulation et clipping de gradients
  • Checkpointing et early stopping
  • Barres de progression et logging

Référence de configuration rapide : Voir scripts/quick_trainer_setup.py pour les configurations courantes du Trainer.

Documentation détaillée : Lire references/trainer.md pour tous les paramètres, méthodes et options de configuration.

3. LightningDataModule - Organisation du pipeline de données

Encapsulez toutes les étapes de traitement des données dans une classe réutilisable :

  1. prepare_data() - Télécharger et traiter les données (processus unique)
  2. setup() - Créer les datasets et appliquer les transformations (par-GPU)
  3. train_dataloader() - Retourner le DataLoader d'entraînement
  4. val_dataloader() - Retourner le DataLoader de validation
  5. test_dataloader() - Retourner le DataLoader de test

Référence rapide de template : Voir scripts/template_datamodule.py pour un boilerplate complet.

Documentation détaillée : Lire references/data_module.md pour les détails des méthodes et les patterns d'utilisation.

4. Callbacks - Logique d'entraînement extensible

Ajoutez des fonctionnalités personnalisées à des hooks d'entraînement spécifiques sans modifier votre LightningModule. Les callbacks intégrés incluent :

  • ModelCheckpoint - Sauvegarder les meilleurs/derniers modèles
  • EarlyStopping - Arrêter quand les métriques stagnent
  • LearningRateMonitor - Suivre les changements du scheduler de learning rate
  • BatchSizeFinder - Déterminer automatiquement la taille de batch optimale

Documentation détaillée : Lire references/callbacks.md pour les callbacks intégrés et la création de callbacks personnalisés.

5. Logging - Suivi des expériences

Intégrez avec plusieurs plateformes de logging :

  • TensorBoard (par défaut)
  • Weights & Biases (WandbLogger)
  • MLflow (MLFlowLogger)
  • Neptune (NeptuneLogger)
  • Comet (CometLogger)
  • CSV (CSVLogger)

Loguez les métriques en utilisant self.log("metric_name", value) dans n'importe quelle méthode de LightningModule.

Documentation détaillée : Lire references/logging.md pour la configuration des loggers et les options de configuration.

6. Entraînement distribué - Mise à l'échelle sur plusieurs dispositifs

Choisissez la bonne stratégie selon la taille du modèle :

  • DDP - Pour les modèles <500M paramètres (ResNet, petits transformers)
  • FSDP - Pour les modèles 500M+ paramètres (grands transformers, recommandé pour les utilisateurs de Lightning)
  • DeepSpeed - Pour les fonctionnalités de pointe et le contrôle granulaire

Configurez avec : Trainer(strategy="ddp", accelerator="gpu", devices=4)

Documentation détaillée : Lire references/distributed_training.md pour la comparaison des stratégies et la configuration.

7. Bonnes pratiques

  • Code agnostique au dispositif - Utilisez self.device au lieu de .cuda()
  • Sauvegarde des hyperparamètres - Utilisez self.save_hyperparameters() dans __init__()
  • Logging de métriques - Utilisez self.log() pour l'agrégation automatique sur les dispositifs
  • Reproductibilité - Utilisez seed_everything() et Trainer(deterministic=True)
  • Débogage - Utilisez Trainer(fast_dev_run=True) pour tester avec 1 batch

Documentation détaillée : Lire references/best_practices.md pour les patterns courants et les pièges.

Flux de travail rapide

  1. Définir le modèle :

    class MyModel(L.LightningModule):
        def __init__(self):
            super().__init__()
            self.save_hyperparameters()
            self.model = YourNetwork()
    
        def training_step(self, batch, batch_idx):
            x, y = batch
            loss = F.cross_entropy(self.model(x), y)
            self.log("train_loss", loss)
            return loss
    
        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters())
  2. Préparer les données :

    # Option 1 : DataLoaders directs
    train_loader = DataLoader(train_dataset, batch_size=32)
    
    # Option 2 : LightningDataModule (recommandé pour la réutilisabilité)
    dm = MyDataModule(batch_size=32)
  3. Entraîner :

    trainer = L.Trainer(max_epochs=10, accelerator="gpu", devices=2)
    trainer.fit(model, train_loader)  # ou trainer.fit(model, datamodule=dm)

Ressources

scripts/

Templates Python exécutables pour les patterns courants de PyTorch Lightning :

  • template_lightning_module.py - Boilerplate complet de LightningModule
  • template_datamodule.py - Boilerplate complet de LightningDataModule
  • quick_trainer_setup.py - Exemples de configuration courante du Trainer

references/

Documentation détaillée pour chaque composant de PyTorch Lightning :

  • lightning_module.md - Guide complet de LightningModule (méthodes, hooks, propriétés)
  • trainer.md - Configuration et paramètres du Trainer
  • data_module.md - Patterns et méthodes de LightningDataModule
  • callbacks.md - Callbacks intégrés et personnalisés
  • logging.md - Intégrations de loggers et utilisation
  • distributed_training.md - Comparaison et configuration de DDP, FSDP, DeepSpeed
  • best_practices.md - Patterns courants, conseils et pièges

Skills similaires