Skip to content

Commit ff9fb3d

Browse files
tingyu66pre-commit-ci[bot]
authored andcommitted
Update cugraph conv layers for pylibcugraphops=23.04 (#7023)
This PR updates cugraph models to reflect breaking changes in `pylibcugraphops=23.04`. ~~Right now, it is **blocked** by RAPIDS 23.04 release.~~ CC: @MatthiasKohl @stadlmax --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b3cba5e commit ff9fb3d

File tree

4 files changed

+69
-26
lines changed

4 files changed

+69
-26
lines changed

torch_geometric/nn/conv/cugraph/base.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,26 @@
88
from torch_geometric.utils.sparse import index2ptr
99

1010
try: # pragma: no cover
11-
from pylibcugraphops import (
12-
make_fg_csr,
13-
make_fg_csr_hg,
14-
make_mfg_csr,
15-
make_mfg_csr_hg,
11+
LEGACY_MODE = False
12+
from pylibcugraphops.pytorch import (
13+
SampledCSC,
14+
SampledHeteroCSC,
15+
StaticCSC,
16+
StaticHeteroCSC,
1617
)
1718
HAS_PYLIBCUGRAPHOPS = True
1819
except ImportError:
1920
HAS_PYLIBCUGRAPHOPS = False
21+
try: # pragma: no cover
22+
from pylibcugraphops import (
23+
make_fg_csr,
24+
make_fg_csr_hg,
25+
make_mfg_csr,
26+
make_mfg_csr_hg,
27+
)
28+
LEGACY_MODE = True
29+
except ImportError:
30+
pass
2031

2132

2233
class CuGraphModule(torch.nn.Module): # pragma: no cover
@@ -25,9 +36,9 @@ class CuGraphModule(torch.nn.Module): # pragma: no cover
2536
def __init__(self):
2637
super().__init__()
2738

28-
if HAS_PYLIBCUGRAPHOPS is False:
39+
if not HAS_PYLIBCUGRAPHOPS and not LEGACY_MODE:
2940
raise ModuleNotFoundError(f"'{self.__class__.__name__}' requires "
30-
f"'pylibcugraphops'")
41+
f"'pylibcugraphops>=23.02'")
3142

3243
def reset_parameters(self):
3344
r"""Resets all learnable parameters of the module."""
@@ -99,12 +110,17 @@ def get_cugraph(
99110
if max_num_neighbors is None:
100111
max_num_neighbors = int((colptr[1:] - colptr[:-1]).max())
101112

102-
dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
113+
if LEGACY_MODE:
114+
dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
115+
return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,
116+
num_src_nodes)
117+
118+
return SampledCSC(colptr, row, max_num_neighbors, num_src_nodes)
103119

104-
return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,
105-
num_src_nodes)
120+
if LEGACY_MODE:
121+
return make_fg_csr(colptr, row)
106122

107-
return make_fg_csr(colptr, row)
123+
return StaticCSC(colptr, row)
108124

109125
def get_typed_cugraph(
110126
self,
@@ -142,17 +158,24 @@ def get_typed_cugraph(
142158
if max_num_neighbors is None:
143159
max_num_neighbors = int((colptr[1:] - colptr[:-1]).max())
144160

145-
dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
161+
if LEGACY_MODE:
162+
dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)
163+
return make_mfg_csr_hg(dst_nodes, colptr, row,
164+
max_num_neighbors, num_src_nodes,
165+
n_node_types=0,
166+
n_edge_types=num_edge_types,
167+
out_node_types=None, in_node_types=None,
168+
edge_types=edge_type)
169+
170+
return SampledHeteroCSC(colptr, row, edge_type, max_num_neighbors,
171+
num_src_nodes, num_edge_types)
146172

147-
return make_mfg_csr_hg(dst_nodes, colptr, row, max_num_neighbors,
148-
num_src_nodes, n_node_types=0,
149-
n_edge_types=num_edge_types,
150-
out_node_types=None, in_node_types=None,
151-
edge_types=edge_type)
173+
if LEGACY_MODE:
174+
return make_fg_csr_hg(colptr, row, n_node_types=0,
175+
n_edge_types=num_edge_types, node_types=None,
176+
edge_types=edge_type)
152177

153-
return make_fg_csr_hg(colptr, row, n_node_types=0,
154-
n_edge_types=num_edge_types, node_types=None,
155-
edge_types=edge_type)
178+
return StaticHeteroCSC(colptr, row, edge_type, num_edge_types)
156179

157180
def forward(
158181
self,

torch_geometric/nn/conv/cugraph/gat_conv.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from torch.nn import Linear, Parameter
66

77
from torch_geometric.nn.conv.cugraph import CuGraphModule
8+
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
89
from torch_geometric.nn.inits import zeros
910

1011
try:
11-
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
12+
if LEGACY_MODE:
13+
from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg
14+
else:
15+
from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg
1216
except ImportError:
1317
pass
1418

@@ -67,8 +71,13 @@ def forward(
6771
graph = self.get_cugraph(csc, max_num_neighbors)
6872

6973
x = self.lin(x)
70-
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
71-
self.negative_slope, False, self.concat)
74+
75+
if LEGACY_MODE:
76+
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
77+
self.negative_slope, False, self.concat)
78+
else:
79+
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
80+
self.negative_slope, self.concat)
7281

7382
if self.bias is not None:
7483
out = out + self.bias

torch_geometric/nn/conv/cugraph/rgcn_conv.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55
from torch.nn import Parameter
66

77
from torch_geometric.nn.conv.cugraph import CuGraphModule
8+
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
89
from torch_geometric.nn.inits import glorot, zeros
910

1011
try:
11-
from pylibcugraphops.torch.autograd import \
12-
agg_hg_basis_n2n_post as RGCNConvAgg
12+
if LEGACY_MODE:
13+
from pylibcugraphops.torch.autograd import \
14+
agg_hg_basis_n2n_post as RGCNConvAgg
15+
else:
16+
from pylibcugraphops.pytorch.operators import \
17+
agg_hg_basis_n2n_post as RGCNConvAgg
1318
except ImportError:
1419
pass
1520

torch_geometric/nn/conv/cugraph/sage_conv.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
from torch.nn import Linear
66

77
from torch_geometric.nn.conv.cugraph import CuGraphModule
8+
from torch_geometric.nn.conv.cugraph.base import LEGACY_MODE
89

910
try:
10-
from pylibcugraphops.torch.autograd import agg_concat_n2n as SAGEConvAgg
11+
if LEGACY_MODE:
12+
from pylibcugraphops.torch.autograd import \
13+
agg_concat_n2n as SAGEConvAgg
14+
else:
15+
from pylibcugraphops.pytorch.operators import \
16+
agg_concat_n2n as SAGEConvAgg
1117
except ImportError:
1218
pass
1319

0 commit comments

Comments
 (0)