Skip to content

Commit a541a68

Browse files
committed
update
1 parent 6721d0e commit a541a68

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

torch_geometric/metrics/link_pred.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ class LinkPredDiversity(_LinkPredMetric):
670670
def __init__(self, k: int, category: Tensor) -> None:
671671
super().__init__(k)
672672

673+
self.accum: Tensor
674+
self.total: Tensor
675+
673676
if WITH_TORCHMETRICS:
674677
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
675678
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
@@ -736,11 +739,14 @@ def __init__(
736739
self.max_src_nodes = max_src_nodes
737740
self.batch_size = batch_size
738741

742+
self.preds: List[Tensor]
743+
self.total: Tensor
744+
739745
if WITH_TORCHMETRICS:
740746
self.add_state('preds', default=[], dist_reduce_fx='cat')
741747
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
742748
else:
743-
self.preds: List[Tensor] = []
749+
self.preds = []
744750
self.register_buffer('total', torch.tensor(0), persistent=False)
745751

746752
def update(
@@ -826,6 +832,9 @@ class LinkPredAveragePopularity(_LinkPredMetric):
826832
def __init__(self, k: int, popularity: Tensor) -> None:
827833
super().__init__(k)
828834

835+
self.accum: Tensor
836+
self.total: Tensor
837+
829838
if WITH_TORCHMETRICS:
830839
self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')
831840
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')

torch_geometric/transforms/rooted_subgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def map(
9494
arange = torch.arange(n_id.size(0), device=data.edge_index.device)
9595
node_map = data.edge_index.new_ones(num_nodes, num_nodes)
9696
node_map[n_sub_batch, n_id] = arange
97-
sub_edge_index += (arange * data.num_nodes)[e_sub_batch]
97+
sub_edge_index += (arange * num_nodes)[e_sub_batch]
9898
sub_edge_index = node_map.view(-1)[sub_edge_index]
9999

100100
return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch

torch_geometric/utils/_negative_sampling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,9 @@ def negative_sampling(
109109
idx = idx.to('cpu')
110110
for _ in range(3): # Number of tries to sample negative indices.
111111
rnd = sample(population, sample_size, device='cpu')
112-
mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore
112+
mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()
113113
if neg_idx is not None:
114-
mask |= np.isin(rnd, neg_idx.to('cpu'))
115-
mask = torch.from_numpy(mask).to(torch.bool)
114+
mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()
116115
rnd = rnd[~mask].to(edge_index.device)
117116
neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])
118117
if neg_idx.numel() >= num_neg_samples:

0 commit comments

Comments
 (0)