Skip to content

Commit ba96098

Browse files
committed
Merge branch 'zaristei/bidirectional-sampler' of https://github.com/zaristei/pytorch_geometric into zaristei/bidirectional-sampler
2 parents 22ddb1d + e6788d7 commit ba96098

File tree

11 files changed

+303
-15
lines changed

11 files changed

+303
-15
lines changed

.github/workflows/rag_testing.yml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: Testing RAG on PyTorch 2.5
2+
3+
on: # yamllint disable-line rule:truthy
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
9+
concurrency:
10+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }} # yamllint disable-line
11+
# Only cancel intermediate builds if on a PR:
12+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
13+
14+
jobs:
15+
16+
rag_pytest:
17+
runs-on: ubuntu-latest
18+
19+
steps:
20+
- name: Checkout repository
21+
uses: actions/checkout@v4
22+
with:
23+
fetch-depth: 40
24+
25+
# Run workflow if only certain files have been changed.
26+
- name: Get changed files
27+
id: changed-files-specific-rag
28+
uses: tj-actions/changed-files@v41
29+
with:
30+
files: |
31+
torch_geometric/datasets/web_qsp_dataset.py
32+
torch_geometric/nn/nlp/**
33+
torch_geometric/nn/models/g_retriever.py
34+
torch_geometric/loader/rag_loader.py
35+
36+
- name: Setup packages
37+
if: steps.changed-files-specific-rag.outputs.any_changed == 'true'
38+
uses: ./.github/actions/setup
39+
with:
40+
full_install: false
41+
42+
- name: Install main package
43+
if: steps.changed-files-specific-rag.outputs.any_changed == 'true'
44+
run: |
45+
pip install -e .[test,rag]
46+
47+
- name: Run tests
48+
if: steps.changed-files-specific-rag.outputs.any_changed == 'true'
49+
timeout-minutes: 10
50+
run: |
51+
RAG_TEST=1 pytest -m rag

.github/workflows/testing_full.yml

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,18 @@ jobs:
6363
run: |
6464
sudo apt-get install graphviz
6565
66-
- name: Install mpmath
67-
if: ${{ matrix.torch-version == 'nightly' }}
66+
- name: Install main package (torch!=nightly)
67+
if: ${{ matrix.torch-version != 'nightly' }}
6868
run: |
69-
pip install mpmath==1.3.0
69+
echo "torch==${{ matrix.torch-version }}" > requirements-constraint.txt
70+
pip install -e ".[full,test]" --constraint requirements-constraint.txt
71+
python -c "import torch; print('PyTorch:', torch.__version__)"
72+
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
73+
python -c "import torch; print('CUDA:', torch.version.cuda)"
7074
shell: bash
7175

72-
- name: Install main package
76+
- name: Install main package (torch==nightly)
77+
if: ${{ matrix.torch-version == 'nightly' }}
7378
run: |
7479
pip install -e ".[full,test]"
7580
python -c "import torch; print('PyTorch:', torch.__version__)"
@@ -80,7 +85,7 @@ jobs:
8085
- name: Run tests
8186
timeout-minutes: 20
8287
run: |
83-
FULL_TEST=1 pytest --cov --cov-report=xml
88+
FULL_TEST=1 pytest --cov --cov-report=xml --durations 10
8489
shell: bash
8590

8691
- name: Upload coverage

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ benchmark=[
5858
"protobuf<4.21",
5959
"wandb",
6060
]
61+
rag=[
62+
"pcst_fast",
63+
"datasets",
64+
"transformers",
65+
"pandas",
66+
"sentencepiece",
67+
"accelerate",
68+
]
6169
test=[
6270
"onnx",
6371
"onnxruntime",
@@ -192,6 +200,9 @@ filterwarnings = [
192200
# Filter `pytorch_lightning` warnings:
193201
"ignore:GPU available but not used:UserWarning",
194202
]
203+
markers = [
204+
"rag: mark test as RAG test",
205+
]
195206

196207
[tool.coverage.run]
197208
source = ["torch_geometric"]
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import os
2+
import random
3+
import string
4+
5+
import pytest
6+
7+
from torch_geometric.datasets import WebQSPDataset
8+
from torch_geometric.datasets.web_qsp_dataset import KGQABaseDataset
9+
from torch_geometric.testing import (
10+
onlyFullTest,
11+
onlyOnline,
12+
onlyRAG,
13+
withPackage,
14+
)
15+
16+
17+
@pytest.mark.skip(reason="Times out")
18+
@onlyOnline
19+
@onlyFullTest
20+
def test_web_qsp_dataset(tmp_path):
21+
dataset = WebQSPDataset(root=tmp_path)
22+
# Split for this dataset is 2826 train | 246 val | 1628 test
23+
# default split is train
24+
assert len(dataset) == 2826
25+
assert str(dataset) == "WebQSPDataset(2826)"
26+
27+
dataset_train = WebQSPDataset(root=tmp_path, split="train")
28+
assert len(dataset_train) == 2826
29+
assert str(dataset_train) == "WebQSPDataset(2826)"
30+
31+
dataset_val = WebQSPDataset(root=tmp_path, split="val")
32+
assert len(dataset_val) == 246
33+
assert str(dataset_val) == "WebQSPDataset(246)"
34+
35+
dataset_test = WebQSPDataset(root=tmp_path, split="test")
36+
assert len(dataset_test) == 1628
37+
assert str(dataset_test) == "WebQSPDataset(1628)"
38+
39+
40+
class MockSentenceTransformer:
41+
def __init__(self, *args, **kwargs):
42+
pass
43+
44+
def to(self, device):
45+
return self
46+
47+
def eval(self):
48+
return self
49+
50+
def encode(self, sentences, batch_size=None, output_device=None):
51+
import torch
52+
53+
def string_to_tensor(s: str) -> torch.Tensor:
54+
return torch.ones(1024).float()
55+
56+
if isinstance(sentences, str):
57+
return string_to_tensor(sentences)
58+
return torch.stack([string_to_tensor(s) for s in sentences])
59+
60+
61+
def create_mock_graphs(tmp_path: str, train_size: int, val_size: int,
62+
test_size: int, num_nodes: int, num_edge_types: int,
63+
num_trips: int, seed: int = 42):
64+
random.seed(seed)
65+
strkeys = string.ascii_letters + string.digits
66+
qa_strkeys = string.ascii_letters + string.digits + " "
67+
68+
def create_mock_triplets(num_nodes: int, num_edges: int, num_trips: int):
69+
nodes = list(
70+
{"".join(random.sample(strkeys, 10))
71+
for i in range(num_nodes)})
72+
edges = list(
73+
{"".join(random.sample(strkeys, 10))
74+
for i in range(num_edges)})
75+
triplets = []
76+
77+
for i in range(num_trips):
78+
h = random.randint(0, num_nodes - 1)
79+
t = random.randint(0, num_nodes - 1)
80+
r = random.randint(0, num_edge_types - 1)
81+
triplets.append((nodes[h], edges[r], nodes[t]))
82+
return triplets
83+
84+
train_triplets = [
85+
create_mock_triplets(num_nodes, num_edge_types, num_trips)
86+
for _ in range(train_size)
87+
]
88+
val_triplets = [
89+
create_mock_triplets(num_nodes, num_edge_types, num_trips)
90+
for _ in range(val_size)
91+
]
92+
test_triplets = [
93+
create_mock_triplets(num_nodes, num_edge_types, num_trips)
94+
for _ in range(test_size)
95+
]
96+
97+
train_questions = [
98+
"".join(random.sample(qa_strkeys, 10)) for _ in range(train_size)
99+
]
100+
val_questions = [
101+
"".join(random.sample(qa_strkeys, 10)) for _ in range(val_size)
102+
]
103+
test_questions = [
104+
"".join(random.sample(qa_strkeys, 10)) for _ in range(test_size)
105+
]
106+
107+
train_answers = [
108+
"".join(random.sample(qa_strkeys, 10)) for _ in range(train_size)
109+
]
110+
val_answers = [
111+
"".join(random.sample(qa_strkeys, 10)) for _ in range(val_size)
112+
]
113+
test_answers = [
114+
"".join(random.sample(qa_strkeys, 10)) for _ in range(test_size)
115+
]
116+
117+
train_graphs = {
118+
"graph": train_triplets,
119+
"question": train_questions,
120+
"answer": train_answers
121+
}
122+
val_graphs = {
123+
"graph": val_triplets,
124+
"question": val_questions,
125+
"answer": val_answers
126+
}
127+
test_graphs = {
128+
"graph": test_triplets,
129+
"question": test_questions,
130+
"answer": test_answers
131+
}
132+
133+
from datasets import Dataset, DatasetDict, load_from_disk
134+
135+
ds_train = Dataset.from_dict(train_graphs, split="train")
136+
ds_val = Dataset.from_dict(val_graphs, split="validation")
137+
ds_test = Dataset.from_dict(test_graphs, split="test")
138+
139+
ds = DatasetDict({
140+
"train": ds_train,
141+
"validation": ds_val,
142+
"test": ds_test
143+
})
144+
145+
def mock_load_dataset(name: str):
146+
# Save the dataset and then load it to emulate downloading from HF
147+
DATASET_CACHE_DIR = os.path.join(tmp_path,
148+
".cache/huggingface/datasets", name)
149+
os.makedirs(DATASET_CACHE_DIR, exist_ok=True)
150+
151+
ds.save_to_disk(DATASET_CACHE_DIR)
152+
dataset_remote = load_from_disk(DATASET_CACHE_DIR)
153+
return dataset_remote
154+
155+
return mock_load_dataset, ds
156+
157+
158+
@onlyRAG
159+
@withPackage("datasets", "pandas")
160+
def test_kgqa_base_dataset(tmp_path, monkeypatch):
161+
162+
num_nodes = 500
163+
num_edge_types = 25
164+
num_trips = 5000
165+
166+
# Mock the dataset graphs
167+
mock_load_dataset_func, expected_result = create_mock_graphs(
168+
tmp_path, train_size=10, val_size=5, test_size=5, num_nodes=num_nodes,
169+
num_edge_types=num_edge_types, num_trips=num_trips)
170+
171+
import datasets
172+
173+
monkeypatch.setattr(datasets, "load_dataset", mock_load_dataset_func)
174+
175+
# Mock the SentenceTransformer
176+
import torch_geometric.datasets.web_qsp_dataset
177+
monkeypatch.setattr(torch_geometric.datasets.web_qsp_dataset,
178+
"SentenceTransformer", MockSentenceTransformer)
179+
180+
dataset_train = KGQABaseDataset(root=tmp_path, dataset_name="TestDataset",
181+
split="train", use_pcst=False)
182+
assert len(dataset_train) == 10
183+
assert str(dataset_train) == "KGQABaseDataset(10)"
184+
for graph in dataset_train:
185+
assert graph.x.shape == (num_nodes, 1024)
186+
assert graph.edge_index.shape == (2, num_trips)
187+
assert graph.edge_attr.shape == (
188+
num_trips, 1024) # Reminder: edge_attr encodes the entire triplet
189+
190+
dataset_val = KGQABaseDataset(root=tmp_path, dataset_name="TestDataset",
191+
split="val", use_pcst=False)
192+
assert len(dataset_val) == 5
193+
assert str(dataset_val) == "KGQABaseDataset(5)"
194+
195+
dataset_test = KGQABaseDataset(root=tmp_path, dataset_name="TestDataset",
196+
split="test", use_pcst=False)
197+
assert len(dataset_test) == 5
198+
assert str(dataset_test) == "KGQABaseDataset(5)"
199+
200+
# TODO(zaristei): More rigorous tests to validate that values are correct
201+
# TODO(zaristei): Proper tests for PCST and CWQ

test/nn/models/test_g_retriever.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from torch_geometric.nn import GAT, GRetriever
44
from torch_geometric.nn.nlp import LLM
5-
from torch_geometric.testing import onlyFullTest, withPackage
5+
from torch_geometric.testing import onlyRAG, withPackage
66

77

8-
@onlyFullTest
8+
@onlyRAG
99
@withPackage('transformers', 'sentencepiece', 'accelerate')
1010
def test_g_retriever() -> None:
1111
llm = LLM(
@@ -53,7 +53,7 @@ def test_g_retriever() -> None:
5353
assert len(pred) == 1
5454

5555

56-
@onlyFullTest
56+
@onlyRAG
5757
@withPackage('transformers', 'sentencepiece', 'accelerate')
5858
def test_g_retriever_many_tokens() -> None:
5959
llm = LLM(

test/nn/models/test_gpse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_gpse_training():
2121

2222
data = Data(x=x, y=y, edge_index=edge_index)
2323
data = VirtualNode()(data)
24-
data.y_graph = torch.tensor(torch.randn(11))
24+
data.y_graph = torch.randn(11)
2525

2626
batch = Batch.from_data_list([data])
2727
model = GPSE()

test/nn/nlp/test_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from torch import Tensor
33

44
from torch_geometric.nn.nlp import LLM
5-
from torch_geometric.testing import onlyFullTest, withPackage
5+
from torch_geometric.testing import onlyRAG, withPackage
66

77

8-
@onlyFullTest
8+
@onlyRAG
99
@withPackage('transformers', 'accelerate')
1010
def test_llm() -> None:
1111
question = ["Is PyG the best open-source GNN library?"]

test/nn/nlp/test_sentence_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import pytest
22

33
from torch_geometric.nn.nlp import SentenceTransformer
4-
from torch_geometric.testing import onlyFullTest, withCUDA, withPackage
4+
from torch_geometric.testing import onlyRAG, withCUDA, withPackage
55

66

77
@withCUDA
8-
@onlyFullTest
8+
@onlyRAG
99
@withPackage('transformers')
1010
@pytest.mark.parametrize('batch_size', [None, 1])
1111
@pytest.mark.parametrize('pooling_strategy', ['mean', 'last', 'cls'])

torch_geometric/data/large_graph_indexer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tqdm import tqdm
2323

2424
from torch_geometric.data import Data
25+
from torch_geometric.io import fs
2526
from torch_geometric.typing import WITH_PT24
2627

2728
# Could be any hashable type
@@ -505,13 +506,13 @@ def from_disk(cls, path: str) -> "LargeGraphIndexer":
505506
for fname in os.listdir(node_attr_path):
506507
full_fname = f"{node_attr_path}/{fname}"
507508
key = fname.split(".")[0]
508-
indexer.node_attr[key] = torch.load(full_fname)
509+
indexer.node_attr[key] = fs.torch_load(full_fname)
509510

510511
edge_attr_path = path + "/edge_attr"
511512
for fname in os.listdir(edge_attr_path):
512513
full_fname = f"{edge_attr_path}/{fname}"
513514
key = fname.split(".")[0]
514-
indexer.edge_attr[key] = torch.load(full_fname)
515+
indexer.edge_attr[key] = fs.torch_load(full_fname)
515516

516517
return indexer
517518

0 commit comments

Comments
 (0)