pymc

Par mkurman · zorai

Modélisation bayésienne avec PyMC. Construisez des modèles hiérarchiques, MCMC (NUTS), inférence variationnelle, comparaison LOO/WAIC, vérifications a posteriori, pour la programmation probabiliste et l'inférence.

npx skills add https://github.com/mkurman/zorai --skill pymc

Modélisation Bayésienne avec PyMC

Aperçu

PyMC est une bibliothèque Python pour la modélisation Bayésienne et la programmation probabiliste. Construisez, ajustez, validez et comparez des modèles Bayésiens en utilisant l'API moderne de PyMC (version 5.x+), incluant les modèles hiérarchiques, l'échantillonnage MCMC (NUTS), l'inférence variationnelle et la comparaison de modèles (LOO, WAIC).

Quand utiliser cette compétence

Cette compétence doit être utilisée quand :

  • Construire des modèles Bayésiens (régression linéaire/logistique, modèles hiérarchiques, séries temporelles, etc.)
  • Effectuer l'échantillonnage MCMC ou l'inférence variationnelle
  • Conduire des vérifications prédictives a priori/a posteriori
  • Diagnostiquer les problèmes d'échantillonnage (divergences, convergence, ESS)
  • Comparer plusieurs modèles en utilisant les critères d'information (LOO, WAIC)
  • Implémenter la quantification de l'incertitude par des méthodes Bayésienne
  • Travailler avec des structures de données hiérarchiques/multiniveaux
  • Gérer les données manquantes ou l'erreur de mesure de manière rigoureuse

Flux de travail Bayésien standard

Suivez ce flux pour construire et valider des modèles Bayésiens :

1. Préparation des données

import pymc as pm
import arviz as az
import numpy as np

# Charger et préparer les données
X = ...  # Prédicteurs
y = ...  # Résultats

# Standardiser les prédicteurs pour un meilleur échantillonnage
X_mean = X.mean(axis=0)
X_std = X.std(axis=0)
X_scaled = (X - X_mean) / X_std

Bonnes pratiques clés :

  • Standardiser les prédicteurs continus (améliore l'efficacité d'échantillonnage)
  • Centrer les résultats quand possible
  • Gérer explicitement les données manquantes (traiter comme paramètres)
  • Utiliser des dimensions nommées avec coords pour plus de clarté

2. Construction du modèle

coords = {
    'predictors': ['var1', 'var2', 'var3'],
    'obs_id': np.arange(len(y))
}

with pm.Model(coords=coords) as model:
    # Priors
    alpha = pm.Normal('alpha', mu=0, sigma=1)
    beta = pm.Normal('beta', mu=0, sigma=1, dims='predictors')
    sigma = pm.HalfNormal('sigma', sigma=1)

    # Prédicteur linéaire
    mu = alpha + pm.math.dot(X_scaled, beta)

    # Vraisemblance
    y_obs = pm.Normal('y_obs', mu=mu, sigma=sigma, observed=y, dims='obs_id')

Bonnes pratiques clés :

  • Utiliser des priors faiblement informatifs (pas de priors plats)
  • Utiliser HalfNormal ou Exponential pour les paramètres d'échelle
  • Utiliser des dimensions nommées (dims) plutôt que shape quand possible
  • Utiliser pm.Data() pour les valeurs qui seront mises à jour pour les prédictions

3. Vérification prédictive a priori

Validez toujours les priors avant d'ajuster :

with model:
    prior_pred = pm.sample_prior_predictive(samples=1000, random_seed=42)

# Visualiser
az.plot_ppc(prior_pred, group='prior')

Vérifier :

  • Les prédictions a priori couvrent-elles des valeurs raisonnables ?
  • Les valeurs extrêmes sont-elles plausibles compte tenu des connaissances du domaine ?
  • Si les priors génèrent des données implausibles, ajuster et re-vérifier

4. Ajustement du modèle

with model:
    # Optionnel : exploration rapide avec ADVI
    # approx = pm.fit(n=20000)

    # Inférence MCMC complète
    idata = pm.sample(
        draws=2000,
        tune=1000,
        chains=4,
        target_accept=0.9,
        random_seed=42,
        idata_kwargs={'log_likelihood': True}  # Pour la comparaison de modèles
    )

Paramètres clés :

  • draws=2000 : Nombre d'échantillons par chaîne
  • tune=1000 : Échantillons de préchauffage (jetés)
  • chains=4 : Exécuter 4 chaînes pour vérifier la convergence
  • target_accept=0.9 : Plus élevé pour les posteriors difficiles (0,95-0,99)
  • Inclure log_likelihood=True pour la comparaison de modèles

5. Vérifier les diagnostics

Utiliser le script de diagnostic :

from scripts.model_diagnostics import check_diagnostics

results = check_diagnostics(idata, var_names=['alpha', 'beta', 'sigma'])

