Stable Baselines3
Aperçu
Stable Baselines3 (SB3) est une bibliothèque basée sur PyTorch fournissant des implémentations fiables d'algorithmes d'apprentissage par renforcement. Cette compétence fournit des conseils complets pour entraîner des agents RL, créer des environnements personnalisés, implémenter des callbacks et optimiser les workflows d'entraînement en utilisant l'API unifiée de SB3.
Capacités principales
1. Entraînement d'agents RL
Pattern d'entraînement basique:
import gymnasium as gym
from stable_baselines3 import PPO
# Create environment
env = gym.make("CartPole-v1")
# Initialize agent
model = PPO("MlpPolicy", env, verbose=1)
# Train the agent
model.learn(total_timesteps=10000)
# Save the model
model.save("ppo_cartpole")
# Load the model (without prior instantiation)
model = PPO.load("ppo_cartpole", env=env)
Notes importantes :
total_timestepsest une limite inférieure ; l'entraînement réel peut dépasser cette valeur en raison de la collecte de batch- Utilisez
model.load()comme méthode statique, pas sur une instance existante - Le replay buffer n'est PAS sauvegardé avec le modèle pour économiser de l'espace
Sélection d'algorithme :
Consultez references/algorithms.md pour des caractéristiques détaillées des algorithmes et des conseils de sélection. Référence rapide :
- PPO/A2C : Usage général, supporte tous les types d'espaces d'action, bon pour le multiprocessing
- SAC/TD3 : Contrôle continu, off-policy, efficace en échantillons
- DQN : Actions discrètes, off-policy
- HER : Tâches orientées vers un but
Consultez scripts/train_rl_agent.py pour un modèle d'entraînement complet avec les meilleures pratiques.
2. Environnements personnalisés
Exigences :
Les environnements personnalisés doivent hériter de gymnasium.Env et implémenter :
__init__(): Définir action_space et observation_spacereset(seed, options): Retourner l'observation initiale et un dictionnaire infostep(action): Retourner observation, reward, terminated, truncated, inforender(): Visualisation (optionnel)close(): Nettoyer les ressources
Contraintes clés :
- Les observations d'image doivent être
np.uint8dans la plage [0, 255] - Utilisez le format canaux en premier quand c'est possible (channels, height, width)
- SB3 normalise automatiquement les images en divisant par 255
- Définissez
normalize_images=Falsedans policy_kwargs si pré-normalisé - SB3 ne supporte PAS les espaces
DiscreteouMultiDiscreteavecstart!=0
Validation :
from stable_baselines3.common.env_checker import check_env
check_env(env, warn=True)
Consultez scripts/custom_env_template.py pour un modèle d'environnement personnalisé complet et references/custom_environments.md pour des conseils complets.
3. Environnements vectorisés
Objectif : Les environnements vectorisés exécutent plusieurs instances d'environnement en parallèle, accélérant l'entraînement et habilitant certains wrappers (frame-stacking, normalisation).
Types :
- DummyVecEnv : Exécution séquentielle dans le processus courant (pour les environnements légers)
- SubprocVecEnv : Exécution parallèle sur plusieurs processus (pour les environnements gourmands en calcul)
Configuration rapide :
from stable_baselines3.common.env_util import make_vec_env
# Create 4 parallel environments
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
Optimisation Off-Policy :
Lorsque vous utilisez plusieurs environnements avec des algorithmes off-policy (SAC, TD3, DQN), définissez gradient_steps=-1 pour effectuer une mise à jour de gradient par étape d'environnement, équilibrant le temps horloge et l'efficacité d'échantillonnage.
Différences d'API :
reset()retourne uniquement les observations (info disponible dansvec_env.reset_infos)step()retourne un 4-tuple :(obs, rewards, dones, infos)pas un 5-tuple- Les environnements se réinitialisent automatiquement après les épisodes
- Les observations terminales disponibles via
infos[env_idx]["terminal_observation"]
Consultez references/vectorized_envs.md pour des informations détaillées sur les wrappers et l'usage avancé.
4. Callbacks pour la surveillance et le contrôle
Objectif : Les callbacks permettent de surveiller les métriques, de sauvegarder des points de contrôle, d'implémenter l'arrêt anticipé et la logique d'entraînement personnalisée sans modifier les algorithmes principaux.
Callbacks courants :
- EvalCallback : Évaluer périodiquement et sauvegarder le meilleur modèle
- CheckpointCallback : Sauvegarder les points de contrôle du modèle à des intervalles
- StopTrainingOnRewardThreshold : Arrêter lorsque la récompense cible est atteinte
- ProgressBarCallback : Afficher la progression de l'entraînement avec le chronométrage
Structure de Callback personnalisé :
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
def _on_training_start(self):
# Called before first rollout
pass
def _on_step(self):
# Called after each environment step
# Return False to stop training
return True
def _on_rollout_end(self):
# Called at end of rollout
pass
Attributs disponibles :
self.model: L'instance de l'algorithme RLself.num_timesteps: Total des étapes d'environnementself.training_env: L'environnement d'entraînement
Chaînage de Callbacks :
from stable_baselines3.common.callbacks import CallbackList
callback = CallbackList([eval_callback, checkpoint_callback, custom_callback])
model.learn(total_timesteps=10000, callback=callback)
Consultez references/callbacks.md pour une documentation complète des callbacks.
5. Persistance et inspection du modèle
Sauvegarde et chargement :
# Save model
model.save("model_name")
# Save normalization statistics (if using VecNormalize)
vec_env.save("vec_normalize.pkl")
# Load model
model = PPO.load("model_name", env=env)
# Load normalization statistics
vec_env = VecNormalize.load("vec_normalize.pkl", vec_env)
Accès aux paramètres :
# Get parameters
params = model.get_parameters()
# Set parameters
model.set_parameters(params)
# Access PyTorch state dict
state_dict = model.policy.state_dict()
6. Évaluation et enregistrement
Évaluation :
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=10,
deterministic=True
)
Enregistrement vidéo :
from stable_baselines3.common.vec_env import VecVideoRecorder
# Wrap environment with video recorder
env = VecVideoRecorder(
env,
"videos/",
record_video_trigger=lambda x: x % 2000 == 0,
video_length=200
)
Consultez scripts/evaluate_agent.py pour un modèle complet d'évaluation et d'enregistrement.
7. Fonctionnalités avancées
Calendriers de taux d'apprentissage :
def linear_schedule(initial_value):
def func(progress_remaining):
# progress_remaining goes from 1 to 0
return progress_remaining * initial_value
return func
model = PPO("MlpPolicy", env, learning_rate=linear_schedule(0.001))
Politiques multi-entrées (Observations Dict) :
model = PPO("MultiInputPolicy", env, verbose=1)
À utiliser quand les observations sont des dictionnaires (par ex., combinaison d'images avec des données de capteurs).
Hindsight Experience Replay :
from stable_baselines3 import SAC, HerReplayBuffer
model = SAC(
"MultiInputPolicy",
env,
replay_buffer_class=HerReplayBuffer,
replay_buffer_kwargs=dict(
n_sampled_goal=4,
goal_selection_strategy="future",
),
)
Intégration TensorBoard :
model = PPO("MlpPolicy", env, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=10000)
Conseils de workflow
Démarrer un nouveau projet RL :
- Définir le problème : Identifier l'espace d'observation, l'espace d'action et la structure de récompense
- Choisir un algorithme : Utilisez
references/algorithms.mdpour des conseils de sélection - Créer/adapter l'environnement : Utilisez
scripts/custom_env_template.pysi nécessaire - Valider l'environnement : Toujours exécuter
check_env()avant l'entraînement - Configurer l'entraînement : Utilisez
scripts/train_rl_agent.pycomme modèle de démarrage - Ajouter la surveillance : Implémenter des callbacks pour l'évaluation et les points de contrôle
- Optimiser les performances : Considérer les environnements vectorisés pour la vitesse
- Évaluer et itérer : Utilisez
scripts/evaluate_agent.pypour l'évaluation
Problèmes courants :
- Erreurs mémoire : Réduisez
buffer_sizepour les algorithmes off-policy ou utilisez moins d'environnements parallèles - Entraînement lent : Considérez SubprocVecEnv pour les environnements parallèles
- Entraînement instable : Essayez différents algorithmes, ajustez les hyperparamètres ou vérifiez la mise à l'échelle des récompenses
- Erreurs d'importation : Assurez-vous que
stable_baselines3est installé :uv pip install stable-baselines3[extra]
Ressources
scripts/
train_rl_agent.py: Modèle complet de script d'entraînement avec les meilleures pratiquesevaluate_agent.py: Modèle d'évaluation d'agent et d'enregistrement vidéocustom_env_template.py: Modèle d'environnement Gym personnalisé
references/
algorithms.md: Guide détaillé de comparaison et de sélection d'algorithmescustom_environments.md: Guide complet de création d'environnements personnaliséscallbacks.md: Référence complète du système de callbacksvectorized_envs.md: Usage des environnements vectorisés et wrappers
Installation
# Basic installation
uv pip install stable-baselines3
# With extra dependencies (Tensorboard, etc.)
uv pip install stable-baselines3[extra]