Skip to content

Commit 3e34f79

Browse files
authored
Update backend_utils.py
1 parent 3fd3874 commit 3e34f79

File tree

1 file changed

+123
-1
lines changed

1 file changed

+123
-1
lines changed

torch_geometric/utils/rag/backend_utils.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
TripletLike,
2626
)
2727
from torch_geometric.data.large_graph_indexer import EDGE_RELATION
28-
from torch_geometric.datasets.web_qsp_dataset import retrieval_via_pcst
2928
from torch_geometric.distributed import (
3029
LocalFeatureStore,
3130
LocalGraphStore,
@@ -50,6 +49,129 @@ def preprocess_triplet(triplet: TripletLike) -> TripletLike:
5049
h, r, t = triplet
5150
return str(h).lower(), str(r).lower(), str(t).lower()
5251

52+
@no_type_check
53+
def retrieval_via_pcst(
54+
data: Data,
55+
q_emb: Tensor,
56+
textual_nodes: Any,
57+
textual_edges: Any,
58+
topk: int = 3,
59+
topk_e: int = 5,
60+
cost_e: float = 0.5,
61+
num_clusters: int = 1,
62+
) -> Tuple[Data, str]:
63+
64+
# skip PCST for bad graphs
65+
booly = data.edge_attr is None or data.edge_attr.numel() == 0
66+
booly = booly or data.x is None or data.x.numel() == 0
67+
booly = booly or data.edge_index is None or data.edge_index.numel() == 0
68+
if not booly:
69+
c = 0.01
70+
71+
from pcst_fast import pcst_fast
72+
73+
root = -1
74+
pruning = 'gw'
75+
verbosity_level = 0
76+
if topk > 0:
77+
n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)
78+
topk = min(topk, data.num_nodes)
79+
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
80+
81+
n_prizes = torch.zeros_like(n_prizes)
82+
n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float()
83+
else:
84+
n_prizes = torch.zeros(data.num_nodes)
85+
86+
if topk_e > 0:
87+
e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)
88+
topk_e = min(topk_e, e_prizes.unique().size(0))
89+
90+
topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e,
91+
largest=True)
92+
e_prizes[e_prizes < topk_e_values[-1]] = 0.0
93+
last_topk_e_value = topk_e
94+
for k in range(topk_e):
95+
indices = e_prizes == topk_e_values[k]
96+
value = min((topk_e - k) / sum(indices), last_topk_e_value - c)
97+
e_prizes[indices] = value
98+
last_topk_e_value = value * (1 - c)
99+
# reduce the cost of the edges so that at least one edge is chosen
100+
cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))
101+
else:
102+
e_prizes = torch.zeros(data.num_edges)
103+
104+
costs = []
105+
edges = []
106+
virtual_n_prizes = []
107+
virtual_edges = []
108+
virtual_costs = []
109+
mapping_n = {}
110+
mapping_e = {}
111+
for i, (src, dst) in enumerate(data.edge_index.t().numpy()):
112+
prize_e = e_prizes[i]
113+
if prize_e <= cost_e:
114+
mapping_e[len(edges)] = i
115+
edges.append((src, dst))
116+
costs.append(cost_e - prize_e)
117+
else:
118+
virtual_node_id = data.num_nodes + len(virtual_n_prizes)
119+
mapping_n[virtual_node_id] = i
120+
virtual_edges.append((src, virtual_node_id))
121+
virtual_edges.append((virtual_node_id, dst))
122+
virtual_costs.append(0)
123+
virtual_costs.append(0)
124+
virtual_n_prizes.append(prize_e - cost_e)
125+
126+
prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])
127+
num_edges = len(edges)
128+
if len(virtual_costs) > 0:
129+
costs = np.array(costs + virtual_costs)
130+
edges = np.array(edges + virtual_edges)
131+
132+
vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,
133+
pruning, verbosity_level)
134+
135+
selected_nodes = vertices[vertices < data.num_nodes]
136+
selected_edges = [mapping_e[e] for e in edges if e < num_edges]
137+
virtual_vertices = vertices[vertices >= data.num_nodes]
138+
if len(virtual_vertices) > 0:
139+
virtual_vertices = vertices[vertices >= data.num_nodes]
140+
virtual_edges = [mapping_n[i] for i in virtual_vertices]
141+
selected_edges = np.array(selected_edges + virtual_edges)
142+
143+
edge_index = data.edge_index[:, selected_edges]
144+
selected_nodes = np.unique(
145+
np.concatenate(
146+
[selected_nodes, edge_index[0].numpy(),
147+
edge_index[1].numpy()]))
148+
149+
n = textual_nodes.iloc[selected_nodes]
150+
e = textual_edges.iloc[selected_edges]
151+
else:
152+
n = textual_nodes
153+
e = textual_edges
154+
desc = n.to_csv(index=False) + '\n' + e.to_csv(
155+
index=False, columns=['src', 'edge_attr', 'dst'])
156+
157+
mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}
158+
src = [mapping[i] for i in edge_index[0].tolist()]
159+
dst = [mapping[i] for i in edge_index[1].tolist()]
160+
161+
# HACK Added so that the subset of nodes and edges selected can be tracked
162+
node_idx = np.array(data.node_idx)[selected_nodes]
163+
edge_idx = np.array(data.edge_idx)[selected_edges]
164+
165+
data = Data(
166+
x=data.x[selected_nodes],
167+
edge_index=torch.tensor([src, dst]).to(torch.long),
168+
edge_attr=data.edge_attr[selected_edges],
169+
# HACK Added so that the subset of nodes and edges selected can be tracked
170+
node_idx=node_idx,
171+
edge_idx=edge_idx,
172+
)
173+
174+
return data, desc
53175

54176
def batch_knn(query_enc: Tensor, embeds: Tensor,
55177
k: int) -> Iterator[InputNodes]:

0 commit comments

Comments
 (0)