Vérifier :

  • R-hat < 1,01 : Les chaînes ont convergé
  • ESS > 400 : Suffisamment d'échantillons efficaces
  • Pas de divergences : NUTS a échantillonné avec succès
  • Graphiques de trace : Les chaînes doivent bien se mélanger (chenille floue)

Si des problèmes surviennent :

  • Divergences → Augmenter target_accept=0.95, utiliser une paramétrisation non-centrée
  • ESS faible → Échantillonner plus, reparamétrer pour réduire la corrélation
  • R-hat élevé → Exécuter plus longtemps, vérifier la multimodalité

6. Vérification prédictive a posteriori

Validez l'ajustement du modèle :

with model:
    pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=42)

# Visualiser
az.plot_ppc(idata)

Vérifier :

  • Les prédictions a posteriori capturent-elles les motifs des données observées ?
  • Y a-t-il des écarts systématiques évidents (mauvaise spécification du modèle) ?
  • Considérer des modèles alternatifs si l'ajustement est mauvais

7. Analyser les résultats

# Statistiques récapitulatives
print(az.summary(idata, var_names=['alpha', 'beta', 'sigma']))

# Distributions a posteriori
az.plot_posterior(idata, var_names=['alpha', 'beta', 'sigma'])

# Estimations des coefficients
az.plot_forest(idata, var_names=['beta'], combined=True)

8. Faire des prédictions

X_new = ...  # Nouvelles valeurs de prédicteurs
X_new_scaled = (X_new - X_mean) / X_std

with model:
    pm.set_data({'X_scaled': X_new_scaled})
    post_pred = pm.sample_posterior_predictive(
        idata.posterior,
        var_names=['y_obs'],
        random_seed=42
    )

# Extraire les intervalles de prédiction
y_pred_mean = post_pred.posterior_predictive['y_obs'].mean(dim=['chain', 'draw'])
y_pred_hdi = az.hdi(post_pred.posterior_predictive, var_names=['y_obs'])

Motifs de modèles courants

Régression linéaire

Pour les résultats continus avec des relations linéaires :

with pm.Model() as linear_model:
    alpha = pm.Normal('alpha', mu=0, sigma=10)
    beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)
    sigma = pm.HalfNormal('sigma', sigma=1)

    mu = alpha + pm.math.dot(X, beta)
    y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)

Utiliser le modèle : assets/linear_regression_template.py

Régression logistique

Pour les résultats binaires :

with pm.Model() as logistic_model:
    alpha = pm.Normal('alpha', mu=0, sigma=10)
    beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)

    logit_p = alpha + pm.math.dot(X, beta)
    y = pm.Bernoulli('y', logit_p=logit_p, observed=y_obs)

Modèles hiérarchiques

Pour les données groupées (utiliser la paramétrisation non-centrée) :

with pm.Model(coords={'groups': group_names}) as hierarchical_model:
    # Hyperpriors
    mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=10)
    sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=1)

    # Niveau groupe (non-centré)
    alpha_offset = pm.Normal('alpha_offset', mu=0, sigma=1, dims='groups')
    alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_offset, dims='groups')

    # Niveau observation
    mu = alpha[group_idx]
    sigma = pm.HalfNormal('sigma', sigma=1)
    y = pm.Normal('y', mu=mu, sigma=sigma, observed=y_obs)

Utiliser le modèle : assets/hierarchical_model_template.py

Critique : Utilisez toujours une paramétrisation non-centrée pour les modèles hiérarchiques pour éviter les divergences.

Régression Poisson

Pour les données de comptage :

with pm.Model() as poisson_model:
    alpha = pm.Normal('alpha', mu=0, sigma=10)
    beta = pm.Normal('beta', mu=0, sigma=10, shape=n_predictors)

    log_lambda = alpha + pm.math.dot(X, beta)
    y = pm.Poisson('y', mu=pm.math.exp(log_lambda), observed=y_obs)

Pour les comptages surdispersés, utilisez NegativeBinomial à la place.

Séries temporelles

Pour les processus autorégressifs :

with pm.Model() as ar_model:
    sigma = pm.HalfNormal('sigma', sigma=1)
    rho = pm.Normal('rho', mu=0, sigma=0.5, shape=ar_order)
    init_dist = pm.Normal.dist(mu=0, sigma=sigma)

    y = pm.AR('y', rho=rho, sigma=sigma, init_dist=init_dist, observed=y_obs)

Comparaison de modèles

Comparer les modèles

Utiliser LOO ou WAIC pour la comparaison de modèles :

from scripts.model_comparison import compare_models, check_loo_reliability

# Ajuster les modèles avec log_likelihood
models = {
    'Model1': idata1,
    'Model2': idata2,
    'Model3': idata3
}

# Comparer avec LOO
comparison = compare_models(models, ic='loo')

