@@ -244,7 +244,7 @@ def _prepare_prompt(self, items: List[str]) -> List[Dict[str, str]]:
244
244
]
245
245
return messages
246
246
247
- def _multi_shot_resolution (self , items : List [str ], num_shots : int = 5 ) -> List [str ]:
247
+ def _iterative_resolution (self , items : List [str ], num_iters : int = 5 ) -> List [str ]:
248
248
items_to_original_id = {item : i for i , item in enumerate (items )}
249
249
250
250
# contains the mapping of the new items to the original items
@@ -254,13 +254,14 @@ def _multi_shot_resolution(self, items: List[str], num_shots: int = 5) -> List[s
254
254
255
255
print ("before resolution: " , items_to_original_id .keys ())
256
256
257
- for shots in range (num_shots ):
257
+ for shots in range (num_iters ):
258
258
messages = self ._prepare_prompt (items_to_original_id .keys ())
259
259
raw_result = self .remote_llm_caller (messages )
260
260
print (f"Raw result: { raw_result } " )
261
261
try :
262
262
ret_items , summary_word = self ._process (raw_result )
263
- except :
263
+ except Exception as e :
264
+ print (f"[_iterative_resolution] Failed to process LLM output: { type (e ).__name__ } : { str (e )} " )
264
265
continue
265
266
print (f"Summary word: { summary_word } " )
266
267
# handling diff kind of wrong LLM outputs
@@ -272,7 +273,7 @@ def _multi_shot_resolution(self, items: List[str], num_shots: int = 5) -> List[s
272
273
# handles hallucinated items
273
274
ret_items = [item for item in ret_items if item in items_to_original_id ]
274
275
# handles empty items
275
- ret_items = [item for item in ret_items if item ]
276
+ ret_items = list ( set ( [item for item in ret_items if item ]))
276
277
print (f"Ret items (after filtering hallucinated items): { ret_items } " )
277
278
278
279
summary_word = summary_word .strip (" '*-" ).lower ()
@@ -343,8 +344,8 @@ def __call__(self, triples: List[Tuple[str, str, str]]) -> List[Tuple[str, str,
343
344
# LLM based entity and relation resolution
344
345
# TODO: explore others methods for resolution (like lighter specialized models)
345
346
# we want more shots for relations as they usually are more ambiguous
346
- new_ents , new_ent_mapping = self ._multi_shot_resolution (ents , num_shots = 5 )
347
- new_rels , new_rel_mapping = self ._multi_shot_resolution (rels , num_shots = 10 )
347
+ new_ents , new_ent_mapping = self ._iterative_resolution (ents , num_iters = 5 )
348
+ new_rels , new_rel_mapping = self ._iterative_resolution (rels , num_iters = 5 )
348
349
349
350
consolidated_ent_mapping = {}
350
351
for orig_id , cleaned_id in ent_mapping .items ():
@@ -372,15 +373,27 @@ def _multistage_proc_helper(rank, in_chunks_per_proc,
372
373
actions : List [Action ],
373
374
remote_llm_caller : RemoteLLMCaller ):
374
375
per_chunk_results = []
375
- for chunk in in_chunks_per_proc [rank ]:
376
- out = chunk
377
- for action in actions :
378
- if isinstance (action , RemoteAction ):
379
- messages = action .prepare_prompt (chunk , out )
380
- out = remote_llm_caller (messages )
381
- out = action .parse (out )
382
- per_chunk_results += out
383
- torch .save (per_chunk_results , "/tmp/txt2kg_outs_for_proc_" + str (rank ))
376
+ try :
377
+ for chunk in in_chunks_per_proc [rank ]:
378
+ out = chunk
379
+ for action in actions :
380
+ try :
381
+ if isinstance (action , RemoteAction ):
382
+ messages = action .prepare_prompt (chunk , out )
383
+ out = remote_llm_caller (messages )
384
+ out = action .parse (out )
385
+ except Exception as e :
386
+ print (f"[_multistage_proc_helper] Process { rank } failed on chunk processing: { type (e ).__name__ } : { str (e )} " )
387
+ import traceback
388
+ print (f"[_multistage_proc_helper] Process { rank } traceback: { traceback .format_exc ()} " )
389
+ out = []
390
+ break
391
+ per_chunk_results += out
392
+ torch .save (per_chunk_results , "/tmp/txt2kg_outs_for_proc_" + str (rank ))
393
+ except Exception as e :
394
+ print (f"[_multistage_proc_helper] Process { rank } failed completely: { type (e ).__name__ } : { str (e )} " )
395
+ import traceback
396
+ print (f"[_multistage_proc_helper] Process { rank } complete failure traceback: { traceback .format_exc ()} " )
384
397
385
398
def consume_actions (chunks : Tuple [str ],
386
399
actions : List [Action ],
@@ -412,18 +425,23 @@ def consume_actions(chunks: Tuple[str],
412
425
nprocs = num_procs ,
413
426
join = True )
414
427
break
415
- except : # noqa
428
+ except Exception as e :
416
429
total_num_tries += 1
417
- pass
430
+ print (f"[consume_actions] Process spawn failed on attempt { total_num_tries } : { type (e ).__name__ } : { str (e )} " )
431
+ # For debugging, you might also want to see the full traceback:
432
+ import traceback
433
+ print (f"[consume_actions] Full traceback: { traceback .format_exc ()} " )
418
434
419
435
for rank in range (num_procs ):
420
436
result += torch .load (f"/tmp/txt2kg_outs_for_proc_{ rank } " )
421
437
os .remove (f"/tmp/txt2kg_outs_for_proc_{ rank } " )
422
438
break
423
- except :
439
+ except Exception as e :
424
440
total_num_tries += 1
425
- pass
426
- print (f"[_llm_call_and_consume] Total number of tries: { total_num_tries } " )
441
+ print (f"[consume_actions] Overall retry { retry + 1 } /5 failed: { type (e ).__name__ } : { str (e )} " )
442
+ import traceback
443
+ print (f"[consume_actions] Full traceback: { traceback .format_exc ()} " )
444
+ print (f"[consume_actions] Total number of tries: { total_num_tries } " )
427
445
return result
428
446
429
447
@@ -778,17 +796,22 @@ def _llm_call_and_consume(chunks: Tuple[str], system_prompt: str, NVIDIA_API_KEY
778
796
nprocs = num_procs ,
779
797
join = True )
780
798
break
781
- except : # noqa
799
+ except Exception as e :
782
800
total_num_tries += 1
783
- pass
801
+ print (f"[_llm_call_and_consume] Process spawn failed on attempt { total_num_tries } : { type (e ).__name__ } : { str (e )} " )
802
+ # For debugging, you might also want to see the full traceback:
803
+ import traceback
804
+ print (f"[_llm_call_and_consume] Full traceback: { traceback .format_exc ()} " )
784
805
785
806
for rank in range (num_procs ):
786
807
result += torch .load (f"/tmp/txt2kg_outs_for_proc_{ rank } " )
787
808
os .remove (f"/tmp/txt2kg_outs_for_proc_{ rank } " )
788
809
break
789
- except :
810
+ except Exception as e :
790
811
total_num_tries += 1
791
- pass
812
+ print (f"[_llm_call_and_consume] Overall retry { retry + 1 } /5 failed: { type (e ).__name__ } : { str (e )} " )
813
+ import traceback
814
+ print (f"[_llm_call_and_consume] Full traceback: { traceback .format_exc ()} " )
792
815
print (f"[_llm_call_and_consume] Total number of tries: { total_num_tries } " )
793
816
return result
794
817
0 commit comments