name: torch-geometric description: "Guide pour construire des réseaux de neurones de graphe avec PyTorch Geometric (PyG). Utilisez cette compétence chaque fois que l'utilisateur pose des questions sur les réseaux de neurones de graphe, GNN, classification de nœuds, prédiction de liens, classification de graphes, réseaux de passage de messages, graphes hétérogènes, échantillonnage de voisins, ou toute tâche impliquant torch_geometric / PyG. Déclenchez aussi quand vous voyez des imports depuis torch_geometric, ou quand l'utilisateur mentionne les convolutions de graphe (GCN, GAT, GraphSAGE, GIN), structures de données graphe, ou travail avec des données relationnelles/réseau. Même si l'utilisateur dit juste « apprentissage sur graphes » ou « apprentissage géométrique profond », utilisez cette compétence."
tags: [graph-neural-networks, gnn-training, message-passing, graph-benchmarks, torch-geometric]
----|----------|----------|
| GCNConv | Classification de nœuds semi-supervisée, homogène | Inspirée spectrale, agrégation normalisée par degré |
| GATConv / GATv2Conv | Quand l'importance des voisins varie | Messages pondérés par attention |
| SAGEConv | Grands graphes, réglages inductifs | Échantillonnage efficace, agrégation apprenante |
| GINConv | Classification de graphes, maximiser l'expressivité | Aussi puissant que le test WL |
| TransformerConv | Attributs de bords riches, interactions complexes | Attention multi-têtes avec attributs de bords |
| EdgeConv | Nuages de points, graphes dynamiques | MLP sur attributs de bords (x_i, x_j - x_i) |
| RGCNConv | Hétérogène avec nombreux types de relations | Matrices de poids spécifiques aux relations |
| HGTConv | Graphes hétérogènes | Attention spécifique aux types |
Toutes les couches de convolution acceptent au minimum (x, edge_index). Beaucoup acceptent aussi edge_attr pour les attributs de bords.
Initialisation Lazy
Utilisez -1 pour les canaux d'entrée afin que PyG déduit les dimensions automatiquement — particulièrement utile pour les modèles hétérogènes :
conv = SAGEConv((-1, -1), 64) # Dimensions d'entrée déduites au premier forward
# Initialiser les modules lazy :
with torch.no_grad():
out = model(data.x, data.edge_index)
APIs de modèles haut niveau
Pour les architectures courantes, PyG fournit des classes de modèle prêtes à l'emploi :
from torch_geometric.nn import GraphSAGE, GCN, GAT, GIN
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
)
Couches personnalisées via MessagePassing
Pour implémenter une nouvelle couche GNN, hériterez de MessagePassing. Le cadre est :
propagate()orchestre le passage des messagesmessage()définit les informations qui circulent le long de chaque arête (la fonction phi)aggregate()combine les messages à chaque nœud (somme/moyenne/max)update()transforme le résultat agrégé (la fonction gamma)
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class MyConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", ou "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Pré-traitement avant passage des messages
x = self.lin(x)
# Commencer le passage des messages
return self.propagate(edge_index, x=x)
def message(self, x_j):
# x_j: caractéristiques des nœuds sources pour chaque arête [num_edges, features]
# Le suffixe _j indexe automatiquement les nœuds sources, _i indexe les nœuds cibles
return x_j
La convention _i / _j : tout tenseur passé à propagate() peut être indexé automatiquement en ajoutant _i (nœud cible/central) ou _j (nœud source/voisin) dans la signature de message(). Donc si vous passez x=... à propagate, vous pouvez accéder à x_i et x_j dans message().
Lisez references/message_passing.md pour des exemples d'implémentation complets de GCN et EdgeConv.
Motifs spécifiques aux tâches
Classification de nœuds
# Entraînement full-batch sur un seul graphe (par exemple, Cora)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Évaluation
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
Classification de graphes
Plusieurs graphes — utilisez DataLoader pour le mini-batching et un pooling global pour obtenir des représentations au niveau graphe :
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
loader = DataLoader(dataset, batch_size=32, shuffle=True)
class GraphClassifier(torch.nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, hidden_ch)
self.lin = torch.nn.Linear(hidden_ch, out_ch)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
x = global_mean_pool(x, batch) # [num_graphs_in_batch, hidden_ch]
return self.lin(x)
# Boucle d'entraînement
for data in loader:
out = model(data.x, data.edge_index, data.batch)
loss = F.cross_entropy(out, data.y)
Le DataLoader de PyG regroupe plusieurs graphes en créant des matrices d'adjacence bloc-diagonales. Le tenseur batch mappe chaque nœud à son indice de graphe. Les opérations de pooling (global_mean_pool, global_max_pool, global_add_pool) utilisent ceci pour agréger par graphe.
Prédiction de liens
Divisez les arêtes en train/val/test, utilisez l'échantillonnage négatif :
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)
# Encoder les nœuds, puis scorer les arêtes
z = model.encode(train_data.x, train_data.edge_index)
# Arêtes positives
pos_score = (z[train_data.edge_label_index[0]] * z[train_data.edge_label_index[1]]).sum(dim=1)
Lisez references/link_prediction.md pour le guide complet de prédiction de liens : autoencodeurs GAE/VGAE, boucles d'entraînement complètes, LinkNeighborLoader pour les grands graphes, prédiction de liens hétérogènes, et métriques d'évaluation.
Mise à l'échelle vers les grands graphes
Pour les graphes qui ne tiennent pas en mémoire GPU, utilisez l'échantillonnage de voisins via NeighborLoader :
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[15, 10], # Échantillonner 15 voisins au saut 1, 10 au saut 2
batch_size=128, # Nombre de nœuds graine par batch
input_nodes=data.train_mask, # Quels nœuds échantillonner
shuffle=True,
)
for batch in train_loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
# Utiliser uniquement les premiers batch_size nœuds pour la perte (ce sont les nœuds graine)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
Points clés sur NeighborLoader :
- La longueur de la liste
num_neighborsdoit correspondre à la profondeur du GNN (nombre de couches de passage de messages) - Les nœuds graine sont toujours les premiers
batch.batch_sizenœuds dans la sortie batch.n_idmappe les indices réétiquetés aux ID de nœuds originaux- Fonctionne pour
DataetHeteroData - Pour la prédiction de liens, utilisez
LinkNeighborLoaderà la place - L'échantillonnage de plus de 2-3 sauts est généralement infaisable (croissance exponentielle)
Autres options de scalabilité : ClusterLoader (ClusterGCN), GraphSAINTSampler, ShaDowKHopSampler. Pour l'entraînement multi-GPU, intégration PyTorch Lightning, et support torch.compile, lisez references/scaling.md.
Graphes hétérogènes
Pour les graphes avec plusieurs types de nœuds et d'arêtes (réseaux sociaux, graphes de connaissances, recommandation) :
from torch_geometric.data import HeteroData
data = HeteroData()
# Caractéristiques de nœuds — indexées par chaîne de type de nœud
data['user'].x = torch.randn(1000, 64)
data['movie'].x = torch.randn(500, 128)
# Indices d'arêtes — indexés par triplet (src_type, edge_type, dst_type)
data['user', 'rates', 'movie'].edge_index = torch.randint(0, 500, (2, 3000))
data['user', 'follows', 'user'].edge_index = torch.randint(0, 1000, (2, 5000))
# Accéder à des dicts de commodité
data.x_dict # {'user': tensor, 'movie': tensor}
data.edge_index_dict # {('user','rates','movie'): tensor, ...}
data.metadata() # ([node_types], [edge_types])
Trois façons de construire des GNN hétérogènes
1. Auto-conversion avec to_hetero() — écrivez un modèle homogène, convertissez automatiquement :
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(64, dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
# Accepte maintenant des dicts :
out = model(data.x_dict, data.edge_index_dict)
Utilisez (-1, -1) pour les canaux d'entrée bipartites (source, cible peuvent différer). L'initialisation lazy gère le reste.
2. Wrapper HeteroConv — convolution différente par type d'arête :
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv
conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'rev_writes', 'author'): GATConv((-1, -1), 64, add_self_loops=False),
}, aggr='sum')
3. Opérateurs hétérogènes natifs comme HGTConv :
from torch_geometric.nn import HGTConv
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads=4)
Important pour les graphes hétérogènes :
- Utilisez
T.ToUndirected()pour ajouter les types d'arêtes inverses pour le flux de messages bidirectionnel - Désactivez
add_self_loopsdans les couches de convolution bipartites (types source/dest différents) — utilisez des connexions de saut à la place :conv(x, edge_index) + lin(x) - Pour NeighborLoader sur HeteroData, spécifiez
input_nodescomme tuple('node_type', mask) num_neighborspeut être un dict clé par type d'arête pour un contrôle granulaire
Lisez references/heterogeneous.md pour des exemples complets incluant les boucles d'entraînement et l'utilisation de NeighborLoader avec des graphes hétérogènes.
Ensembles de données personnalisés
Pour charger vos propres données dans PyG :
- Rapide (pas de classe nécessaire) : Créez des objets
Datadirectement et passez une liste àDataLoader - Réutilisable (tient en RAM) : Hériterez de
InMemoryDataset— écrasezraw_file_names,processed_file_names,download(),process() - Grand (sauvegardé sur disque) : Hériterez de
Dataset— écrasez aussilen()etget() - Depuis CSV : Chargez les tables nœud/arête avec pandas, construisez des mappages vers des indices consécutifs, assemblez dans
DataouHeteroData - Depuis NetworkX :
from_networkx(G)convertit un graphe NetworkX directement - Depuis scipy sparse :
from_scipy_sparse_matrix(adj)extrait edge_index
Lisez references/custom_datasets.md pour des exemples complets avec tous les motifs, chargement CSV avec encodeurs, et la procédure pas à pas MovieLens.
Explicabilité
PyG fournit torch_geometric.explain pour interpréter les prédictions GNN :
from torch_geometric.explain import Explainer, GNNExplainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
explanation = explainer(data.x, data.edge_index, index=10)
explanation.visualize_graph() # Sous-graphe important
explanation.visualize_feature_importance(top_k=10) # Importance des caractéristiques
Algorithmes disponibles : GNNExplainer (basée optimisation), PGExplainer (paramétrique, entraînée), CaptumExplainer (basée gradient via Captum), AttentionExplainer (poids d'attention). Fonctionne pour les graphes homogènes et hétérogènes.
Lisez references/explainability.md pour tous les algorithmes, explications hétérogènes, métriques d'évaluation, et entraînement PGExplainer.
Pièges courants
- Forme edge_index : Doit être
[2, num_edges], pas[num_edges, 2]. Transposez si nécessaire. - Oublier les activations : Les couches de convolution n'incluent pas ReLU/etc — ajoutez-les manuellement.
- Auto-boucles en bipartite hétéro : N'utilisez pas
add_self_loops=Truequand les types de nœuds source et dest diffèrent. Utilisez plutôt des connexions de saut. - Slicing NeighborLoader : Seuls les premiers
batch.batch_sizenœuds sont vos nœuds graine. Slicez les prédictions et étiquettes en conséquence. - Graphes non orientés : Si votre graphe est non orienté, incluez les arêtes dans les deux directions dans
edge_index, ou utilisezT.ToUndirected(). - Initialisation lazy : Les modèles avec canaux d'entrée
-1ont besoin d'un forward pass avectorch.no_grad()avant l'entraînement pour initialiser les paramètres. - Pooling global pour tâches graphe : Utilisez
global_mean_pool(x, batch)(pas reshape manuel) pour agréger les caractéristiques de nœuds au niveau graphe. - Alignement num_neighbors : Maintenez
len(num_neighbors)égal au nombre de couches GNN. Plus de sauts que de couches gaspille le calcul ; moins de sauts signifie capacité de modèle gaspillée.