captum

Par mkurman · zorai

Captum (PyTorch) — interprétabilité des modèles et attribution de features. Integrated Gradients, DeepLIFT, SmoothGrad, Occlusion, approximation SHAP et Layer-wise Relevance Propagation. Pour les modèles vision et texte.

npx skills add https://github.com/mkurman/zorai --skill captum

Aperçu

Captum (Comprehension in PyTorch) fournit l'interprétabilité des modèles pour les modèles PyTorch. Implémente Integrated Gradients, Gradient SHAP, DeepLIFT, Occlusion, Feature Ablation et Layer Conductance. Supporte les modèles de vision par ordinateur, NLP et tabulaires.

Installation

uv pip install captum

Integrated Gradients

import torch
import torch.nn as nn
from captum.attr import IntegratedGradients

model = nn.Linear(10, 2)
input = torch.randn(1, 10)
baseline = torch.zeros(1, 10)

ig = IntegratedGradients(model)
attrs = ig.attribute(input, baseline, target=0)
print(f"Feature attributions: {attrs}")

Occlusion

from captum.attr import Occlusion

occ = Occlusion(model)
attrs = occ.attribute(input, target=0, sliding_window_shapes=(1,))  # 1D
print(attrs)

Visualisation

from captum.attr import visualization as viz

_ = viz.visualize_image_attr(
    attrs.squeeze().numpy(),
    original_image=input.squeeze().numpy(),
    method="heat_map",
    sign="absolute_value",
    show_colorbar=True,
)

Références

Skills similaires