Skip to content

Commit 0815d35

Browse files
authored
Update backend_utils.py
1 parent 1d2f43b commit 0815d35

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torch_geometric/utils/rag/backend_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def preprocess_triplet(triplet: TripletLike) -> TripletLike:
5050
h, r, t = triplet
5151
return str(h).lower(), str(r).lower(), str(t).lower()
5252

53+
def batch_knn(query_enc: Tensor, embeds: Tensor,
54+
k: int) -> Iterator[InputNodes]:
55+
from torchmetrics.functional import pairwise_cosine_similarity
56+
prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device))
57+
topk = min(k, len(embeds))
58+
for i, q in enumerate(prizes):
59+
_, indices = torch.topk(q, topk, largest=True)
60+
yield indices, query_enc[i].unsqueeze(0)
5361

5462
# Adapted from LocalGraphStore
5563
@runtime_checkable

0 commit comments

Comments
 (0)