torch-geometric

7
0
Source

Graph Neural Networks (PyG). Node/graph classification, link prediction, GCN, GAT, GraphSAGE, heterogeneous graphs, molecular property prediction, for geometric deep learning.

Install

mkdir -p .claude/skills/torch-geometric && curl -L -o skill.zip "https://mcp.directory/api/skills/download/2611" && unzip -o skill.zip -d .claude/skills/torch-geometric && rm skill.zip

Installs to .claude/skills/torch-geometric

About this skill

PyTorch Geometric (PyG)

Overview

PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.

When to Use This Skill

This skill should be used when working with:

  • Graph-based machine learning: Node classification, graph classification, link prediction
  • Molecular property prediction: Drug discovery, chemical property prediction
  • Social network analysis: Community detection, influence prediction
  • Citation networks: Paper classification, recommendation systems
  • 3D geometric data: Point clouds, meshes, molecular structures
  • Heterogeneous graphs: Multi-type nodes and edges (e.g., knowledge graphs)
  • Large-scale graph learning: Neighbor sampling, distributed training

Quick Start

Installation

uv pip install torch_geometric

For additional dependencies (sparse operations, clustering):

uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html

Basic Graph Creation

import torch
from torch_geometric.data import Data

# Create a simple graph with 3 nodes
edge_index = torch.tensor([[0, 1, 1, 2],  # source nodes
                           [1, 0, 2, 1]], dtype=torch.long)  # target nodes
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)  # node features

data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")

Loading a Benchmark Dataset

from torch_geometric.datasets import Planetoid

# Load Cora citation network
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]  # Get the first (and only) graph

print(f"Dataset: {dataset}")
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")

Core Concepts

Data Structure

PyG represents graphs using the torch_geometric.data.Data class with these key attributes:

  • data.x: Node feature matrix [num_nodes, num_node_features]
  • data.edge_index: Graph connectivity in COO format [2, num_edges]
  • data.edge_attr: Edge feature matrix [num_edges, num_edge_features] (optional)
  • data.y: Target labels for nodes or graphs
  • data.pos: Node spatial positions [num_nodes, num_dimensions] (optional)
  • Custom attributes: Can add any attribute (e.g., data.train_mask, data.batch)

Important: These attributes are not mandatory—extend Data objects with custom attributes as needed.

Edge Index Format

Edges are stored in COO (coordinate) format as a [2, num_edges] tensor:

  • First row: source node indices
  • Second row: target node indices
# Edge list: (0→1), (1→0), (1→2), (2→1)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

Mini-Batch Processing

PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:

  • Adjacency matrices are stacked diagonally
  • Node features are concatenated along the node dimension
  • A batch vector maps each node to its source graph
  • No padding needed—computationally efficient
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
    print(f"Batch size: {batch.num_graphs}")
    print(f"Total nodes: {batch.num_nodes}")
    # batch.batch maps nodes to graphs

Building Graph Neural Networks

Message Passing Paradigm

GNNs in PyG follow a neighborhood aggregation scheme:

  1. Transform node features
  2. Propagate messages along edges
  3. Aggregate messages from neighbors
  4. Update node representations

Using Pre-Built Layers

PyG provides 40+ convolutional layers. Common ones include:

GCNConv (Graph Convolutional Network):

from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

GATConv (Graph Attention Network):

from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

GraphSAGE:

from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = SAGEConv(num_features, 64)
        self.conv2 = SAGEConv(64, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

Custom Message Passing Layers

For custom layers, inherit from MessagePassing:

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class CustomConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "add", "mean", or "max"
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # Add self-loops to adjacency matrix
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Transform node features
        x = self.lin(x)

        # Compute normalization
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Propagate messages
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j: features of source nodes
        return norm.view(-1, 1) * x_j

Key methods:

  • forward(): Main entry point
  • message(): Constructs messages from source to target nodes
  • aggregate(): Aggregates messages (usually don't override—set aggr parameter)
  • update(): Updates node embeddings after aggregation

Variable naming convention: Appending _i or _j to tensor names automatically maps them to target or source nodes.

Working with Datasets

Loading Built-in Datasets

PyG provides extensive benchmark datasets:

# Citation networks (node classification)
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')  # or 'CiteSeer', 'PubMed'

# Graph classification
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')

# Molecular datasets
from torch_geometric.datasets import QM9
dataset = QM9(root='/tmp/QM9')

# Large-scale datasets
from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')

Check references/datasets_reference.md for a comprehensive list.

Creating Custom Datasets

For datasets that fit in memory, inherit from InMemoryDataset:

from torch_geometric.data import InMemoryDataset, Data
import torch

class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['my_data.csv']  # Files needed in raw_dir

    @property
    def processed_file_names(self):
        return ['data.pt']  # Files in processed_dir

    def download(self):
        # Download raw data to self.raw_dir
        pass

    def process(self):
        # Read data, create Data objects
        data_list = []

        # Example: Create a simple graph
        edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
        x = torch.randn(2, 16)
        y = torch.tensor([0], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)

        # Apply pre_filter and pre_transform
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        # Save processed data
        self.save(data_list, self.processed_paths[0])

For large datasets that don't fit in memory, inherit from Dataset and implement len() and get(idx).

Loading Graphs from CSV

import pandas as pd
import torch
from torch_geometric.data import HeteroData

# Load nodes
nodes_df = pd.read_csv('nodes.csv')
x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)

# Load edges
edges_df = pd.read_csv('edges.csv')
edge_index = torch.tensor([edges_df['source'].values,
                           edges_df['target'].values], dtype=torch.long)

data = Data(x=x, edge_index=edge_index)

Training Workflows

Node Classification (Single Graph)

import torch
import torch.nn.functional a

---

*Content truncated.*

scroll-experience

davila7

Expert in building immersive scroll-driven experiences - parallax storytelling, scroll animations, interactive narratives, and cinematic web experiences. Like NY Times interactives, Apple product pages, and award-winning web experiences. Makes websites feel like experiences, not just pages. Use when: scroll animation, parallax, scroll storytelling, interactive story, cinematic website.

6332

software-architecture

davila7

Guide for quality focused software architecture. This skill should be used when users want to write code, design architecture, analyze code, in any case that relates to software development.

8125

senior-fullstack

davila7

Comprehensive fullstack development skill for building complete web applications with React, Next.js, Node.js, GraphQL, and PostgreSQL. Includes project scaffolding, code quality analysis, architecture patterns, and complete tech stack guidance. Use when building new projects, analyzing code quality, implementing design patterns, or setting up development workflows.

8122

senior-security

davila7

Comprehensive security engineering skill for application security, penetration testing, security architecture, and compliance auditing. Includes security assessment tools, threat modeling, crypto implementation, and security automation. Use when designing security architecture, conducting penetration tests, implementing cryptography, or performing security audits.

6819

game-development

davila7

Game development orchestrator. Routes to platform-specific skills based on project needs.

5414

2d-games

davila7

2D game development principles. Sprites, tilemaps, physics, camera.

4812

You might also like

flutter-development

aj-geddes

Build beautiful cross-platform mobile apps with Flutter and Dart. Covers widgets, state management with Provider/BLoC, navigation, API integration, and material design.

643969

drawio-diagrams-enhanced

jgtolentino

Create professional draw.io (diagrams.net) diagrams in XML format (.drawio files) with integrated PMP/PMBOK methodologies, extensive visual asset libraries, and industry-standard professional templates. Use this skill when users ask to create flowcharts, swimlane diagrams, cross-functional flowcharts, org charts, network diagrams, UML diagrams, BPMN, project management diagrams (WBS, Gantt, PERT, RACI), risk matrices, stakeholder maps, or any other visual diagram in draw.io format. This skill includes access to custom shape libraries for icons, clipart, and professional symbols.

591705

ui-ux-pro-max

nextlevelbuilder

"UI/UX design intelligence. 50 styles, 21 palettes, 50 font pairings, 20 charts, 8 stacks (React, Next.js, Vue, Svelte, SwiftUI, React Native, Flutter, Tailwind). Actions: plan, build, create, design, implement, review, fix, improve, optimize, enhance, refactor, check UI/UX code. Projects: website, landing page, dashboard, admin panel, e-commerce, SaaS, portfolio, blog, mobile app, .html, .tsx, .vue, .svelte. Elements: button, modal, navbar, sidebar, card, table, form, chart. Styles: glassmorphism, claymorphism, minimalism, brutalism, neumorphism, bento grid, dark mode, responsive, skeuomorphism, flat design. Topics: color palette, accessibility, animation, layout, typography, font pairing, spacing, hover, shadow, gradient."

318399

godot

bfollington

This skill should be used when working on Godot Engine projects. It provides specialized knowledge of Godot's file formats (.gd, .tscn, .tres), architecture patterns (component-based, signal-driven, resource-based), common pitfalls, validation tools, code templates, and CLI workflows. The `godot` command is available for running the game, validating scripts, importing resources, and exporting builds. Use this skill for tasks involving Godot game development, debugging scene/resource files, implementing game systems, or creating new Godot components.

340397

nano-banana-pro

garg-aayush

Generate and edit images using Google's Nano Banana Pro (Gemini 3 Pro Image) API. Use when the user asks to generate, create, edit, modify, change, alter, or update images. Also use when user references an existing image file and asks to modify it in any way (e.g., "modify this image", "change the background", "replace X with Y"). Supports both text-to-image generation and image-to-image editing with configurable resolution (1K default, 2K, or 4K for high resolution). DO NOT read the image file first - use this skill directly with the --input-image parameter.

452339

fastapi-templates

wshobson

Create production-ready FastAPI projects with async patterns, dependency injection, and comprehensive error handling. Use when building new FastAPI applications or setting up backend API projects.

304231

Stay ahead of the MCP ecosystem

Get weekly updates on new skills and servers.