Skip to content

Commit cd28763

Browse files
authored
Merge branch 'master' into gcn
2 parents 2517f5c + 11e576b commit cd28763

File tree

10 files changed

+42
-50
lines changed

10 files changed

+42
-50
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66
## [2.3.0] - 2023-MM-DD
77
### Added
88
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033))
9+
- Add inputs_channels back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
910
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
1011
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
1112
### Changed
1213
- [Breaking Change] Moved PyTorch Lightning data modules to `torch_geometric.data.lightning` ([#6140](https://github.com/pyg-team/pytorch_geometric/pull/6140))
1314
- Make `torch_sparse` an optional dependency ([#6132](https://github.com/pyg-team/pytorch_geometric/pull/6132), [#6134](https://github.com/pyg-team/pytorch_geometric/pull/6134), [#6138](https://github.com/pyg-team/pytorch_geometric/pull/6138), [#6139](https://github.com/pyg-team/pytorch_geometric/pull/6139))
14-
- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113))
15+
- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155))
1516
- Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123))
1617
### Removed
1718

benchmark/training/training_benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def run(args: argparse.ArgumentParser) -> None:
8080
else:
8181
amp = torch.cpu.amp.autocast(enabled=args.bf16)
8282

83+
inputs_channels = data[
84+
'paper'].num_features if dataset_name == 'ogbn-mag' \
85+
else data.num_features
86+
8387
for model_name in args.models:
8488
if model_name not in supported_sets[dataset_name]:
8589
print(f'Configuration of {dataset_name} + {model_name} '
@@ -124,6 +128,7 @@ def run(args: argparse.ArgumentParser) -> None:
124128
f'Sparse tensor={args.use_sparse_tensor}')
125129

126130
params = {
131+
'inputs_channels': inputs_channels,
127132
'hidden_channels': hidden_channels,
128133
'output_channels': num_classes,
129134
'num_heads': args.num_heads,

test/loader/test_neighbor_loader.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
from torch_sparse import SparseTensor
1111

12+
import torch_geometric.typing
1213
from torch_geometric.data import Data, HeteroData
1314
from torch_geometric.loader import NeighborLoader
1415
from torch_geometric.nn import GraphConv, to_hetero
@@ -20,12 +21,6 @@
2021
)
2122
from torch_geometric.utils import k_hop_subgraph
2223

23-
try:
24-
import pyg_lib # noqa
25-
_WITH_PYG_LIB = True
26-
except ImportError:
27-
_WITH_PYG_LIB = False
28-
2924

3025
def get_edge_index(num_src_nodes, num_dst_nodes, num_edges, dtype=torch.int64):
3126
row = torch.randint(num_src_nodes, (num_edges, ), dtype=dtype)
@@ -44,7 +39,7 @@ def is_subset(subedge_index, edge_index, src_idx, dst_idx):
4439
@pytest.mark.parametrize('directed', [True]) # TODO re-enable undirected mode
4540
@pytest.mark.parametrize('dtype', [torch.int64, torch.int32])
4641
def test_homogeneous_neighbor_loader(directed, dtype):
47-
if dtype != torch.int64 and not _WITH_PYG_LIB:
42+
if dtype != torch.int64 and not torch_geometric.typing.WITH_PYG_LIB:
4843
return
4944

5045
torch.manual_seed(12345)
@@ -83,7 +78,7 @@ def test_homogeneous_neighbor_loader(directed, dtype):
8378
@pytest.mark.parametrize('directed', [True]) # TODO re-enable undirected mode
8479
@pytest.mark.parametrize('dtype', [torch.int64, torch.int32])
8580
def test_heterogeneous_neighbor_loader(directed, dtype):
86-
if dtype != torch.int64 and not _WITH_PYG_LIB:
81+
if dtype != torch.int64 and not torch_geometric.typing.WITH_PYG_LIB:
8782
return
8883

8984
torch.manual_seed(12345)

test/nn/norm/test_layer_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_layer_norm(affine, mode):
1212
batch = torch.zeros(100, dtype=torch.long)
1313

1414
norm = LayerNorm(16, affine=affine, mode=mode)
15-
assert norm.__repr__() == f'LayerNorm(16, mode={mode})'
15+
assert norm.__repr__() == f'LayerNorm(16, affine={affine}, mode={mode})'
1616

