Skip to content

[Feature] Prefill assistant response - add continue_final_message parameter #4226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2025
Merged
1 change: 1 addition & 0 deletions docs/backend/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu
* `n: int = 1`: Specifies the number of output sequences to generate per request. (Generating multiple outputs in one request (n > 1) is discouraged; repeat the same prompts for several times offer better control and efficiency.)
* `spaces_between_special_tokens: bool = True`: Whether or not to add spaces between special tokens during detokenization.
* `no_stop_trim: bool = False`: Don't trim stop words or EOS token from the generated text.
* `continue_final_message: bool = False` : When enabled, the final assistant message is removed and its content is used as a prefill so that the model continues that message instead of starting a new turn. See [openai_chat_with_response_prefill.py](https://github.com/sgl-project/sglang/blob/main/examples/runtime/openai_chat_with_response_prefill.py) for examples.
* `ignore_eos: bool = False`: Don't stop generation when EOS token is sampled.
* `skip_special_tokens: bool = True`: Remove special tokens during decoding.
* `custom_params: Optional[List[Optional[Dict[str, Any]]]] = None`: Used when employing `CustomLogitProcessor`. For usage see below.
Expand Down
4 changes: 3 additions & 1 deletion examples/runtime/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ The below examples will mostly need you to start a server in a separate terminal
* `multimodal_embedding.py`: An example how perform [multi modal embedding](Alibaba-NLP/gme-Qwen2-VL-2B-Instruct).
* `openai_batch_chat.py`: An example how to process batch requests for chat completions.
* `openai_batch_complete.py`: An example how to process batch requests for text completions.
* `openai_chat_with_response_prefill.py`: An example how to [prefill](https://eugeneyan.com/writing/prompting/#prefill-claudes-responses) a response using OpenAI API.
* **`openai_chat_with_response_prefill.py`**:
An example that demonstrates how to [prefill a response](https://eugeneyan.com/writing/prompting/#prefill-claudes-responses) using the OpenAI API by enabling the `continue_final_message` parameter.
When enabled, the final (partial) assistant message is removed and its content is used as a prefill so that the model continues that message rather than starting a new turn. See [Anthropic's prefill example](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/prefill-claudes-response#example-structured-data-extraction-with-prefilling) for more context.
* `reward_model.py`: An example how to extract scores from a reward model.
* `vertex_predict.py`: An example how to deploy a model to [Vertex AI](https://cloud.google.com/vertex-ai?hl=en).

Expand Down
54 changes: 37 additions & 17 deletions examples/runtime/openai_chat_with_response_prefill.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,53 @@
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
python openai_chat.py
1) Launch the server in one terminal:
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000

2) Run this script in another terminal:
python openai_chat_with_response_prefill.py

This example demonstrates two chat completion calls:
- One with continue_final_message enabled (the final assistant message is used as a prefill).
- One without continue_final_message (the final assistant message remains, starting a new turn).
"""

import openai

client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")

response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": """
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{
"role": "user",
"content": """
Extract the name, size, price, and color from this product description as a JSON object:

<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99.
At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—
no matter where you place it in your home.
This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
},
{
"role": "assistant",
"content": "{\n",
},
],
},
{"role": "assistant", "content": "{\n"},
]

# Calling the API with continue_final_message enabled.
print("=== Prefill with continue_final_messagem ===")
response_with = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
temperature=0,
extra_body={"continue_final_message": True},
)
print(response_with.choices[0].message.content)

print(response.choices[0].message.content)
# Calling the API without continue_final_message (using default behavior).
print("\n=== Prefill without continue_final_message ===")
response_without = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=messages,
temperature=0,
)
print(response_without.choices[0].message.content)
42 changes: 38 additions & 4 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,9 +950,16 @@ def v1_chat_generate_request(
openai_compatible_messages.append(
{"role": message.role, "content": content["text"]}
)
if openai_compatible_messages[-1]["role"] == "assistant":
assistant_prefix = openai_compatible_messages[-1]["content"]
openai_compatible_messages = openai_compatible_messages[:-1]
if (
openai_compatible_messages
and openai_compatible_messages[-1]["role"] == "assistant"
):
if request.continue_final_message:
# Remove the final assistant message so its content can be continued.
assistant_prefix = openai_compatible_messages[-1]["content"]
openai_compatible_messages = openai_compatible_messages[:-1]
else:
assistant_prefix = None
else:
assistant_prefix = None

Expand Down Expand Up @@ -991,7 +998,33 @@ def v1_chat_generate_request(
modalities = []
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
# If we should continue the final assistant message, adjust the conversation.
if (
request.continue_final_message
and request.messages
and request.messages[-1].role == "assistant"
):
# Remove the auto-added blank assistant turn, if present.
if conv.messages and conv.messages[-1][1] is None:
conv.messages.pop()
# Rebuild the prompt from the conversation.
prompt = conv.get_prompt()
# Strip any trailing stop tokens or separators that indicate end-of-assistant.
if isinstance(conv.stop_str, list):
for stop_token in conv.stop_str:
if prompt.endswith(stop_token):
prompt = prompt[: -len(stop_token)]
elif isinstance(conv.stop_str, str) and prompt.endswith(
conv.stop_str
):
prompt = prompt[: -len(conv.stop_str)]
if conv.sep and prompt.endswith(conv.sep):
prompt = prompt[: -len(conv.sep)]
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
prompt = prompt[: -len(conv.sep2)]
else:
prompt = conv.get_prompt()

image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
Expand All @@ -1002,6 +1035,7 @@ def v1_chat_generate_request(
else:
stop.extend(request.stop)
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)

else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def set_tool_choice_default(cls, values):
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
Expand Down
1 change: 1 addition & 0 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def test_response_prefill(self):
},
],
temperature=0,
extra_body={"continue_final_message": True},
)

assert (
Expand Down
Loading