Skip to content

Commit de43c5c

Browse files
authored
Don't cast dtype right after tensor creation in to_dense_batch
1 parent b0b6d4e commit de43c5c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch_geometric/utils/_to_dense_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def to_dense_batch(
123123
x, idx = x[mask], idx[mask]
124124

125125
size = [batch_size * max_num_nodes] + list(x.size())[1:]
126-
out = torch.as_tensor(fill_value, device=x.device)
127-
out = out.to(x.dtype).repeat(size)
126+
out = torch.as_tensor(fill_value, device=x.device, dtype=x.dtype)
127+
out = out.repeat(size)
128128
out[idx] = x
129129
out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])
130130

0 commit comments

Comments
 (0)