PyTorch Lightning
Aperçu
PyTorch Lightning est un framework de deep learning qui organise le code PyTorch pour éliminer le boilerplate tout en conservant une flexibilité totale. Automatisez les workflows d'entraînement, l'orchestration multi-dispositifs, et implémentez les bonnes pratiques pour l'entraînement et la mise à l'échelle de réseaux de neurones sur plusieurs GPUs/TPUs.
Quand utiliser cette compétence
Cette compétence doit être utilisée quand :
- Construire, entraîner ou déployer des réseaux de neurones avec PyTorch Lightning
- Organiser le code PyTorch en LightningModules
- Configurer des Trainers pour un entraînement multi-GPU/TPU
- Implémenter des pipelines de données avec LightningDataModules
- Travailler avec des callbacks, logging et stratégies d'entraînement distribué (DDP, FSDP, DeepSpeed)
- Structurer des projets de deep learning de manière professionnelle
Capacités principales
1. LightningModule - Définition du modèle
Organisez les modèles PyTorch en six sections logiques :
- Initialisation -
__init__()etsetup() - Boucle d'entraînement -
training_step(batch, batch_idx) - Boucle de validation -
validation_step(batch, batch_idx) - Boucle de test -
test_step(batch, batch_idx) - Prédiction -
predict_step(batch, batch_idx) - Configuration de l'optimiseur -
configure_optimizers()
Référence rapide de template : Voir scripts/template_lightning_module.py pour un boilerplate complet.
Documentation détaillée : Lire references/lightning_module.md pour la documentation complète des méthodes, hooks, propriétés et bonnes pratiques.
2. Trainer - Automatisation de l'entraînement
Le Trainer automatise la boucle d'entraînement, la gestion des dispositifs, les opérations de gradients et les callbacks. Fonctionnalités clés :
- Support multi-GPU/TPU avec sélection de stratégie (DDP, FSDP, DeepSpeed)
- Entraînement en précision mixte automatique
- Accumulation et clipping de gradients
- Checkpointing et early stopping
- Barres de progression et logging
Référence de configuration rapide : Voir scripts/quick_trainer_setup.py pour les configurations courantes du Trainer.
Documentation détaillée : Lire references/trainer.md pour tous les paramètres, méthodes et options de configuration.
3. LightningDataModule - Organisation du pipeline de données
Encapsulez toutes les étapes de traitement des données dans une classe réutilisable :
prepare_data()- Télécharger et traiter les données (processus unique)setup()- Créer les datasets et appliquer les transformations (par-GPU)train_dataloader()- Retourner le DataLoader d'entraînementval_dataloader()- Retourner le DataLoader de validationtest_dataloader()- Retourner le DataLoader de test
Référence rapide de template : Voir scripts/template_datamodule.py pour un boilerplate complet.
Documentation détaillée : Lire references/data_module.md pour les détails des méthodes et les patterns d'utilisation.
4. Callbacks - Logique d'entraînement extensible
Ajoutez des fonctionnalités personnalisées à des hooks d'entraînement spécifiques sans modifier votre LightningModule. Les callbacks intégrés incluent :
- ModelCheckpoint - Sauvegarder les meilleurs/derniers modèles
- EarlyStopping - Arrêter quand les métriques stagnent
- LearningRateMonitor - Suivre les changements du scheduler de learning rate
- BatchSizeFinder - Déterminer automatiquement la taille de batch optimale
Documentation détaillée : Lire references/callbacks.md pour les callbacks intégrés et la création de callbacks personnalisés.
5. Logging - Suivi des expériences
Intégrez avec plusieurs plateformes de logging :
- TensorBoard (par défaut)
- Weights & Biases (WandbLogger)
- MLflow (MLFlowLogger)
- Neptune (NeptuneLogger)
- Comet (CometLogger)
- CSV (CSVLogger)
Loguez les métriques en utilisant self.log("metric_name", value) dans n'importe quelle méthode de LightningModule.
Documentation détaillée : Lire references/logging.md pour la configuration des loggers et les options de configuration.
6. Entraînement distribué - Mise à l'échelle sur plusieurs dispositifs
Choisissez la bonne stratégie selon la taille du modèle :
- DDP - Pour les modèles <500M paramètres (ResNet, petits transformers)
- FSDP - Pour les modèles 500M+ paramètres (grands transformers, recommandé pour les utilisateurs de Lightning)
- DeepSpeed - Pour les fonctionnalités de pointe et le contrôle granulaire
Configurez avec : Trainer(strategy="ddp", accelerator="gpu", devices=4)
Documentation détaillée : Lire references/distributed_training.md pour la comparaison des stratégies et la configuration.
7. Bonnes pratiques
- Code agnostique au dispositif - Utilisez
self.deviceau lieu de.cuda() - Sauvegarde des hyperparamètres - Utilisez
self.save_hyperparameters()dans__init__() - Logging de métriques - Utilisez
self.log()pour l'agrégation automatique sur les dispositifs - Reproductibilité - Utilisez
seed_everything()etTrainer(deterministic=True) - Débogage - Utilisez
Trainer(fast_dev_run=True)pour tester avec 1 batch
Documentation détaillée : Lire references/best_practices.md pour les patterns courants et les pièges.
Flux de travail rapide
-
Définir le modèle :
class MyModel(L.LightningModule): def __init__(self): super().__init__() self.save_hyperparameters() self.model = YourNetwork() def training_step(self, batch, batch_idx): x, y = batch loss = F.cross_entropy(self.model(x), y) self.log("train_loss", loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters()) -
Préparer les données :
# Option 1 : DataLoaders directs train_loader = DataLoader(train_dataset, batch_size=32) # Option 2 : LightningDataModule (recommandé pour la réutilisabilité) dm = MyDataModule(batch_size=32) -
Entraîner :
trainer = L.Trainer(max_epochs=10, accelerator="gpu", devices=2) trainer.fit(model, train_loader) # ou trainer.fit(model, datamodule=dm)
Ressources
scripts/
Templates Python exécutables pour les patterns courants de PyTorch Lightning :
template_lightning_module.py- Boilerplate complet de LightningModuletemplate_datamodule.py- Boilerplate complet de LightningDataModulequick_trainer_setup.py- Exemples de configuration courante du Trainer
references/
Documentation détaillée pour chaque composant de PyTorch Lightning :
lightning_module.md- Guide complet de LightningModule (méthodes, hooks, propriétés)trainer.md- Configuration et paramètres du Trainerdata_module.md- Patterns et méthodes de LightningDataModulecallbacks.md- Callbacks intégrés et personnaliséslogging.md- Intégrations de loggers et utilisationdistributed_training.md- Comparaison et configuration de DDP, FSDP, DeepSpeedbest_practices.md- Patterns courants, conseils et pièges