2525 TripletLike ,
2626)
2727from torch_geometric .data .large_graph_indexer import EDGE_RELATION
28- from torch_geometric .datasets .web_qsp_dataset import retrieval_via_pcst
2928from torch_geometric .distributed import (
3029 LocalFeatureStore ,
3130 LocalGraphStore ,
@@ -50,6 +49,129 @@ def preprocess_triplet(triplet: TripletLike) -> TripletLike:
5049 h , r , t = triplet
5150 return str (h ).lower (), str (r ).lower (), str (t ).lower ()
5251
52+ @no_type_check
53+ def retrieval_via_pcst (
54+ data : Data ,
55+ q_emb : Tensor ,
56+ textual_nodes : Any ,
57+ textual_edges : Any ,
58+ topk : int = 3 ,
59+ topk_e : int = 5 ,
60+ cost_e : float = 0.5 ,
61+ num_clusters : int = 1 ,
62+ ) -> Tuple [Data , str ]:
63+
64+ # skip PCST for bad graphs
65+ booly = data .edge_attr is None or data .edge_attr .numel () == 0
66+ booly = booly or data .x is None or data .x .numel () == 0
67+ booly = booly or data .edge_index is None or data .edge_index .numel () == 0
68+ if not booly :
69+ c = 0.01
70+
71+ from pcst_fast import pcst_fast
72+
73+ root = - 1
74+ pruning = 'gw'
75+ verbosity_level = 0
76+ if topk > 0 :
77+ n_prizes = torch .nn .CosineSimilarity (dim = - 1 )(q_emb , data .x )
78+ topk = min (topk , data .num_nodes )
79+ _ , topk_n_indices = torch .topk (n_prizes , topk , largest = True )
80+
81+ n_prizes = torch .zeros_like (n_prizes )
82+ n_prizes [topk_n_indices ] = torch .arange (topk , 0 , - 1 ).float ()
83+ else :
84+ n_prizes = torch .zeros (data .num_nodes )
85+
86+ if topk_e > 0 :
87+ e_prizes = torch .nn .CosineSimilarity (dim = - 1 )(q_emb , data .edge_attr )
88+ topk_e = min (topk_e , e_prizes .unique ().size (0 ))
89+
90+ topk_e_values , _ = torch .topk (e_prizes .unique (), topk_e ,
91+ largest = True )
92+ e_prizes [e_prizes < topk_e_values [- 1 ]] = 0.0
93+ last_topk_e_value = topk_e
94+ for k in range (topk_e ):
95+ indices = e_prizes == topk_e_values [k ]
96+ value = min ((topk_e - k ) / sum (indices ), last_topk_e_value - c )
97+ e_prizes [indices ] = value
98+ last_topk_e_value = value * (1 - c )
99+ # reduce the cost of the edges so that at least one edge is chosen
100+ cost_e = min (cost_e , e_prizes .max ().item () * (1 - c / 2 ))
101+ else :
102+ e_prizes = torch .zeros (data .num_edges )
103+
104+ costs = []
105+ edges = []
106+ virtual_n_prizes = []
107+ virtual_edges = []
108+ virtual_costs = []
109+ mapping_n = {}
110+ mapping_e = {}
111+ for i , (src , dst ) in enumerate (data .edge_index .t ().numpy ()):
112+ prize_e = e_prizes [i ]
113+ if prize_e <= cost_e :
114+ mapping_e [len (edges )] = i
115+ edges .append ((src , dst ))
116+ costs .append (cost_e - prize_e )
117+ else :
118+ virtual_node_id = data .num_nodes + len (virtual_n_prizes )
119+ mapping_n [virtual_node_id ] = i
120+ virtual_edges .append ((src , virtual_node_id ))
121+ virtual_edges .append ((virtual_node_id , dst ))
122+ virtual_costs .append (0 )
123+ virtual_costs .append (0 )
124+ virtual_n_prizes .append (prize_e - cost_e )
125+
126+ prizes = np .concatenate ([n_prizes , np .array (virtual_n_prizes )])
127+ num_edges = len (edges )
128+ if len (virtual_costs ) > 0 :
129+ costs = np .array (costs + virtual_costs )
130+ edges = np .array (edges + virtual_edges )
131+
132+ vertices , edges = pcst_fast (edges , prizes , costs , root , num_clusters ,
133+ pruning , verbosity_level )
134+
135+ selected_nodes = vertices [vertices < data .num_nodes ]
136+ selected_edges = [mapping_e [e ] for e in edges if e < num_edges ]
137+ virtual_vertices = vertices [vertices >= data .num_nodes ]
138+ if len (virtual_vertices ) > 0 :
139+ virtual_vertices = vertices [vertices >= data .num_nodes ]
140+ virtual_edges = [mapping_n [i ] for i in virtual_vertices ]
141+ selected_edges = np .array (selected_edges + virtual_edges )
142+
143+ edge_index = data .edge_index [:, selected_edges ]
144+ selected_nodes = np .unique (
145+ np .concatenate (
146+ [selected_nodes , edge_index [0 ].numpy (),
147+ edge_index [1 ].numpy ()]))
148+
149+ n = textual_nodes .iloc [selected_nodes ]
150+ e = textual_edges .iloc [selected_edges ]
151+ else :
152+ n = textual_nodes
153+ e = textual_edges
154+ desc = n .to_csv (index = False ) + '\n ' + e .to_csv (
155+ index = False , columns = ['src' , 'edge_attr' , 'dst' ])
156+
157+ mapping = {n : i for i , n in enumerate (selected_nodes .tolist ())}
158+ src = [mapping [i ] for i in edge_index [0 ].tolist ()]
159+ dst = [mapping [i ] for i in edge_index [1 ].tolist ()]
160+
161+ # HACK Added so that the subset of nodes and edges selected can be tracked
162+ node_idx = np .array (data .node_idx )[selected_nodes ]
163+ edge_idx = np .array (data .edge_idx )[selected_edges ]
164+
165+ data = Data (
166+ x = data .x [selected_nodes ],
167+ edge_index = torch .tensor ([src , dst ]).to (torch .long ),
168+ edge_attr = data .edge_attr [selected_edges ],
169+ # HACK Added so that the subset of nodes and edges selected can be tracked
170+ node_idx = node_idx ,
171+ edge_idx = edge_idx ,
172+ )
173+
174+ return data , desc
53175
54176def batch_knn (query_enc : Tensor , embeds : Tensor ,
55177 k : int ) -> Iterator [InputNodes ]:
0 commit comments