# Vérifier la fiabilité
check_loo_reliability(models)

Interprétation :

  • Δloo < 2 : Les modèles sont similaires, choisir le modèle plus simple
  • 2 < Δloo < 4 : Faible preuve d'un meilleur modèle
  • 4 < Δloo < 10 : Preuve modérée
  • Δloo > 10 : Forte preuve pour un meilleur modèle

Vérifier les valeurs Pareto-k :

  • k < 0,7 : LOO fiable
  • k > 0,7 : Considérer WAIC ou k-fold CV

Moyenne de modèles

Quand les modèles sont similaires, faire la moyenne des prédictions :

from scripts.model_comparison import model_averaging

averaged_pred, weights = model_averaging(models, var_name='y_obs')

Guide de sélection des distributions

Pour les priors

Paramètres d'échelle (σ, τ) :

  • pm.HalfNormal('sigma', sigma=1) - Choix par défaut
  • pm.Exponential('sigma', lam=1) - Alternative
  • pm.Gamma('sigma', alpha=2, beta=1) - Plus informatif

Paramètres non-bornés :

  • pm.Normal('theta', mu=0, sigma=1) - Pour les données standardisées
  • pm.StudentT('theta', nu=3, mu=0, sigma=1) - Robuste aux valeurs aberrantes

Paramètres positifs :

  • pm.LogNormal('theta', mu=0, sigma=1)
  • pm.Gamma('theta', alpha=2, beta=1)

Probabilités :

  • pm.Beta('p', alpha=2, beta=2) - Faiblement informatif
  • pm.Uniform('p', lower=0, upper=1) - Non-informatif (utiliser avec parcimonie)

Matrices de corrélation :

  • pm.LKJCorr('corr', n=n_vars, eta=2) - eta=1 uniforme, eta>1 préfère l'identité

Pour les vraisemblances

Résultats continus :

  • pm.Normal('y', mu=mu, sigma=sigma) - Par défaut pour les données continues
  • pm.StudentT('y', nu=nu, mu=mu, sigma=sigma) - Robuste aux valeurs aberrantes

Données de comptage :

  • pm.Poisson('y', mu=lambda) - Comptages équidispersés
  • pm.NegativeBinomial('y', mu=mu, alpha=alpha) - Comptages surdispersés
  • pm.ZeroInflatedPoisson('y', psi=psi, mu=mu) - Zéros en excès

Résultats binaires :

  • pm.Bernoulli('y', p=p) ou pm.Bernoulli('y', logit_p=logit_p)

Résultats catégoriques :

  • pm.Categorical('y', p=probs)

Voir : references/distributions.md pour une référence complète des distributions

Échantillonnage et inférence

MCMC avec NUTS

Par défaut et recommandé pour la plupart des modèles :

idata = pm.sample(
    draws=2000,
    tune=1000,
    chains=4,
    target_accept=0.9,
    random_seed=42
)

Ajuster si nécessaire :

  • Divergences → target_accept=0.95 ou plus
  • Échantillonnage lent → Utiliser ADVI pour l'initialisation
  • Paramètres discrets → Utiliser pm.Metropolis() pour les variables discrètes

Inférence variationnelle

Approximation rapide pour l'exploration ou l'initialisation :

with model:
    approx = pm.fit(n=20000, method='advi')

    # Utiliser pour l'initialisation
    start = approx.sample(return_inferencedata=False)[0]
    idata = pm.sample(start=start)

