Skip to content

Commit 04872a1

Browse files
committed
Merge branch 'master' into issue/10463
2 parents 57d2cda + aae8bd2 commit 04872a1

File tree

9 files changed

+60
-38
lines changed

9 files changed

+60
-38
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repos:
2828
args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}']
2929

3030
- repo: https://github.com/asottile/pyupgrade
31-
rev: v3.20.0
31+
rev: v3.21.0
3232
hooks:
3333
- id: pyupgrade
3434
name: Upgrade Python syntax
@@ -55,7 +55,7 @@ repos:
5555
additional_dependencies: [toml]
5656

5757
- repo: https://github.com/pycqa/isort
58-
rev: 6.1.0
58+
rev: 7.0.0
5959
hooks:
6060
- id: isort
6161
name: Sort imports
@@ -68,7 +68,7 @@ repos:
6868
additional_dependencies: [Flake8-pyproject]
6969

7070
- repo: https://github.com/astral-sh/ruff-pre-commit
71-
rev: v0.13.3
71+
rev: v0.14.3
7272
hooks:
7373
- id: ruff
7474
name: Ruff formatting
@@ -85,7 +85,7 @@ repos:
8585
- mdformat_footnote
8686

8787
- repo: https://github.com/sphinx-contrib/sphinx-lint
88-
rev: v1.0.0
88+
rev: v1.0.1
8989
hooks:
9090
- id: sphinx-lint
9191
name: Check Sphinx

CHANGELOG.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
# Changelog
22

33
All notable changes to this project will be documented in this file.
4-
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
5+
6+
## [Unreleased] - YYYY-MM-DD
7+
8+
### Added
9+
10+
### Changed
11+
12+
### Deprecated
13+
14+
### Removed
15+
16+
### Fixed
17+
18+
### Security
519

620
## [2.7.0] - 2025-10-14
721

examples/llm/glem.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def main(args):
6060
token_on_disk = args.token_on_disk
6161
num_em_iters = args.num_em_iters
6262
start_time = time.time()
63-
train_without_ext_pred = args.train_without_ext_pred
63+
train_with_ext_pred = not args.train_without_ext_pred and \
64+
dataset_name == 'products'
6465
ext_pred = None
6566
pretrain_augmented = False
6667
ext_pseudo_labels = None
@@ -69,7 +70,7 @@ def main(args):
6970
print(f'Running on: {torch.cuda.get_device_name({gpu})}')
7071
torch.cuda.empty_cache()
7172

