File tree Expand file tree Collapse file tree 2 files changed +17
-7
lines changed
paddleformers/transformers Expand file tree Collapse file tree 2 files changed +17
-7
lines changed Original file line number Diff line number Diff line change @@ -3496,10 +3496,15 @@ def decode_token(
34963496 # from byte fallback tokenization.
34973497 # If it's in the middle, it's probably a real invalid id generated
34983498 # by the model
3499- prefix_index = new_text .index (prefix_text )
3500- new_text = new_text [prefix_index + len (prefix_text ) :]
3501- return new_text , read_offset , len (all_input_ids )
3499+ if new_text .startswith (prefix_text ):
3500+ prefix_index = new_text .index (prefix_text )
3501+ new_text = new_text [prefix_index + len (prefix_text ) :]
3502+ return new_text , read_offset , len (all_input_ids )
3503+ else :
3504+ return "" , prefix_offset , len (all_input_ids )
35023505 else :
3506+ if len (all_input_ids [prefix_offset :]) > 3 :
3507+ return new_text , len (all_input_ids ), len (all_input_ids )
35033508 return "" , prefix_offset , read_offset
35043509
35053510 def batch_decode (
Original file line number Diff line number Diff line change @@ -487,15 +487,20 @@ def decode_token(
487487 all_input_ids [prefix_offset :], skip_special_tokens = skip_special_tokens , clean_up_tokenization_spaces = False
488488 )
489489
490- if len (new_text ) > len (prefix_text ) and not prefix_text . endswith ( "�" ) and not new_text . endswith ( "�" ) :
490+ if len (new_text ) > len (prefix_text ) and "�" not in prefix_text and "�" not in new_text :
491491 # utf-8 char at the end means it's a potential unfinished byte sequence
492492 # from byte fallback tokenization.
493493 # If it's in the middle, it's probably a real invalid id generated
494494 # by the model
495- prefix_index = new_text .index (prefix_text )
496- new_text = new_text [prefix_index + len (prefix_text ) :]
497- return new_text , read_offset , len (all_input_ids )
495+ if new_text .startswith (prefix_text ):
496+ prefix_index = new_text .index (prefix_text )
497+ new_text = new_text [prefix_index + len (prefix_text ) :]
498+ return new_text , read_offset , len (all_input_ids )
499+ else :
500+ return "" , prefix_offset , len (all_input_ids )
498501 else :
502+ if len (all_input_ids [prefix_offset :]) > 3 :
503+ return new_text , len (all_input_ids ), len (all_input_ids )
499504 return "" , prefix_offset , read_offset
500505
501506
You can’t perform that action at this time.
0 commit comments