Aperçu
Captum (Comprehension in PyTorch) fournit l'interprétabilité des modèles pour les modèles PyTorch. Implémente Integrated Gradients, Gradient SHAP, DeepLIFT, Occlusion, Feature Ablation et Layer Conductance. Supporte les modèles de vision par ordinateur, NLP et tabulaires.
Installation
uv pip install captum
Integrated Gradients
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients
model = nn.Linear(10, 2)
input = torch.randn(1, 10)
baseline = torch.zeros(1, 10)
ig = IntegratedGradients(model)
attrs = ig.attribute(input, baseline, target=0)
print(f"Feature attributions: {attrs}")
Occlusion
from captum.attr import Occlusion
occ = Occlusion(model)
attrs = occ.attribute(input, target=0, sliding_window_shapes=(1,)) # 1D
print(attrs)
Visualisation
from captum.attr import visualization as viz
_ = viz.visualize_image_attr(
attrs.squeeze().numpy(),
original_image=input.squeeze().numpy(),
method="heat_map",
sign="absolute_value",
show_colorbar=True,
)