Skip to content

Commit bad72a2

Browse files
committed
order nodes by batch when creating dense adj
1 parent 49288d6 commit bad72a2

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

test/utils/test_to_dense_adj.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,28 @@ def test_to_dense_adj_with_duplicate_entries():
106106
[0.0, 0.0, 13.0],
107107
[8.0, 0.0, 0.0],
108108
]
109+
110+
111+
def test_to_dense_adj_with_unordered_batch():
112+
edge_index = torch.tensor([
113+
[0, 1, 2, 3],
114+
[3, 2, 1, 0],
115+
])
116+
batch = torch.tensor([0, 1, 1, 0])
117+
118+
adj = to_dense_adj(edge_index, batch)
119+
assert adj.size() == (2, 2, 2)
120+
assert adj[0].tolist() == [[0.0, 1.0], [1.0, 0.0]]
121+
assert adj[1].tolist() == [[0.0, 1.0], [1.0, 0.0]]
122+
123+
edge_index = torch.tensor([
124+
[0, 1, 2, 3],
125+
[3, 2, 1, 0],
126+
])
127+
batch = torch.tensor([0, 1, 1, 0])
128+
edge_attr = torch.tensor([1.0, 3.0, 4.0, 2.0])
129+
130+
adj = to_dense_adj(edge_index, batch, edge_attr)
131+
assert adj.size() == (2, 2, 2)
132+
assert adj[0].tolist() == [[0.0, 1.0], [2.0, 0.0]]
133+
assert adj[1].tolist() == [[0.0, 3.0], [4.0, 0.0]]

torch_geometric/utils/_to_dense_adj.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def to_dense_adj(
6767
if batch_size is None:
6868
batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1
6969

70+
perm = batch.argsort()
71+
batch = batch[perm]
72+
new_index_map = torch.empty_like(perm)
73+
new_index_map[perm] = torch.arange(perm.size(0))
74+
edge_index = new_index_map[edge_index]
75+
7076
one = batch.new_ones(batch.size(0))
7177
num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')
7278
cum_nodes = cumsum(num_nodes)

0 commit comments

Comments
 (0)