1717
if is_full_test():
1818
torch.jit.script(norm)

torch_geometric/nn/conv/rgcn_conv.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,12 @@
88
from torch_scatter import scatter
99
from torch_sparse import SparseTensor, masked_select_nnz, matmul
1010

11+
import torch_geometric.typing
1112
from torch_geometric.nn.conv import MessagePassing
12-
from torch_geometric.typing import Adj, OptTensor
13+
from torch_geometric.typing import Adj, OptTensor, pyg_lib
1314

1415
from ..inits import glorot, zeros
1516

16-
try:
17-
from pyg_lib.ops import segment_matmul # noqa
18-
_WITH_PYG_LIB = True
19-
except ImportError:
20-
_WITH_PYG_LIB = False
21-
22-
def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor:
23-
raise NotImplementedError
24-
2517

2618
@torch.jit._overload
2719
def masked_edge_index(edge_index, edge_mask):
@@ -112,7 +104,6 @@ def __init__(
112104
):
113105
kwargs.setdefault('aggr', aggr)
114106
super().__init__(node_dim=0, **kwargs)
115-
self._WITH_PYG_LIB = _WITH_PYG_LIB
116107

117108
if num_bases is not None and num_blocks is not None:
118109
raise ValueError('Can not apply both basis-decomposition and '
@@ -225,7 +216,7 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
225216
out = out + h.contiguous().view(-1, self.out_channels)
226217

227218
else: # No regularization/Basis-decomposition ========================
228-
if (self._WITH_PYG_LIB and self.num_bases is None
219+
if (torch_geometric.typing.WITH_PYG_LIB and self.num_bases is None
229220
and x_l.is_floating_point()
230221
and isinstance(edge_index, Tensor)):
231222
if not self.is_sorted:
@@ -264,7 +255,7 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
264255
def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
265256
if edge_type_ptr is not None:
266257
# TODO Re-weight according to edge type degree for `aggr=mean`.
267-
return segment_matmul(x_j, edge_type_ptr, self.weight)
258+
return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)
268259

269260
return x_j
270261

torch_geometric/nn/dense/linear.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,9 @@
77
from torch import Tensor, nn
88
from torch.nn.parameter import Parameter
99

10+
import torch_geometric.typing
1011
from torch_geometric.nn import inits
11-
12-
try:
13-
from pyg_lib.ops import segment_matmul # noqa
14-
_WITH_PYG_LIB = True
15-
except ImportError:
16-
_WITH_PYG_LIB = False
17-
18-
def segment_matmul(inputs: Tensor, ptr: Tensor, other: Tensor) -> Tensor:
19-
raise NotImplementedError
12+
from torch_geometric.typing import pyg_lib
2013

2114

2215
def is_uninitialized_parameter(x: Any) -> bool:
@@ -220,9 +213,7 @@ def __init__(self, in_channels: int, out_channels: int, num_types: int,
220213
self.is_sorted = is_sorted
221214
self.kwargs = kwargs
222215

223-
self._WITH_PYG_LIB = _WITH_PYG_LIB
224-
225-
if self._WITH_PYG_LIB:
216+
if torch_geometric.typing.WITH_PYG_LIB:
226217
self.lins = None
227218
self.weight = torch.nn.Parameter(
228219
torch.Tensor(num_types, in_channels, out_channels))
@@ -241,7 +232,7 @@ def __init__(self, in_channels: int, out_channels: int, num_types: int,
241232
self.reset_parameters()
242233

243234
def reset_parameters(self):
244-
if self._WITH_PYG_LIB:
235+
if torch_geometric.typing.WITH_PYG_LIB:
245236
reset_weight_(self.weight, self.in_channels,
246237
self.kwargs.get('weight_initializer', None))
247238
reset_weight_(self.bias, self.in_channels,
@@ -256,7 +247,7 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
256247
x (Tensor): The input features.
257248
type_vec (LongTensor): A vector that maps each entry to a type.
258249
"""
259-
if self._WITH_PYG_LIB:
250+
if torch_geometric.typing.WITH_PYG_LIB:
260251
assert self.weight is not None
261252

262253
if not self.is_sorted:
@@ -266,7 +257,7 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
266257

267258
type_vec_ptr = torch.ops.torch_sparse.ind2ptr(
268259
type_vec, self.num_types)
269-
out = segment_matmul(x, type_vec_ptr, self.weight)
260+
out = pyg_lib.ops.segment_matmul(x, type_vec_ptr, self.weight)
270261
if self.bias is not None:
271262
out += self.bias[type_vec]
272263
else:

torch_geometric/nn/norm/layer_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, in_channels: int, eps: float = 1e-5,
4242

4343
self.in_channels = in_channels
4444
self.eps = eps
45+
self.affine = affine
4546
self.mode = mode
4647

4748
if affine:
@@ -94,4 +95,4 @@ def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
9495

9596
def __repr__(self):
9697
return (f'{self.__class__.__name__}({self.in_channels}, '
97-
f'mode={self.mode})')
98+
f'affine={self.affine}, mode={self.mode})')

torch_geometric/sampler/neighbor_sampler.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch import Tensor
66

7+
import torch_geometric.typing
78
from torch_geometric.data import (
89
Data,
910
FeatureStore,
@@ -25,12 +26,6 @@
2526
from torch_geometric.sampler.utils import remap_keys, to_csc, to_hetero_csc
2627
from torch_geometric.typing import EdgeType, NodeType, NumNeighbors, OptTensor
2728

28-
try:
29-
import pyg_lib # noqa
30-
_WITH_PYG_LIB = True
31-
except ImportError:
32-
_WITH_PYG_LIB = False
33-
3429

3530
class NeighborSampler(BaseSampler):
3631
r"""An implementation of an in-memory (heterogeneous) neighbor sampler used
@@ -227,7 +222,7 @@ def _sample(
227222
loaders."""
228223
# TODO(manan): remote backends only support heterogeneous graphs:
229224
if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData):
230-
if _WITH_PYG_LIB:
225+
if torch_geometric.typing.WITH_PYG_LIB:
231226
# TODO (matthias) `return_edge_id` if edge features present
232227
# TODO (matthias) Ideally, `seed` should inherit the type of
233228
# `colptr_dict` and `row_dict`.
@@ -285,7 +280,7 @@ def _sample(
285280
)
286281

287282
if issubclass(self.data_cls, Data):
288-
if _WITH_PYG_LIB:
283+
if torch_geometric.typing.WITH_PYG_LIB:
289284
# TODO (matthias) `return_edge_id` if edge features present
290285
# TODO (matthias) Ideally, `seed` should inherit the type of
291286
# `colptr` and `row`.

torch_geometric/typing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
from typing import Dict, List, Optional, Tuple, Union
1+
from typing import Dict, Final, List, Optional, Tuple, Union
22

33
import numpy as np
44
from torch import Tensor
55

6+
try:
7+
import pyg_lib # noqa
8+
WITH_PYG_LIB: Final[bool] = True
9+
except ImportError:
10+
pyg_lib = object
11+
WITH_PYG_LIB: Final[bool] = False
12+
613
try:
714
from torch_sparse import SparseTensor
815
except ImportError:

torch_geometric/utils/softmax.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from torch import Tensor
55
from torch_scatter import gather_csr, segment_csr
66

7+
import torch_geometric.typing
8+
from torch_geometric.typing import pyg_lib
79
from torch_geometric.utils import scatter
810

911
from .num_nodes import maybe_num_nodes
@@ -67,8 +69,12 @@ def softmax(
6769
N = maybe_num_nodes(index, num_nodes)
6870
with torch.no_grad():
6971
src_max = scatter(src, index, dim, dim_size=N, reduce='max')
70-
src_max = src_max.index_select(dim, index)
71-
out = (src - src_max).exp()
72+
if (torch_geometric.typing.WITH_PYG_LIB and src.dim() == 2
73+
and (dim == 0 or dim == -2)):
74+
out = pyg_lib.ops.sampled_sub(src, src_max, right_index=index)
75+
else:
76+
out = src - src_max.index_select(dim, index)
77+
out = out.exp()
7278
out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16
7379
out_sum = out_sum.index_select(dim, index)
7480
else:

0 commit comments

Comments
 (0)