PyG 2.4.0: Model compilation, on-disk datasets, hierarchical sampling
We are excited to announce the release of PyG 2.4 πππ
PyG 2.4 is the culmination of work from 62 contributors who have worked on features and bug-fixes for a total of over 500 commits since torch-geometric==2.3.1.
Highlights
PyTorch 2.1 and torch.compile(dynamic=True) support
The long wait has an end! With the release of PyTorch 2.1, PyG 2.4 now brings full support for torch.compile to graphs of varying size via the dynamic=True option, which is especially useful for use-cases that involve the usage of DataLoader or NeighborLoader. Examples and tutorials have been updated to reflect this support accordingly (#8134), and models and layers in torch_geometric.nn have been tested to produce zero graph breaks:
import torch_geometric
model = torch_geometric.compile(model, dynamic=True)When enabling the dynamic=True option, PyTorch will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. As such, you should only ever not specify dynamic=True when graph sizes are guaranteed to never change. Note that dynamic=True requires PyTorch >= 2.1.0 to be installed.
PyG 2.4 is fully compatible with PyTorch 2.1, and supports the following combinations:
| PyTorch 2.1 | cpu |
cu118 |
cu121 |
|---|---|---|---|
| Linux | β | β | β |
| macOS | β | ||
| Windows | β | β | β |
You can still install PyG 2.4 on older PyTorch releases up to PyTorch 1.11 in case you are not eager to update your PyTorch version.
OnDiskDataset Interface
We added the OnDiskDataset base class for creating large graph datasets (e.g., molecular databases with billions of graphs), which do not easily fit into CPU memory at once (#8028, #8044, #8046, #8051, #8052, #8054, #8057, #8058, #8066, #8088, #8092, #8106). OnDiskDataset leverages our newly introduced Database backend (sqlite3 by default) for on-disk storage and access of graphs, supports DataLoader out-of-the-box, and is optimized for maximum performance.
OnDiskDataset utilizes a user-specified schema to store data as efficient as possible (instead of Python pickling). The schema can take int, float str, object or a dictionary with dtype and size keys (for specifying tensor data) as input, and can be nested as a dictionary. For example,
dataset = OnDiskDataset(root, schema={
'x': dict(dtype=torch.float, size=(-1, 16)),
'edge_index': dict(dtype=torch.long, size=(2, -1)),
'y': float,
})creates a database with three columns, where x and edge_index are stored as binary data, and y is stored as a float.
Afterwards, you can append data to the OnDiskDataset and retrieve data from it via dataset.append()/dataset.extend(), and dataset.get()/dataset.multi_get(), respectively. We added a fully working example on how to set up your own OnDiskDataset here (#8102). You can also convert in-memory dataset instances to an OnDiskDataset instance by running InMemoryDataset.to_on_disk_dataset() (#8116).
Neighbor Sampling Improvements
Hierarchical Sampling
One drawback of NeighborLoader is that it computes a representations for all sampled nodes at all depths of the network. However, nodes sampled in later hops no longer contribute to the node representations of seed nodes in later GNN layers, thus performing useless computation. NeighborLoader will be marginally slower since we are computing node embeddings for nodes we no longer need. This is a trade-off we have made to obtain a clean, modular and experimental-friendly GNN design, which does not tie the definition of the model to its utilized data loader routine.
With PyG 2.4, we introduced the option to eliminate this overhead and speed-up training and inference in mini-batch GNNs further, which we call "Hierarchical Neighborhood Sampling" (see here for the full tutorial) (#6661, #7089, #7244, #7425, #7594, #7942). Its main idea is to progressively trim the adjacency matrix of the returned subgraph before inputting it to each GNN layer, and works seamlessly across several models, both in the homogeneous and heterogeneous graph setting. To support this trimming and implement it effectively, the NeighborLoader implementation in PyG and in pyg-lib additionally return the number of nodes and edges sampled in each hop, which are then used on a per-layer basis to trim the adjacency matrix and the various feature matrices to only maintain the required amount (see the trim_to_layer method):
class GNN(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, num_layers: int):
super().__init__()
self.convs = ModuleList([SAGEConv(in_channels, 64)])
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.lin = Linear(hidden_channels, out_channels)
def forward(
self,
x: Tensor,
edge_index: Tensor,
num_sampled_nodes_per_hop: List[int],
num_sampled_edges_per_hop: List[int],
) -> Tensor:
for i, conv in enumerate(self.convs):
# Trim edge and node information to the current layer `i`.
x, edge_index, _ = trim_to_layer(
i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
x, edge_index)
x = conv(x, edge_index).relu()
return self.lin(x)Corresponding examples can be found here and here.
Biased Sampling
Additionally, we added support for weighted/biased sampling in NeighborLoader/LinkNeighborLoader scenarios. For this, simply specify your edge_weight attribute during NeighborLoader initialization, and PyG will pick up these weights to perform weighted/biased sampling (#8038):
data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)
loader = NeighborLoader(
data,
num_neighbors=[10, 10],
weight_attr='edge_weight',
)
batch = next(iter(loader))New models, datasets, examples & tutorials
As part of our algorithm and documentation sprints (#7892), we have added:
- Model components:
MixHopConv: βMixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified Neighborhood Mixingβ (examples/mixhop.py) (#8025)LCMAggregation: βLearnable Commutative Monoids for Graph Neural Networksβ (examples/lcm_aggr_2nd_min.py) (#7976, #8020, #8023, #8026, #8075)DirGNNConv: βEdge Directionality Improves Learning on Heterophilic Graphsβ (examples/dir_gnn.py) (#7458)- Support for
PerformerinGPSConv: βRecipe for a General, Powerful, Scalable Graph Transformerβ (examples/graph_gps.py) (#7465) PMLP: βGraph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPsβ (examples/pmlp.py) (#7470, #7543)RotateE: βRotatE: Knowledge Graph Embedding by Relational Rotation in Complex Spaceβ (examples/kge_fb15k_237.py) (#7026)NeuralFingerprint: βConvolutional Networks on Graphs for Learning Molecular Fingerprintsβ (#7919)
- Datasets:
HM(#7515),BrcaTcga(#7994),MyketDataset(#7959),Wikidata5M(#7864),OSE_GVCS(#7811),MovieLens1M(#7479),AmazonBook(#7483),GDELTLite(#7442),IGMCDataset(#7441),MovieLens100K(#7398),EllipticBitcoinTemporalDataset(#7011),NeuroGraphDataset(#8112),PCQM4Mv2(#8102) - Tutorials:
- Examples:
- Heterogeneous link-level GNN explanations via
CaptumExplainer(examples/captum_explainer_hetero_link.py) (#7096) - Training
LightGCNonAmazonBookfor recommendation (examples/lightgcn.py) (#7603) - Using the KΓΉzu remote backend as
FeatureStore(examples/kuzu) (#7298) - Multi-GPU training on
ogbn-papers100M(examples/papers100m_multigpu.py) (#7921) - The
OGCmodel onCora(examples/ogc.py) (#8168) - Distributed training via
graphlearn-for-pytorch(examples/distributed/graphlearn_for_pytorch) (#7402)
- Heterogeneous link-level GNN explanations via
Join our Slack here if you're interested in joining community sprints in the future!
Breaking Changes
Data.keys()is now a method instead of a property (#7629):<=2.3 2.4 data = Data(x=x, edge_index=edge_index) print(data.keys) # ['x', 'edge_index']
data = Data(x=x, edge_index=edge_index) print(data.keys()) # ['x', 'edge_index']
- Dropped Python 3.7 support (#7939)
- RemovedΒ
FastHGTConvin favor ofΒHGTConvΒ (#7117) - Removed the
layer_typeargument fromGraphMaskExplainer(#7445) - Renamed
destargument todstinutils.geodesic_distance(#7708)
Deprecations
- DeprecatedΒ
contrib.explain.GraphMaskExplainerin favor ofΒexplain.algorithm.GraphMaskExplainerΒ (#7779)
Features
Data and HeteroData improvements
- Added a warning for isolated/non-existing node types in
HeteroData.validate()(#7995) - Added
HeteroDatasupport into_networkx(#7713) - Added
Data.sort()andHeteroData.sort()(#7649) - Added padding capabilities to
HeteroData.to_homogeneous()in case feature dimensionalities do not match (#7374) - Added
torch.nested_tensorsupport inDataandBatch(#7643, #7647) - Added
keep_inter_cluster_edgesoption toClusterDatato support inter-subgraph edge connections when doing graph partitioning (#7326)
Data-loading improvements
- Added support for floating-point slicing inΒ
Dataset,Β e.g.,Βdataset[:0.9]Β (#7915) - Added
saveandloadmethods toInMemoryDataset(#7250, #7413) - Beta: AddedΒ
IBMBNodeLoaderΒ andΒIBMBBatchLoaderΒ data loaders (#6230) - Beta: AddedΒ
HyperGraphDataΒ to support hypergraphs (#7611) - AddedΒ
CachedLoader(#7896, #7897) - Allowed GPU tensors as input toΒ
NodeLoaderΒ andΒLinkLoaderΒ (#7572) - AddedΒ
PrefetchLoaderΒ capabilities (#7376, #7378, #7383) - Added manual sampling interface toΒ
NodeLoaderΒ andΒLinkLoaderΒ (#7197)
Better support for sparse tensors
- AddedΒ
SparseTensorΒ support toΒWLConvContinuous,ΒGeneralConv,ΒPDNConvΒ andΒARMAConvΒ (#8013) - ChangeΒ
torch_sparse.SparseTensorΒ logic to utilizeΒtorch.sparse_csrΒ instead (#7041) - Added support forΒ
torch.sparse.TensorΒ inΒDataLoaderΒ (#7252) - Added support forΒ
torch.jit.scriptΒ withinΒMessagePassingΒ layers withoutΒtorch_sparseΒ being installed (#7061, #7062) - Added unbatching logic forΒ
torch.sparse.Tensor(#7037) - Added support forΒ
Data.num_edgesΒ for nativeΒtorch.sparse.TensorΒ adjacency matrices (#7104) - Accelerated sparse tensor conversion routines (#7042, #7043)
- Added a sparseΒ
cross_entropyΒ implementation (#7447, #7466)
Integration with 3rd-party libraries
- AddedΒ
FlopsCountΒ support viaΒfvcoreΒ (#7693) - AddedΒ
to_dglΒ andΒfrom_dglΒ conversion functions (#7053)
torch_geometric.transforms
- All transforms are now immutable, i.e. they perform a shallow-copy of the data and therefore do not longer modify data in-place (#7429)
- Added the
HalfHopgraph upsampling augmentation (#7827) - Added interval argument to
Cartesian,LocalCartesianandDistancetransformations (#7533, #7614, #7700) - Added an optional
add_pad_maskargument to thePadtransform (#7339) - Added
NodePropertySplittransformation for creating node-level splits using structural node properties (#6894) - Added a
AddRemainingSelfLoopstransformation (#7192)
Bugfixes
- Fixed
HeteroConvfor layers that have a non-default argument order, e.g.,GCN2Conv(#8166) - Handle reserved keywords as keys in
ModuleDictandParameterDict(#8163) - Fixed
DynamicBatchSampler.__len__to raise an error in casenum_stepsis undefined (#8137) - Enabled pickling of
DimeNetmodels (#8019) - Fixed a bug in which
batch.e_idwas not correctly computed on unsorted graph inputs (#7953) - Fixed
from_networkxconversion fromnx.stochastic_block_modelgraphs (#7941) - Fixed the usage of
bias_initializerinHeteroLinear(#7923) - Fixed broken URLs in
HGBDataset(#7907) - Fixed an issue where
SetTransformerAggregationproduced NaN values for isolates nodes (#7902) - Fixed
summaryon modules with uninitialized parameters (#7884) - Fixed tracing of
add_self_loopsfor a dynamic number of nodes (#7330) - Fixed device issue in
PNAConv.get_degree_histogram(#7830) - Fixed the shape of
edge_label_timewhen using temporal sampling on homogeneous graphs (#7807) - Fixed
edge_label_indexcomputation inLinkNeighborLoaderfor the homogeneous+disjoint mode (#7791) - Fixed
CaptumExplainerfor binary classification tasks (#7787) - Raise error when collecting non-existing attributes in
HeteroData(#7714) - Fixed
get_mesh_laplacianfornormalization="sym"(#7544) - Use
dim_sizeto initialize output size of theEquilibriumAggregationlayer (#7530) - Fixed empty edge indices handling in
SparseTensor(#7519) - Move the
scalertensor inGeneralConvto the correct device (#7484) - Fixed
HeteroLinearbug when used via mixed precision (#7473) - Fixed gradient computation of edge weights in
utils.spmm(#7428) - Fixed an index-out-of-range bug in
QuantileAggregationwhendim_sizeis passed (#7407) - Fixed a bug in
LightGCN.recommendation_loss()to only use the embeddings of the nodes involved in the current mini-batch (#7384) - Fixed a bug in which inputs where modified in-place in
to_hetero_with_bases(#7363) - Do not load
node_defaultandedge_defaultattributes infrom_networkx(#7348) - Fixed
HGTConvutility function_construct_src_node_feat(#7194) - Fixed
subgraphon unordered inputs (#7187) - Allow missing node types in
HeteroDictLinear(#7185) - Fix
numpyincompatiblity when reading files forPlanetoiddatasets (#7141) - Fixed crash of heterogeneous data loaders if node or edge types are missing (#7060, #7087)
- Allowed
CaptumExplainerto be called multiple times in a row (#7391)
Changes
- Enabled dense eigenvalue computation in
AddLaplacianEigenvectorPEfor small-scale graphs (#8143) - Accelerated and simplified
top_kcomputation inTopKPooling(#7737) - Updated
GINimplementation in benchmarks to apply sequential batch normalization (#7955) - Updated
QM9data pre-processing to include the SMILES string (#7867) - Warn user when using the
trainingflag into_heteromodules (#7772) - Changed
add_random_edgeto only add true negative edges (#7654) - Allowed the usage of
BasicGNNmodels inDeepGraphInfomax(#7648) - Added a
num_edgesparameter to the forward method ofHypergraphConv(#7560) - Added a
max_num_elementsparameter to the forward method ofGraphMultisetTransformer,GRUAggregation,LSTMAggregation,SetTransformerAggregationandSortAggregation(#7529, #7367) - Re-factored
ClusterLoaderto integratepyg-libMETIS routine (#7416) - The
filter_per_workeroption will not get automatically inferred by default based on the device of the underlying data (#7399) - Added the option to pass
fill_valueas atorch.tensortoutils.to_dense_batch(#7367) - Updated examples to use
NeighborLoaderinstead ofNeighborSampler(#7152) - Extend dataset summary to create stats for each node/edge type (#7203)
- Added an optional
batch_sizeargument toavg_pool_xandmax_pool_x(#7216) - Optimized
from_networkxmemory footprint by reducing unnecessary copies (#7119) - Added an optional
batch_sizeargument toLayerNorm,GraphNorm,InstanceNorm,GraphSizeNormandPairNorm(#7135) - Accelerated attention-based
MultiAggregation(#7077) - Edges in
HeterophilousGraphDatasetare now undirected by default (#7065) - Added an optional
batch_sizeandmax_num_nodesarguments toMemPoolinglayer (#7239)
Full Changelog
Full Changelog: 2.3.0...2.4.0
New Contributors
- @zoryzhang made their first contribution in #7027
- @DomInvivo made their first contribution in #7037
- @OlegPlatonov made their first contribution in #7065
- @hbenedek made their first contribution in #7053
- @rishiagarwal2000 made their first contribution in #7011
- @sisaman made their first contribution in #7104
- @amorehead made their first contribution in #7110
- @EulerPascal404 made their first contribution in #7093
- @Looong01 made their first contribution in #7143
- @kamil-andrzejewski made their first contribution in #7135
- @andreazanetti made their first contribution in #7089
- @akihironitta made their first contribution in #7195
- @kjkozlowski made their first contribution in #7216
- @vstenby made their first contribution in #7221
- @piotrchmiel made their first contribution in #7239
- @vedal made their first contribution in #7272
- @gvbazhenov made their first contribution in #6894
- @saydemr made their first contribution in #7313
- @HaoyuLu1022 made their first contribution in #7325
- @Vuenc made their first contribution in #7330
- @mewim made their first contribution in #7298
- @volltin made their first contribution in #7355
- @kasper-piskorski made their first contribution in #7377
- @happykygo made their first contribution in #7384
- @ThomasKLY made their first contribution in #7398
- @sky-2002 made their first contribution in #7421
- @denadai2 made their first contribution in #7456
- @chrisgo-gc made their first contribution in #7484
- @furkanakkurt1335 made their first contribution in #7507
- @mzamini92 made their first contribution in #7497
- @n-patricia made their first contribution in #7543
- @SalvishGoomanee made their first contribution in #7573
- @emalgorithm made their first contribution in #7458
- @marshka made their first contribution in #7595
- @djm93dev made their first contribution in #7598
- @NripeshN made their first contribution in #7770
- @ATheCoder made their first contribution in #7774
- @ebrahimpichka made their first contribution in #7775
- @kaidic made their first contribution in #7814
- @Wesxdz made their first contribution in #7811
- @daviddavo made their first contribution in #7888
- @frinkleko made their first contribution in #7907
- @chendiqian made their first contribution in #7917
- @rajveer43 made their first contribution in #7885
- @erfanloghmani made their first contribution in #7959
- @xnuohz made their first contribution in #7937
- @Favourj-bit made their first contribution in #7905
- @apfelsinecode made their first contribution in #7996
- @ArchieGertsman made their first contribution in #7976
- @bkmi made their first contribution in #8019
- @harshit5674 made their first contribution in #7919
- @erikhuck made their first contribution in #8024
- @jay-bhambhani made their first contribution in #8028
- @Barcavin made their first contribution in #8049
- @royvelich made their first contribution in #8048
- @CodeTal made their first contribution in #7611
- @filipekstrm made their first contribution in #8117
- @Anwar-Said made their first contribution in #8122
- @xYix made their first contribution in #8168