Compromis :

  • Beaucoup plus rapide que MCMC
  • Approximatif (peut sous-estimer l'incertitude)
  • Bon pour les grands modèles ou l'exploration rapide

Voir : references/sampling_inference.md pour un guide détaillé de l'échantillonnage

Scripts de diagnostic

Diagnostics complets

from scripts.model_diagnostics import create_diagnostic_report

create_diagnostic_report(
    idata,
    var_names=['alpha', 'beta', 'sigma'],
    output_dir='diagnostics/'
)

Crée :

  • Graphiques de trace
  • Graphiques de rang (vérification de mélange)
  • Graphiques d'autocorrélation
  • Graphiques d'énergie
  • Évolution de l'ESS
  • Statistiques récapitulatives en CSV

Vérification diagnostic rapide

from scripts.model_diagnostics import check_diagnostics

results = check_diagnostics(idata)

Vérifie R-hat, ESS, divergences et profondeur d'arbre.

Problèmes courants et solutions

Divergences

Symptôme : idata.sample_stats.diverging.sum() > 0

Solutions :

  1. Augmenter target_accept=0.95 ou 0.99
  2. Utiliser la paramétrisation non-centrée (modèles hiérarchiques)
  3. Ajouter des priors plus forts pour contraindre les paramètres
  4. Vérifier la mauvaise spécification du modèle

Faible taille d'échantillon efficace

Symptôme : ESS < 400

Solutions :

  1. Échantillonner plus : draws=5000
  2. Reparamétrer pour réduire la corrélation a posteriori
  3. Utiliser la décomposition QR pour la régression avec prédicteurs corrélés

R-hat élevé

Symptôme : R-hat > 1,01

Solutions :

  1. Exécuter des chaînes plus longues : tune=2000, draws=5000
  2. Vérifier la multimodalité
  3. Améliorer l'initialisation avec ADVI

Échantillonnage lent

Solutions :

  1. Utiliser l'initialisation ADVI
  2. Réduire la complexité du modèle
  3. Augmenter la parallélisation : cores=8, chains=8
  4. Utiliser l'inférence variationnelle si appropriée

Bonnes pratiques

Construction du modèle

  1. Toujours standardiser les prédicteurs pour un meilleur échantillonnage
  2. Utiliser des priors faiblement informatifs (pas plats)
  3. Utiliser des dimensions nommées (dims) pour la clarté
  4. Paramétrisation non-centrée pour les modèles hiérarchiques
  5. Vérifier les prédictions a priori avant d'ajuster

Échantillonnage

  1. Exécuter plusieurs chaînes (au minimum 4) pour la convergence
  2. Utiliser target_accept=0.9 comme ligne de base (plus élevé si nécessaire)
  3. Inclure log_likelihood=True pour la comparaison de modèles
  4. Définir la graine aléatoire pour la reproductibilité

Validation

  1. Vérifier les diagnostics avant l'interprétation (R-hat, ESS, divergences)
  2. Vérification prédictive a posteriori pour la validation du modèle
  3. Comparer plusieurs modèles quand approprié
  4. Signaler l'incertitude (intervalles HDI, pas seulement les estimations ponctuelles)

Flux de travail

  1. Commencer simple, ajouter la complexité graduellement
  2. Vérification prédictive a priori → Ajustement → Diagnostics → Vérification prédictive a posteriori
  3. Itérer sur la spécification du modèle en fonction des vérifications
  4. Documenter les hypothèses et les choix de priors

Ressources

Cette compétence inclut :

Références (references/)

  • distributions.md : Catalogue complet des distributions PyMC organisé par catégorie (continues, discrètes, multivariées, mélanges, séries temporelles). Utiliser pour sélectionner les priors ou vraisemblances.

  • sampling_inference.md : Guide détaillé des algorithmes d'échantillonnage (NUTS, Metropolis, SMC), inférence variationnelle (ADVI, SVGD) et gestion des problèmes d'échantillonnage. Utiliser en cas de problèmes de convergence ou pour choisir les méthodes d'inférence.

  • workflows.md : Exemples complets de flux de travail et motifs de code pour les types de modèles courants, préparation des données, sélection des priors et validation du modèle. Utiliser comme livre de recettes pour les analyses Bayésienne standard.

Scripts (scripts/)

  • model_diagnostics.py : Vérification diagnostique automatisée et génération de rapports. Fonctions : check_diagnostics() pour les vérifications rapides, create_diagnostic_report() pour l'analyse complète avec graphiques.

  • model_comparison.py : Utilitaires de comparaison de modèles utilisant LOO/WAIC. Fonctions : compare_models(), check_loo_reliability(), model_averaging().

Modèles (assets/)

  • linear_regression_template.py : Modèle complet pour la régression Bayésienne linéaire avec flux de travail complet (préparation des données, vérifications prédictives a priori, ajustement, diagnostics, prédictions).

  • hierarchical_model_template.py : Modèle complet pour les modèles hiérarchiques/multiniveaux avec paramétrisation non-centrée et analyse au niveau des groupes.

Référence rapide

Construction du modèle

with pm.Model(coords={'var': names}) as model:
    # Priors
    param = pm.Normal('param', mu=0, sigma=1, dims='var')
    # Vraisemblance
    y = pm.Normal('y', mu=..., sigma=..., observed=data)

Échantillonnage

idata = pm.sample(draws=2000, tune=1000, chains=4, target_accept=0.9)

Diagnostics

from scripts.model_diagnostics import check_diagnostics
check_diagnostics(idata)

Comparaison de modèles

from scripts.model_comparison import compare_models
compare_models({'m1': idata1, 'm2': idata2}, ic='loo')

Prédictions

with model:
    pm.set_data({'X': X_new})
    pred = pm.sample_posterior_predictive(idata.posterior)

Notes supplémentaires

  • PyMC s'intègre avec ArviZ pour la visualisation et les diagnostics
  • Utiliser pm.model_to_graphviz(model) pour visualiser la structure du modèle
  • Sauvegarder les résultats avec idata.to_netcdf('results.nc')
  • Charger avec az.from_netcdf('results.nc')
  • Pour très grands modèles, considérer minibatch ADVI ou sous-échantillonnage des données

Skills similaires