Skip to content

Commit c7535a0

Browse files
authored
Merge branch 'master' into models/polynormer
2 parents 4843c02 + e7050f1 commit c7535a0

File tree

19 files changed

+2357
-84
lines changed

19 files changed

+2357
-84
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88
### Added
99

1010
- Added `Polynormer` model and example ([#9908](https://github.com/pyg-team/pytorch_geometric/pull/9908))
11+
- Added support for heterogenous graphs in `GNNExplainer` ([#10158](https://github.com/pyg-team/pytorch_geometric/pull/10158))
12+
- Added Graph Positional and Structural Encoder (GPSE) ([#9018](https://github.com/pyg-team/pytorch_geometric/pull/9018))
13+
- Added attract-repel link prediction example ([#10107](https://github.com/pyg-team/pytorch_geometric/pull/10107))
14+
- Added `ARLinkPredictor` for implementing Attract-Repel embeddings for link prediction ([#10105](https://github.com/pyg-team/pytorch_geometric/pull/10105))
1115
- Improving documentation for [cuGraph](https://github.com/rapidsai/cugraph) ([#10083](https://github.com/pyg-team/pytorch_geometric/pull/10083))
1216
- Added `HashTensor` ([#10072](https://github.com/pyg-team/pytorch_geometric/pull/10072))
1317
- Added `SGFormer` model and example ([#9904](https://github.com/pyg-team/pytorch_geometric/pull/9904))

examples/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ A great and simple example to start with is [`gcn.py`](./gcn.py), showing a user
77

88
For a simple link prediction example, see [`link_pred.py`](./link_pred.py).
99

10+
For an improved link prediction approach using Attract-Repel embeddings that can significantly boost accuracy (up to 23% improvement in AUC), see [`ar_link_pred.py`](./ar_link_pred.py). This approach is based on [Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs](https://arxiv.org/abs/2106.09671).
11+
1012
For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see the `ogbn_*.py` examples:
1113

1214
- [`ogbn_train.py`](./ogbn_train.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges or the medium scale `ogbn-products` dataset, ~62M edges.

examples/ar_link_pred.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import argparse
2+
import os.path as osp
3+
4+
import torch
5+
import torch.nn.functional as F
6+
7+
import torch_geometric.transforms as T
8+
from torch_geometric.datasets import Planetoid
9+
from torch_geometric.nn import GCNConv
10+
from torch_geometric.utils import negative_sampling, train_test_split_edges
11+
12+
13+
class GCNEncoder(torch.nn.Module):
14+
def __init__(self, in_channels, hidden_channels, out_channels):
15+
super().__init__()
16+
self.conv1 = GCNConv(in_channels, hidden_channels)
17+
self.conv2 = GCNConv(hidden_channels, out_channels)
18+
19+
def forward(self, x, edge_index):
20+
x = self.conv1(x, edge_index).relu()
21+
return self.conv2(x, edge_index)
22+
23+
24+
class LinkPredictor(torch.nn.Module):
25+
def __init__(self, in_channels, hidden_channels):
26+
super().__init__()
27+
self.lin1 = torch.nn.Linear(in_channels * 2, hidden_channels)
28+
self.lin2 = torch.nn.Linear(hidden_channels, 1)
29+
30+
def forward(self, z_i, z_j):
31+
x = torch.cat([z_i, z_j], dim=1)
32+
x = self.lin1(x).relu()
33+
x = self.lin2(x)
34+
return x.view(-1)
35+
36+
37+
class ARLinkPredictor(torch.nn.Module):
38+
def __init__(self, in_channels):
39+
super().__init__()
40+
# Split dimensions between attract and repel
41+
self.attract_dim = in_channels // 2
42+
self.repel_dim = in_channels - self.attract_dim
43+
44+
def forward(self, z_i, z_j):
45+
# Split into attract and repel parts
46+
z_i_attr = z_i[:, :self.attract_dim]
47+
z_i_repel = z_i[:, self.attract_dim:]
48+
49+
z_j_attr = z_j[:, :self.attract_dim]
50+
z_j_repel = z_j[:, self.attract_dim:]
51+
52+
# Calculate AR score
53+
attract_score = (z_i_attr * z_j_attr).sum(dim=1)
54+
repel_score = (z_i_repel * z_j_repel).sum(dim=1)
55+
56+
return attract_score - repel_score
57+
58+
59+
def train(encoder, predictor, data, optimizer):
60+
encoder.train()
61+
predictor.train()
62+
63+
# Forward pass and calculate loss
64+
optimizer.zero_grad()
65+
z = encoder(data.x, data.train_pos_edge_index)
66+
67+
# Positive edges
68+
pos_out = predictor(z[data.train_pos_edge_index[0]],
69+
z[data.train_pos_edge_index[1]])
70+
71+
# Sample and predict on negative edges
72+
neg_edge_index = negative_sampling(
73+
edge_index=data.train_pos_edge_index,
74+
num_nodes=data.num_nodes,
75+
num_neg_samples=data.train_pos_edge_index.size(1),
76+
)
77+
neg_out = predictor(z[neg_edge_index[0]], z[neg_edge_index[1]])
78+
79+
# Calculate loss
80+
pos_loss = F.binary_cross_entropy_with_logits(pos_out,
81+
torch.ones_like(pos_out))
82+
neg_loss = F.binary_cross_entropy_with_logits(neg_out,
83+
torch.zeros_like(neg_out))
84+
loss = pos_loss + neg_loss
85+
86+
loss.backward()
87+
optimizer.step()
88+
89+
return loss.item()
90+
91+
92+
@torch.no_grad()
93+
def test(encoder, predictor, data):
94+
encoder.eval()
95+
predictor.eval()
96+
97+
z = encoder(data.x, data.train_pos_edge_index)
98+
99+
pos_val_out = predictor(z[data.val_pos_edge_index[0]],
100+
z[data.val_pos_edge_index[1]])
101+
neg_val_out = predictor(z[data.val_neg_edge_index[0]],
102+
z[data.val_neg_edge_index[1]])
103+
104+
pos_test_out = predictor(z[data.test_pos_edge_index[0]],
105+
z[data.test_pos_edge_index[1]])
106+
neg_test_out = predictor(z[data.test_neg_edge_index[0]],
107+
z[data.test_neg_edge_index[1]])
108+
109+
val_auc = compute_auc(pos_val_out, neg_val_out)
110+
test_auc = compute_auc(pos_test_out, neg_test_out)
111+
112+
return val_auc, test_auc
113+
114+
115+
def compute_auc(pos_out, neg_out):
116+
pos_out = torch.sigmoid(pos_out).cpu().numpy()
117+
neg_out = torch.sigmoid(neg_out).cpu().numpy()
118+
119+
# Simple AUC calculation
120+
from sklearn.metrics import roc_auc_score
121+
y_true = torch.cat(
122+
[torch.ones(pos_out.shape[0]),
123+
torch.zeros(neg_out.shape[0])])
124+
y_score = torch.cat([torch.tensor(pos_out), torch.tensor(neg_out)])
125+
126+
return roc_auc_score(y_true, y_score)
127+
128+
129+
def main():
130+
parser = argparse.ArgumentParser()
131+
parser.add_argument('--dataset', type=str, default='Cora',
132+
choices=['Cora', 'CiteSeer', 'PubMed'])
133+
parser.add_argument('--hidden_channels', type=int, default=128)
134+
parser.add_argument('--out_channels', type=int, default=64)
135+
parser.add_argument('--epochs', type=int, default=200)
136+
parser.add_argument('--use_ar', action='store_true',
137+
help='Use Attract-Repel embeddings')
138+
parser.add_argument('--lr', type=float, default=0.01)
139+
args = parser.parse_args()
140+
141+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
142+
143+
# Load dataset
144+
transform = T.Compose([
145+
T.NormalizeFeatures(),
146+
T.ToDevice(device),
147+
])
148+
149+
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
150+
args.dataset)
151+
dataset = Planetoid(path, args.dataset, transform=transform)
152+
data = dataset[0]
153+
154+
# Process data for link prediction
155+
data = train_test_split_edges(data)
156+
157+
# Initialize encoder
158+
encoder = GCNEncoder(
159+
in_channels=dataset.num_features,
160+
hidden_channels=args.hidden_channels,
161+
out_channels=args.out_channels,
162+
).to(device)
163+
164+
# Choose predictor based on args
165+
if args.use_ar:
166+
predictor = ARLinkPredictor(in_channels=args.out_channels).to(device)
167+
print(f"Running link prediction on {args.dataset}"
168+
f"with Attract-Repel embeddings")
169+
else:
170+
predictor = LinkPredictor(
171+
in_channels=args.out_channels,
172+
hidden_channels=args.hidden_channels).to(device)
173+
print(f"Running link prediction on {args.dataset}"
174+
f"with Traditional embeddings")
175+
176+
optimizer = torch.optim.Adam(
177+
list(encoder.parameters()) + list(predictor.parameters()), lr=args.lr)
178+
179+
best_val_auc = 0
180+
final_test_auc = 0
181+
182+
for epoch in range(1, args.epochs + 1):
183+
loss = train(encoder, predictor, data, optimizer)
184+
val_auc, test_auc = test(encoder, predictor, data)
185+
186+
if val_auc > best_val_auc:
187+
best_val_auc = val_auc
188+
final_test_auc = test_auc
189+
190+
if epoch % 10 == 0:
191+
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
192+
f'Val AUC: {val_auc:.4f}, '
193+
f'Test AUC: {test_auc:.4f}')
194+
195+
print(f'Final results - Val AUC: {best_val_auc:.4f}, '
196+
f'Test AUC: {final_test_auc:.4f}')
197+
198+
# Calculate R-fraction if using AR
199+
if args.use_ar:
200+
with torch.no_grad():
201+
z = encoder(data.x, data.train_pos_edge_index)
202+
attr_dim = args.out_channels // 2
203+
204+
z_attr = z[:, :attr_dim]
205+
z_repel = z[:, attr_dim:]
206+
207+
attract_norm_squared = torch.sum(z_attr**2)
208+
repel_norm_squared = torch.sum(z_repel**2)
209+
210+
r_fraction = repel_norm_squared / (attract_norm_squared +
211+
repel_norm_squared)
212+
print(f"R-fraction: {r_fraction.item():.4f}")
213+
214+
215+
if __name__ == '__main__':
216+
main()

examples/llm/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Examples for Co-training LLMs and GNNs
22

3-
| Example | Description |
4-
| -------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
5-
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA3` with `GAT` for answering questions based on knowledge graph information. We also have an [example repo](https://github.com/neo4j-product-examples/neo4j-gnn-llm-example) for integration with [Neo4j Graph DBs][neo4j.com] |
6-
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
7-
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
8-
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
9-
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
10-
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
11-
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
3+
| Example | Description |
4+
| -------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
5+
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA3` with `GAT` for answering questions based on knowledge graph information from the toy WebQSP dataset. We also have an [example repo](https://github.com/neo4j-product-examples/neo4j-gnn-llm-example) for integration with [Neo4j Graph DBs][neo4j.com] along with an associated [blog](https://developer.nvidia.com/blog/boosting-qa-accuracy-with-graphrag-using-pyg-and-graph-databases/) showing 2x accuracy gains over LLMs on real medical data. |
6+
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
7+
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
8+
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
9+
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset |
10+
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
11+
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |

examples/llm/g_retriever.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
1010
Example repo for integration with Neo4j Graph DB:
1111
https://github.com/neo4j-product-examples/neo4j-gnn-llm-example
12+
Example blog showing 2x accuracy over LLM on real medical data:
13+
https://developer.nvidia.com/blog/boosting-qa-accuracy-with-graphrag-using-pyg-and-graph-databases/
1214
"""
1315
import argparse
1416
import gc

test/contrib/explain/test_pgm_explainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_geometric.explain import Explainer
66
from torch_geometric.explain.config import ModelConfig
77
from torch_geometric.nn import GCNConv, global_add_pool
8-
from torch_geometric.testing import withPackage
8+
from torch_geometric.testing import minPython, withPackage
99

1010

1111
class GCN(torch.nn.Module):
@@ -45,6 +45,7 @@ def forward(self, x, edge_index, edge_weight=None, batch=None, **kwargs):
4545
edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])
4646

4747

48+
@minPython('3.10')
4849
@withPackage('pgmpy', 'pandas')
4950
@pytest.mark.parametrize('node_idx', [2, 6])
5051
@pytest.mark.parametrize('task_level, perturbation_mode', [

0 commit comments

Comments
 (0)