stable-baselines3

Par mkurman · zorai

Algorithmes d'apprentissage par renforcement prêts pour la production (PPO, SAC, DQN, TD3, DDPG, A2C) avec une API de type scikit-learn. À utiliser pour les expériences RL standard, le prototypage rapide et les implémentations d'algorithmes bien documentées. Idéal pour le RL mono-agent avec des environnements Gymnasium. Pour l'entraînement parallèle haute performance, les systèmes multi-agents ou les environnements vectorisés personnalisés, utilisez plutôt pufferlib.

npx skills add https://github.com/mkurman/zorai --skill stable-baselines3

Stable Baselines3

Aperçu

Stable Baselines3 (SB3) est une bibliothèque basée sur PyTorch fournissant des implémentations fiables d'algorithmes d'apprentissage par renforcement. Cette compétence fournit des conseils complets pour entraîner des agents RL, créer des environnements personnalisés, implémenter des callbacks et optimiser les workflows d'entraînement en utilisant l'API unifiée de SB3.

Capacités principales

1. Entraînement d'agents RL

Pattern d'entraînement basique:

import gymnasium as gym
from stable_baselines3 import PPO

# Create environment
env = gym.make("CartPole-v1")

# Initialize agent
model = PPO("MlpPolicy", env, verbose=1)

# Train the agent
model.learn(total_timesteps=10000)

# Save the model
model.save("ppo_cartpole")

# Load the model (without prior instantiation)
model = PPO.load("ppo_cartpole", env=env)

Notes importantes :

  • total_timesteps est une limite inférieure ; l'entraînement réel peut dépasser cette valeur en raison de la collecte de batch
  • Utilisez model.load() comme méthode statique, pas sur une instance existante
  • Le replay buffer n'est PAS sauvegardé avec le modèle pour économiser de l'espace

Sélection d'algorithme : Consultez references/algorithms.md pour des caractéristiques détaillées des algorithmes et des conseils de sélection. Référence rapide :

  • PPO/A2C : Usage général, supporte tous les types d'espaces d'action, bon pour le multiprocessing
  • SAC/TD3 : Contrôle continu, off-policy, efficace en échantillons
  • DQN : Actions discrètes, off-policy
  • HER : Tâches orientées vers un but

Consultez scripts/train_rl_agent.py pour un modèle d'entraînement complet avec les meilleures pratiques.

2. Environnements personnalisés

Exigences : Les environnements personnalisés doivent hériter de gymnasium.Env et implémenter :

  • __init__() : Définir action_space et observation_space
  • reset(seed, options) : Retourner l'observation initiale et un dictionnaire info
  • step(action) : Retourner observation, reward, terminated, truncated, info
  • render() : Visualisation (optionnel)
  • close() : Nettoyer les ressources

Contraintes clés :

  • Les observations d'image doivent être np.uint8 dans la plage [0, 255]
  • Utilisez le format canaux en premier quand c'est possible (channels, height, width)
  • SB3 normalise automatiquement les images en divisant par 255
  • Définissez normalize_images=False dans policy_kwargs si pré-normalisé
  • SB3 ne supporte PAS les espaces Discrete ou MultiDiscrete avec start!=0

Validation :

from stable_baselines3.common.env_checker import check_env

check_env(env, warn=True)

Consultez scripts/custom_env_template.py pour un modèle d'environnement personnalisé complet et references/custom_environments.md pour des conseils complets.

3. Environnements vectorisés

Objectif : Les environnements vectorisés exécutent plusieurs instances d'environnement en parallèle, accélérant l'entraînement et habilitant certains wrappers (frame-stacking, normalisation).

Types :

  • DummyVecEnv : Exécution séquentielle dans le processus courant (pour les environnements légers)
  • SubprocVecEnv : Exécution parallèle sur plusieurs processus (pour les environnements gourmands en calcul)

Configuration rapide :

from stable_baselines3.common.env_util import make_vec_env

# Create 4 parallel environments
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)

Optimisation Off-Policy : Lorsque vous utilisez plusieurs environnements avec des algorithmes off-policy (SAC, TD3, DQN), définissez gradient_steps=-1 pour effectuer une mise à jour de gradient par étape d'environnement, équilibrant le temps horloge et l'efficacité d'échantillonnage.

Différences d'API :

  • reset() retourne uniquement les observations (info disponible dans vec_env.reset_infos)
  • step() retourne un 4-tuple : (obs, rewards, dones, infos) pas un 5-tuple
  • Les environnements se réinitialisent automatiquement après les épisodes
  • Les observations terminales disponibles via infos[env_idx]["terminal_observation"]

Consultez references/vectorized_envs.md pour des informations détaillées sur les wrappers et l'usage avancé.

4. Callbacks pour la surveillance et le contrôle

Objectif : Les callbacks permettent de surveiller les métriques, de sauvegarder des points de contrôle, d'implémenter l'arrêt anticipé et la logique d'entraînement personnalisée sans modifier les algorithmes principaux.

