@@ -145,15 +145,31 @@ def _get_num_new_tokens(self, request, token_budget):
145145 if inputs .get ("patch_idx" , None ) is not None and inputs .get ("patch_map" , None ) is not None :
146146 pre_end_idx = request .num_computed_tokens
147147 new_end_idx = pre_end_idx + num_new_tokens
148+
149+ prompt_token_ids_len = len (request .prompt_token_ids )
150+ assert prompt_token_ids_len == len (inputs ["patch_idx" ]), (prompt_token_ids_len , len (inputs ["patch_idx" ]))
151+
148152 # start
149- start_patch_idx = inputs ["patch_idx" ][pre_end_idx ]
153+ if pre_end_idx >= prompt_token_ids_len :
154+ start_patch_idx = inputs ["patch_idx" ][- 1 ]
155+ else :
156+ start_patch_idx = inputs ["patch_idx" ][pre_end_idx ]
150157 start_patch_map = inputs ["patch_map" ][start_patch_idx ]
151158 request .image_start = start_patch_map ["image_num" ]
152159 request .video_start = start_patch_map ["video_num" ]
153160 request .audio_start = start_patch_map ["audio_num" ]
154161
155162 # end
156- end_patch_idx = inputs ["patch_idx" ][new_end_idx ]
163+ if new_end_idx >= prompt_token_ids_len :
164+ end_patch_idx = inputs ["patch_idx" ][- 1 ]
165+ else :
166+ end_patch_idx = inputs ["patch_idx" ][new_end_idx ]
167+ if request .prompt_token_ids [new_end_idx ] in [
168+ inputs ["image_end_id" ],
169+ inputs ["video_end_id" ],
170+ inputs ["audio_end_id" ],
171+ ]:
172+ end_patch_idx -= 1
157173 end_patch_map = inputs ["patch_map" ][end_patch_idx ]
158174 end_modal_id = end_patch_map ["modal_id" ]
159175 if end_modal_id > 0 :
0 commit comments