Skip to content

Simplify (and fix) passing of guided decoding backend options #17008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
360026f
Split `guided_decoding_backend` into `guided_decoding_backend` and `g…
hmellor Apr 22, 2025
0390b6f
Typo
hmellor Apr 22, 2025
dc9f57b
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 24, 2025
49b89e5
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 25, 2025
8a1ed06
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 28, 2025
7a660b0
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 28, 2025
592ff3d
Fix typo
hmellor Apr 28, 2025
2979c31
Use bool flags instead of cramming many flags into a dict
hmellor Apr 28, 2025
8a367ba
Fix word order in args
hmellor Apr 28, 2025
b9d4074
Fix missing arg
hmellor Apr 28, 2025
0884458
Enforce disable additional properties only works for guidance
hmellor Apr 28, 2025
4ebe137
Missed the EngineArgs field
hmellor Apr 28, 2025
05ab20b
Add backward compatible deprecated `guided_decoding_backend`
hmellor Apr 28, 2025
a34a434
Fix incorrect attribute in xgrammar
hmellor Apr 28, 2025
ae2f1c3
Fix test parameters
hmellor Apr 29, 2025
666ba39
Merge `Literal`s with `Literal` not `Union`
hmellor Apr 29, 2025
6df5dbb
Enforce that `Literal`s are merged with `Literal` not `Union`
hmellor Apr 29, 2025
08d5e20
Add tests for `config` decorator
hmellor Apr 29, 2025
92300b6
Create new helper function to handle sequences of literals
hmellor Apr 29, 2025
88c1479
Add test for literal to kwarg
hmellor Apr 29, 2025
4037013
Add test cases for `list[Literal]` and `Literal[Literal, Literal]`
hmellor Apr 29, 2025
0ecf76e
Fix pre-commit
hmellor Apr 29, 2025
8beb8df
Respond to comment
hmellor Apr 29, 2025
e44135c
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 29, 2025
008b037
Merge branch 'config-literal-handling' into split-guided-decoding-bac…
hmellor Apr 29, 2025
39557d7
Merge branch 'main' into split-guided-decoding-backend-options
hmellor Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
"[email protected]\n")

try:
# The no-fallback option forces vLLM to use xgrammar, so when it fails
# you get a 400 with the reason why
# The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create(
model=model,
messages=[{
Expand All @@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str):
extra_body={
"guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"],
"guided_decoding_backend": "xgrammar:no-fallback"
"guided_decoding_backend": "xgrammar",
"guided_decoding_disable_fallback": True,
},
)
return completion.choices[0].message.content
Expand Down
206 changes: 125 additions & 81 deletions tests/entrypoints/llm/test_guided_generate.py

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):

def test_guided_decoding_backend_options():
"""Test backend-specific options"""
params = GuidedDecodingParams(
backend="xgrammar:option-1,option-2,option-3")
assert params.backend_options() == ["option-1", "option-2", "option-3"]

no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback")
assert no_fallback.no_fallback()
with pytest.warns(DeprecationWarning):
guided_decoding_params = GuidedDecodingParams(
backend=
"xgrammar:no-fallback,disable-any-whitespace,no-additional-properties"
)
assert guided_decoding_params.backend == "xgrammar"
assert guided_decoding_params.disable_fallback
assert guided_decoding_params.disable_any_whitespace
assert guided_decoding_params.disable_additional_properties


def test_pickle_xgrammar_tokenizer_data():
Expand Down
31 changes: 16 additions & 15 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
"auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace",
"auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
"mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"),
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
]

PARAMS_MODELS_TOKENIZER_MODE = [
Expand Down Expand Up @@ -73,6 +70,7 @@ def test_structured_output(
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode)

#
Expand All @@ -98,8 +96,7 @@ def test_structured_output(

generated_text = output.outputs[0].text
assert generated_text is not None
if 'disable-any-whitespace' in guided_decoding_backend:
assert "\n" not in generated_text
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
Expand Down Expand Up @@ -520,10 +517,11 @@ def test_structured_output_auto_mode(
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")

backend = 'guidance:no-additional-properties,disable-any-whitespace'
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=1024,
guided_decoding_backend=backend)
guided_decoding_backend="guidance",
guided_decoding_disable_any_whitespace=True,
guided_decoding_disable_additional_properties=True)

schema = {
'type': 'object',
Expand All @@ -548,7 +546,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
"<|im_end|>\n<|im_start|>assistant\n")

def generate_with_backend(backend):
guided_params = GuidedDecodingParams(json=schema, backend=backend)
guided_params = GuidedDecodingParams(
json=schema,
backend=backend,
disable_any_whitespace=True,
disable_additional_properties=True)
sampling_params = SamplingParams(temperature=0,
max_tokens=256,
guided_decoding=guided_params)
Expand All @@ -562,8 +564,7 @@ def generate_with_backend(backend):
jsonschema.validate(instance=parsed_json, schema=schema)
return parsed_json

generated = generate_with_backend(
'guidance:no-additional-properties,disable-any-whitespace')
generated = generate_with_backend("guidance")
assert "a1" in generated
assert "a2" in generated
assert "a3" in generated
Expand Down
3 changes: 2 additions & 1 deletion tests/v1/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def test_unsupported_configs(monkeypatch):
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
guided_decoding_backend="lm-format-enforcer:no-fallback",
guided_decoding_backend="lm-format-enforcer",
guided_decoding_disable_fallback=True,
).create_engine_config()

with pytest.raises(NotImplementedError):
Expand Down
70 changes: 60 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, get_args, get_origin)
Optional, Protocol, TypeVar, Union, cast, get_args,
get_origin)

import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated

import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
Expand All @@ -32,7 +34,6 @@
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
from vllm.sampling_params import GuidedDecodingParams
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
Expand Down Expand Up @@ -344,7 +345,7 @@ def compute_hash(self) -> str:
def __init__(
self,
model: str,
task: Union[TaskOption, Literal["draft"]],
task: Literal[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
Expand Down Expand Up @@ -701,7 +702,7 @@ def _get_preferred_task(

def _resolve_task(
self,
task_option: Union[TaskOption, Literal["draft"]],
task_option: Literal[TaskOption, Literal["draft"]],
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
if task_option == "draft":
return {"draft"}, "draft"
Expand Down Expand Up @@ -3185,13 +3186,36 @@ def get_served_model_name(model: str,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine."""

guided_decoding_backend: GuidedDecodingBackend = \
"auto" if envs.VLLM_USE_V1 else "xgrammar"
@property
@deprecated(
"`guided_decoding_backend` is deprecated and has been renamed to "
"`backend`. This will be removed in v0.10.0. Please use the "
"`backend` argument instead.")
def guided_decoding_backend(self) -> GuidedDecodingBackend:
return self.backend

@guided_decoding_backend.setter
def guided_decoding_backend(self, value: GuidedDecodingBackend):
self.backend = value

backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
is subject to change in each release."""

disable_fallback: bool = False
"""If `True`, vLLM will not fallback to a different backend on error."""

disable_any_whitespace: bool = False
"""If `True`, the model will not generate any whitespace during guided
decoding. This is only supported for xgrammar and guidance backends."""

disable_additional_properties: bool = False
"""If `True`, the `guidance` backend will not use `additionalProperties`
in the JSON schema. This is only supported for the `guidance` backend and
is used to better align its behaviour with `outlines` and `xgrammar`."""

reasoning_backend: Optional[str] = None
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.
Expand All @@ -3217,15 +3241,41 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if ":" in self.backend:
self._extract_backend_options()

if envs.VLLM_USE_V1:
valid_guided_backends = get_args(GuidedDecodingBackendV1)
else:
valid_guided_backends = get_args(GuidedDecodingBackendV0)
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
if self.backend not in valid_guided_backends:
raise ValueError(f"Invalid backend '{self.backend}',"
f" must be one of {valid_guided_backends}")
if (self.disable_any_whitespace
and self.backend not in ("xgrammar", "guidance")):
raise ValueError("disable_any_whitespace is only supported for "
"xgrammar and guidance backends.")
if (self.disable_additional_properties and self.backend != "guidance"):
raise ValueError("disable_additional_properties is only supported "
"for the guidance backend.")

@deprecated(
"Passing guided decoding backend options inside backend in the format "
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
"use the dedicated arguments '--disable-fallback', "
"'--disable-any-whitespace' and '--disable-additional-properties' "
"instead.")
def _extract_backend_options(self):
"""Extract backend options from the backend string."""
backend, options = self.backend.split(":")
self.backend = cast(GuidedDecodingBackend, backend)
options_set = set(options.strip().split(","))
if "no-fallback" in options_set:
self.disable_fallback = True
if "disable-any-whitespace" in options_set:
self.disable_any_whitespace = True
if "no-additional-properties" in options_set:
self.disable_additional_properties = True


@dataclass
Expand Down
36 changes: 26 additions & 10 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, MultiModalConfig,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
Expand Down Expand Up @@ -317,7 +317,12 @@ class EngineArgs:
bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input

guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
guided_decoding_disable_any_whitespace: bool = \
DecodingConfig.disable_any_whitespace
guided_decoding_disable_additional_properties: bool = \
DecodingConfig.disable_additional_properties
logits_processor_pattern: Optional[str] = None

speculative_config: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -498,9 +503,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
title="DecodingConfig",
description=DecodingConfig.__doc__,
)
guided_decoding_group.add_argument("--guided-decoding-backend",
**guided_decoding_kwargs["backend"])
guided_decoding_group.add_argument(
'--guided-decoding-backend',
**guided_decoding_kwargs["guided_decoding_backend"])
"--guided-decoding-disable-fallback",
**guided_decoding_kwargs["disable_fallback"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-any-whitespace",
**guided_decoding_kwargs["disable_any_whitespace"])
guided_decoding_group.add_argument(
"--guided-decoding-disable-additional-properties",
**guided_decoding_kwargs["disable_additional_properties"])
guided_decoding_group.add_argument(
"--reasoning-parser",
# This choices is a special case because it's not static
Expand Down Expand Up @@ -1244,7 +1257,11 @@ def create_engine_config(
if self.enable_prompt_adapter else None

decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend,
backend=self.guided_decoding_backend,
disable_fallback=self.guided_decoding_disable_fallback,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
disable_additional_properties=\
self.guided_decoding_disable_additional_properties,
reasoning_backend=self.reasoning_parser
if self.enable_reasoning else None,
)
Expand Down Expand Up @@ -1335,9 +1352,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=True)
return False

# remove backend options when doing this check
if self.guided_decoding_backend.split(':')[0] \
not in get_args(GuidedDecodingBackendV1):
if self.guided_decoding_backend not in get_args(
GuidedDecodingBackendV1):
_raise_or_fallback(
feature_name=
f"--guided-decoding-backend={self.guided_decoding_backend}",
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,7 @@ def _build_logits_processors(

tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
self.decoding_config.backend

if self.decoding_config.reasoning_backend is not None:
logger.debug("Building with reasoning backend %s",
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ async def _process_request(
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=(self.decoding_config.guided_decoding_backend
default_guided_backend=(self.decoding_config.backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
else DecodingConfig.backend),
model_config=self.model_config,
reasoning_backend=self.decoding_config.reasoning_backend,
)
Expand Down
Loading