72-
if not train_without_ext_pred:
73+
if train_with_ext_pred:
7374
ext_pred_path = download_google_url(
7475
id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY',
7576
folder='data/ogb/ogbn_products/ext_preds',
@@ -262,7 +263,7 @@ def load_model(em_phase):
262263
if pretrain_phase == 'gnn':
263264
model.gnn = model.gnn.to(device)
264265
print('pretraining gnn to generate pseudo labels')
265-
if not train_without_ext_pred:
266+
if train_with_ext_pred:
266267
pretrain_loader = graph_train_loader
267268
preds_filename = 'gnn_pretrain'
268269
elif pretrain_phase == 'lm':
@@ -272,7 +273,7 @@ def load_model(em_phase):
272273
pretrain_loader = text_pretrain_loader
273274
test_loader = text_test_loader
274275
pretrain_opt = lm_opt
275-
if not train_without_ext_pred:
276+
if train_with_ext_pred:
276277
pretrain_loader = text_train_loader
277278
preds_filename = 'lm_pretrain'
278279

@@ -404,10 +405,10 @@ def load_model(em_phase):
404405
help='number of runs')
405406
parser.add_argument('--num_em_iters', type=int, default=1,
406407
help='number of iterations')
407-
parser.add_argument("--dataset", type=str, default='arxiv',
408+
parser.add_argument("--dataset", type=str, default='products',
408409
help='arxiv or products')
409410
parser.add_argument(
410-
"--text_type", type=str, default='llm_explanation',
411+
"--text_type", type=str, default='raw_text',
411412
help="type of text, support raw_text, llm_explanation,"
412413
"all for arxiv and raw_text for products")
413414
parser.add_argument("--pl_ratio", type=float, default=0.5,

examples/llm/nvtx_examples/nvtx_rag_backend_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
preprocess_triplet,
1313
retrieval_via_pcst,
1414
)
15+
from torch_geometric.llm import SentenceTransformer
1516
from torch_geometric.loader import rag_loader
16-
from torch_geometric.nn.nlp import SentenceTransformer
1717
from torch_geometric.profile.nvtx import nvtxit
1818

1919
sys.path.append('..')

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend="flit_core.buildapi"
44

55
[project]
66
name="torch-geometric"
7-
version="2.7.0"
7+
version="2.8.0"
88
authors=[
99
{name="Matthias Fey", email="[email protected]"},
1010
]

torch_geometric/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
contrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')
3232
graphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')
3333

34-
__version__ = '2.7.0'
34+
__version__ = '2.8.0'
3535

3636
__all__ = [
3737
'Index',

torch_geometric/datasets/tag_dataset.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,13 @@ def __init__(
137137
self.token_on_disk = token_on_disk
138138
self.tokenize_batch_size = tokenize_batch_size
139139
self._token = self.tokenize_graph(self.tokenize_batch_size)
140-
self._llm_explanation_token = self.tokenize_graph(
141-
self.tokenize_batch_size, text_type='llm_explanation')
142-
self._all_token = self.tokenize_graph(self.tokenize_batch_size,
143-
text_type='all')
140+
self._llm_explanation_token: Dict[str, Tensor] = {}
141+
self._all_token: Dict[str, Tensor] = {}
142+
if self.name in self.llm_explanation_id:
143+
self._llm_explanation_token = self.tokenize_graph(
144+
self.tokenize_batch_size, text_type='llm_explanation')
145+
self._all_token = self.tokenize_graph(self.tokenize_batch_size,
146+
text_type='all')
144147
self.__num_classes__ = dataset.num_classes
145148

146149
@property
@@ -170,14 +173,16 @@ def token(self) -> Dict[str, Tensor]:
170173

171174
@property
172175
def llm_explanation_token(self) -> Dict[str, Tensor]:
173-
if self._llm_explanation_token is None: # lazy load
176+
if self._llm_explanation_token is None and \
177+
self.name in self.llm_explanation_id:
174178
self._llm_explanation_token = self.tokenize_graph(
175179
text_type='llm_explanation')
176180
return self._llm_explanation_token
177181

178182
@property
179183
def all_token(self) -> Dict[str, Tensor]:
180-
if self._all_token is None: # lazy load
184+
if self._all_token is None and \
185+
self.name in self.llm_explanation_id:
181186
self._all_token = self.tokenize_graph(text_type='all')
182187
return self._all_token
183188

@@ -230,13 +235,15 @@ def download(self) -> None:
230235
filename='node-text.csv.gz',
231236
log=True)
232237
self.text = list(read_csv(raw_text_path)['text'])
233-
print('downloading llm explanations')
234-
llm_explanation_path = download_google_url(
235-
id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw',
236-
filename='node-gpt-response.csv.gz', log=True)
237-
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
238-
print('downloading llm predictions')
239-
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
238+
if self.name in self.llm_explanation_id:
239+
print('downloading llm explanations')
240+
llm_explanation_path = download_google_url(
241+
id=self.llm_explanation_id[self.name],
242+
folder=f'{self.root}/raw', filename='node-gpt-response.csv.gz',
243+
log=True)
244+
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
245+
print('downloading llm predictions')
246+
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)
240247

241248
def process(self) -> None:
242249
# process Title and Abstraction
@@ -276,20 +283,21 @@ def process(self) -> None:
276283
for i, pred in enumerate(preds):
277284
pl[i][:len(pred)] = torch.tensor(
278285
pred[:self.llm_prediction_topk], dtype=torch.long) + 1
286+
287+
if self.llm_explanation is None or pl is None:
288+
raise ValueError(
289+
"The TAGDataset only have ogbn-arxiv LLM explanations"
290+
"and predictions in default. The llm explanation and"
291+
"prediction of each node is not specified.Please pass in"
292+
"'llm_explanation' and 'llm_prediction' when"
293+
"convert your dataset to Text Attribute Graph Dataset")
279294
elif self.name in self.llm_explanation_id:
280295
self.download()
281296
else:
282297
print(
283298
'The dataset is not ogbn-arxiv,'
284299
'please pass in your llm explanation list to `llm_explanation`'
285300
'and llm prediction list to `llm_prediction`')
286-
if self.llm_explanation is None or pl is None:
287-
raise ValueError(
288-
"The TAGDataset only have ogbn-arxiv LLM explanations"
289-
"and predictions in default. The llm explanation and"
290-
"prediction of each node is not specified."
291-
"Please pass in 'llm_explanation' and 'llm_prediction' when"
292-
"convert your dataset to Text Attribute Graph Dataset")
293301

294302
def save_node_text(self, text: List[str]) -> None:
295303
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')

torch_geometric/nn/model_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ def _from_pretrained(
144144
revision,
145145
cache_dir,
146146
force_download,
147-
proxies,
148-
resume_download,
149147
local_files_only,
150148
token,
149+
proxies=None,
150+
resume_download=False,
151151
dataset_name='',
152152
model_name='',
153153
map_location='cpu',

torch_geometric/nn/pool/cluster_pool.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ class UnpoolInfo(NamedTuple):
2020

2121
class ClusterPooling(torch.nn.Module):
2222
r"""The cluster pooling operator from the `"Edge-Based Graph Component
23-
Pooling" <paper url>`_ paper.
24-
23+
Pooling" <https://arxiv.org/abs/2409.11856>`_ paper.
2524
:class:`ClusterPooling` computes a score for each edge.
2625
Based on the selected edges, graph clusters are calculated and compressed
2726
to one node using the injective :obj:`"sum"` aggregation function.

0 commit comments

Comments
 (0)