Callbacks courants :

  • EvalCallback : Évaluer périodiquement et sauvegarder le meilleur modèle
  • CheckpointCallback : Sauvegarder les points de contrôle du modèle à des intervalles
  • StopTrainingOnRewardThreshold : Arrêter lorsque la récompense cible est atteinte
  • ProgressBarCallback : Afficher la progression de l'entraînement avec le chronométrage

Structure de Callback personnalisé :

from stable_baselines3.common.callbacks import BaseCallback

class CustomCallback(BaseCallback):
    def _on_training_start(self):
        # Called before first rollout
        pass

    def _on_step(self):
        # Called after each environment step
        # Return False to stop training
        return True

    def _on_rollout_end(self):
        # Called at end of rollout
        pass

Attributs disponibles :

  • self.model : L'instance de l'algorithme RL
  • self.num_timesteps : Total des étapes d'environnement
  • self.training_env : L'environnement d'entraînement

Chaînage de Callbacks :

from stable_baselines3.common.callbacks import CallbackList

callback = CallbackList([eval_callback, checkpoint_callback, custom_callback])
model.learn(total_timesteps=10000, callback=callback)

Consultez references/callbacks.md pour une documentation complète des callbacks.

5. Persistance et inspection du modèle

Sauvegarde et chargement :

# Save model
model.save("model_name")

# Save normalization statistics (if using VecNormalize)
vec_env.save("vec_normalize.pkl")

# Load model
model = PPO.load("model_name", env=env)

# Load normalization statistics
vec_env = VecNormalize.load("vec_normalize.pkl", vec_env)

Accès aux paramètres :

# Get parameters
params = model.get_parameters()

# Set parameters
model.set_parameters(params)

# Access PyTorch state dict
state_dict = model.policy.state_dict()

6. Évaluation et enregistrement

Évaluation :

from stable_baselines3.common.evaluation import evaluate_policy

mean_reward, std_reward = evaluate_policy(
    model,
    env,
    n_eval_episodes=10,
    deterministic=True
)

Enregistrement vidéo :

from stable_baselines3.common.vec_env import VecVideoRecorder

# Wrap environment with video recorder
env = VecVideoRecorder(
    env,
    "videos/",
    record_video_trigger=lambda x: x % 2000 == 0,
    video_length=200
)

Consultez scripts/evaluate_agent.py pour un modèle complet d'évaluation et d'enregistrement.

7. Fonctionnalités avancées

Calendriers de taux d'apprentissage :

def linear_schedule(initial_value):
    def func(progress_remaining):
        # progress_remaining goes from 1 to 0
        return progress_remaining * initial_value
    return func

model = PPO("MlpPolicy", env, learning_rate=linear_schedule(0.001))

Politiques multi-entrées (Observations Dict) :

model = PPO("MultiInputPolicy", env, verbose=1)

À utiliser quand les observations sont des dictionnaires (par ex., combinaison d'images avec des données de capteurs).

Hindsight Experience Replay :

from stable_baselines3 import SAC, HerReplayBuffer

model = SAC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy="future",
    ),
)

Intégration TensorBoard :

model = PPO("MlpPolicy", env, tensorboard_log="./tensorboard/")
model.learn(total_timesteps=10000)

Conseils de workflow

Démarrer un nouveau projet RL :

  1. Définir le problème : Identifier l'espace d'observation, l'espace d'action et la structure de récompense
  2. Choisir un algorithme : Utilisez references/algorithms.md pour des conseils de sélection
  3. Créer/adapter l'environnement : Utilisez scripts/custom_env_template.py si nécessaire
  4. Valider l'environnement : Toujours exécuter check_env() avant l'entraînement
  5. Configurer l'entraînement : Utilisez scripts/train_rl_agent.py comme modèle de démarrage
  6. Ajouter la surveillance : Implémenter des callbacks pour l'évaluation et les points de contrôle
  7. Optimiser les performances : Considérer les environnements vectorisés pour la vitesse
  8. Évaluer et itérer : Utilisez scripts/evaluate_agent.py pour l'évaluation

Problèmes courants :

  • Erreurs mémoire : Réduisez buffer_size pour les algorithmes off-policy ou utilisez moins d'environnements parallèles
  • Entraînement lent : Considérez SubprocVecEnv pour les environnements parallèles
  • Entraînement instable : Essayez différents algorithmes, ajustez les hyperparamètres ou vérifiez la mise à l'échelle des récompenses
  • Erreurs d'importation : Assurez-vous que stable_baselines3 est installé : uv pip install stable-baselines3[extra]

Ressources

scripts/

  • train_rl_agent.py : Modèle complet de script d'entraînement avec les meilleures pratiques
  • evaluate_agent.py : Modèle d'évaluation d'agent et d'enregistrement vidéo
  • custom_env_template.py : Modèle d'environnement Gym personnalisé

references/

  • algorithms.md : Guide détaillé de comparaison et de sélection d'algorithmes
  • custom_environments.md : Guide complet de création d'environnements personnalisés
  • callbacks.md : Référence complète du système de callbacks
  • vectorized_envs.md : Usage des environnements vectorisés et wrappers

Installation

# Basic installation
uv pip install stable-baselines3

# With extra dependencies (Tensorboard, etc.)
uv pip install stable-baselines3[extra]

Skills similaires