Skip to content

Commit 86f7adc

Browse files
committed
update
1 parent e0fb4a1 commit 86f7adc

File tree

1 file changed

+47
-24
lines changed

1 file changed

+47
-24
lines changed

torch_geometric/nn/nlp/txt2kg.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _prepare_prompt(self, items: List[str]) -> List[Dict[str, str]]:
244244
]
245245
return messages
246246

247-
def _multi_shot_resolution(self, items: List[str], num_shots: int = 5) -> List[str]:
247+
def _iterative_resolution(self, items: List[str], num_iters: int = 5) -> List[str]:
248248
items_to_original_id = {item: i for i, item in enumerate(items)}
249249

250250
# contains the mapping of the new items to the original items
@@ -254,13 +254,14 @@ def _multi_shot_resolution(self, items: List[str], num_shots: int = 5) -> List[s
254254

255255
print("before resolution: ", items_to_original_id.keys())
256256

257-
for shots in range(num_shots):
257+
for shots in range(num_iters):
258258
messages = self._prepare_prompt(items_to_original_id.keys())
259259
raw_result = self.remote_llm_caller(messages)
260260
print(f"Raw result: {raw_result}")
261261
try:
262262
ret_items, summary_word = self._process(raw_result)
263-
except:
263+
except Exception as e:
264+
print(f"[_iterative_resolution] Failed to process LLM output: {type(e).__name__}: {str(e)}")
264265
continue
265266
print(f"Summary word: {summary_word}")
266267
# handling diff kind of wrong LLM outputs
@@ -272,7 +273,7 @@ def _multi_shot_resolution(self, items: List[str], num_shots: int = 5) -> List[s
272273
# handles hallucinated items
273274
ret_items = [item for item in ret_items if item in items_to_original_id]
274275
# handles empty items
275-
ret_items = [item for item in ret_items if item]
276+
ret_items = list(set([item for item in ret_items if item]))
276277
print(f"Ret items (after filtering hallucinated items): {ret_items}")
277278

278279
summary_word = summary_word.strip(" '*-").lower()
@@ -343,8 +344,8 @@ def __call__(self, triples: List[Tuple[str, str, str]]) -> List[Tuple[str, str,
343344
# LLM based entity and relation resolution
344345
# TODO: explore others methods for resolution (like lighter specialized models)
345346
# we want more shots for relations as they usually are more ambiguous
346-
new_ents, new_ent_mapping = self._multi_shot_resolution(ents, num_shots=5)
347-
new_rels, new_rel_mapping = self._multi_shot_resolution(rels, num_shots=10)
347+
new_ents, new_ent_mapping = self._iterative_resolution(ents, num_iters=5)
348+
new_rels, new_rel_mapping = self._iterative_resolution(rels, num_iters=5)
348349

349350
consolidated_ent_mapping = {}
350351
for orig_id, cleaned_id in ent_mapping.items():
@@ -372,15 +373,27 @@ def _multistage_proc_helper(rank, in_chunks_per_proc,
372373
actions: List[Action],
373374
remote_llm_caller: RemoteLLMCaller):
374375
per_chunk_results = []
375-
for chunk in in_chunks_per_proc[rank]:
376-
out = chunk
377-
for action in actions:
378-
if isinstance(action, RemoteAction):
379-
messages = action.prepare_prompt(chunk, out)
380-
out = remote_llm_caller(messages)
381-
out = action.parse(out)
382-
per_chunk_results += out
383-
torch.save(per_chunk_results, "/tmp/txt2kg_outs_for_proc_" + str(rank))
376+
try:
377+
for chunk in in_chunks_per_proc[rank]:
378+
out = chunk
379+
for action in actions:
380+
try:
381+
if isinstance(action, RemoteAction):
382+
messages = action.prepare_prompt(chunk, out)
383+
out = remote_llm_caller(messages)
384+
out = action.parse(out)
385+
except Exception as e:
386+
print(f"[_multistage_proc_helper] Process {rank} failed on chunk processing: {type(e).__name__}: {str(e)}")
387+
import traceback
388+
print(f"[_multistage_proc_helper] Process {rank} traceback: {traceback.format_exc()}")
389+
out = []
390+
break
391+
per_chunk_results += out
392+
torch.save(per_chunk_results, "/tmp/txt2kg_outs_for_proc_" + str(rank))
393+
except Exception as e:
394+
print(f"[_multistage_proc_helper] Process {rank} failed completely: {type(e).__name__}: {str(e)}")
395+
import traceback
396+
print(f"[_multistage_proc_helper] Process {rank} complete failure traceback: {traceback.format_exc()}")
384397

385398
def consume_actions(chunks: Tuple[str],
386399
actions: List[Action],
@@ -412,18 +425,23 @@ def consume_actions(chunks: Tuple[str],
412425
nprocs=num_procs,
413426
join=True)
414427
break
415-
except: # noqa
428+
except Exception as e:
416429
total_num_tries += 1
417-
pass
430+
print(f"[consume_actions] Process spawn failed on attempt {total_num_tries}: {type(e).__name__}: {str(e)}")
431+
# For debugging, you might also want to see the full traceback:
432+
import traceback
433+
print(f"[consume_actions] Full traceback: {traceback.format_exc()}")
418434

419435
for rank in range(num_procs):
420436
result += torch.load(f"/tmp/txt2kg_outs_for_proc_{rank}")
421437
os.remove(f"/tmp/txt2kg_outs_for_proc_{rank}")
422438
break
423-
except:
439+
except Exception as e:
424440
total_num_tries += 1
425-
pass
426-
print(f"[_llm_call_and_consume] Total number of tries: {total_num_tries}")
441+
print(f"[consume_actions] Overall retry {retry+1}/5 failed: {type(e).__name__}: {str(e)}")
442+
import traceback
443+
print(f"[consume_actions] Full traceback: {traceback.format_exc()}")
444+
print(f"[consume_actions] Total number of tries: {total_num_tries}")
427445
return result
428446

429447

@@ -778,17 +796,22 @@ def _llm_call_and_consume(chunks: Tuple[str], system_prompt: str, NVIDIA_API_KEY
778796
nprocs=num_procs,
779797
join=True)
780798
break
781-
except: # noqa
799+
except Exception as e:
782800
total_num_tries += 1
783-
pass
801+
print(f"[_llm_call_and_consume] Process spawn failed on attempt {total_num_tries}: {type(e).__name__}: {str(e)}")
802+
# For debugging, you might also want to see the full traceback:
803+
import traceback
804+
print(f"[_llm_call_and_consume] Full traceback: {traceback.format_exc()}")
784805

785806
for rank in range(num_procs):
786807
result += torch.load(f"/tmp/txt2kg_outs_for_proc_{rank}")
787808
os.remove(f"/tmp/txt2kg_outs_for_proc_{rank}")
788809
break
789-
except:
810+
except Exception as e:
790811
total_num_tries += 1
791-
pass
812+
print(f"[_llm_call_and_consume] Overall retry {retry+1}/5 failed: {type(e).__name__}: {str(e)}")
813+
import traceback
814+
print(f"[_llm_call_and_consume] Full traceback: {traceback.format_exc()}")
792815
print(f"[_llm_call_and_consume] Total number of tries: {total_num_tries}")
793816
return result
794817

0 commit comments

Comments
 (0)