Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ class InferenceEngineInput(TypedDict):
class InferenceEngineOutput(TypedDict):
# We always return both tokens and text outputs. The tokens are the outputs
# of inference engine, and the text is the decoded text output. Therefore,
# it is guaranteed that tokenizer.decode(response_token_ids) == responses,
# it is guaranteed that tokenizer.decode(response_token_ids, skip_special_tokens=True) == responses,
# but the reverse is not guaranteed, since there are multiple ways to
# represent the same text with tokens. Therefore, for multi-turn generation,
# please use token-in-token-out to ensure correctness.
# `skip_special_tokens=True` is needed because string responses do not include EOS tokens like `<|im_end|>`
responses: List[str]
response_ids: List[List[int]]
stop_reasons: List[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
output_ids.append(cur_output_ids)
# SGLang only returns tokens not text when skip_tokenizer_init is True, so
# we manually decode it.
outputs.append(self.tokenizer.decode(cur_output_ids))
outputs.append(self.tokenizer.decode(cur_output_ids, skip_special_tokens=True))
finish_reasons.append(output["meta_info"]["finish_reason"]["type"])
else:
raise ValueError(f"Invalid engine backend: {self.engine_backend}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _postprocess_outputs(self, outputs):

for output in outputs:
response_ids.append(output["output_ids"])
responses.append(self.tokenizer.decode(output["output_ids"]))
responses.append(self.tokenizer.decode(output["output_ids"], skip_special_tokens=True))
stop_reasons.append(output["meta_info"]["finish_reason"]["type"])

return InferenceEngineOutput(
Expand Down