Skip to content

Commit 122da1c

Browse files
russellbJC1DA
andcommitted
[V1] guidance backend for structured output
This is the V1 integration for [guidance](https://github.com/guidance-ai/llguidance) as a backend for structured output. There is a V0 integration in vllm-project#14589. This backend provides some key benefits to V1: * Broader jsonschema support * Quick startup performance for large schemas Instead of precomputing the masks for all states, this is done on the fly. We see very fast request startup times, even for large schemas. This should make V1 roughly feature equivalent to V0 in terms of the types of schemas it can support. More technical details are available in the llguidance git repo. Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: Loc Huynh <[email protected]>
1 parent 3a1e648 commit 122da1c

File tree

7 files changed

+163
-10
lines changed

7 files changed

+163
-10
lines changed

benchmarks/benchmark_serving_structured_output.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -989,11 +989,12 @@ def main(args: argparse.Namespace):
989989
type=float,
990990
default=1.0,
991991
help="Ratio of Structured Outputs requests")
992-
parser.add_argument("--structured-output-backend",
993-
type=str,
994-
choices=["outlines", "lm-format-enforcer", "xgrammar"],
995-
default="xgrammar",
996-
help="Backend to use for structured outputs")
992+
parser.add_argument(
993+
"--structured-output-backend",
994+
type=str,
995+
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
996+
default="xgrammar",
997+
help="Backend to use for structured outputs")
997998

998999
args = parser.parse_args()
9991000
main(args)

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ lm-format-enforcer >= 0.10.11, < 0.11
2121
outlines == 0.1.11
2222
lark == 1.2.2
2323
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
24+
llguidance==0.6.31
2425
typing_extensions >= 4.10
2526
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
2627
partial-json-parser # used for parsing partial JSON outputs

vllm/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2785,10 +2785,17 @@ def compute_hash(self) -> str:
27852785
return hash_str
27862786

27872787
def __post_init__(self):
2788-
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
2788+
v0_valid_guided_backends = [
2789+
'outlines', 'lm-format-enforcer', 'xgrammar'
2790+
]
2791+
v1_valid_guided_backends = ['xgrammar', 'guidance']
27892792

27902793
backend = GuidedDecodingParams(
27912794
backend=self.guided_decoding_backend).backend_name
2795+
if envs.VLLM_USE_V1:
2796+
valid_guided_backends = v1_valid_guided_backends
2797+
else:
2798+
valid_guided_backends = v0_valid_guided_backends
27922799
if backend not in valid_guided_backends:
27932800
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
27942801
f" must be one of {valid_guided_backends}")

vllm/v1/engine/processor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from vllm.sampling_params import SamplingParams
2121
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2222
from vllm.v1.engine import EngineCoreRequest
23-
from vllm.v1.structured_output.utils import validate_structured_output_request
23+
from vllm.v1.structured_output.utils import (
24+
validate_structured_output_request_xgrammar)
2425

2526

2627
class Processor:
@@ -120,7 +121,7 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
120121
if not params.guided_decoding or not self.decoding_config:
121122
return
122123

123-
supported_backends = ["xgrammar"]
124+
supported_backends = ["xgrammar", "guidance"]
124125
engine_level_backend = self.decoding_config.guided_decoding_backend
125126
if engine_level_backend not in supported_backends:
126127
raise ValueError(f"Only {supported_backends} structured output is "
@@ -137,7 +138,8 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
137138
if vllm.platforms.current_platform.is_tpu():
138139
raise ValueError("Structured output is not supported on TPU.")
139140

140-
validate_structured_output_request(params)
141+
if engine_level_backend == "xgrammar":
142+
validate_structured_output_request_xgrammar(params)
141143

142144
def process_inputs(
143145
self,

vllm/v1/structured_output/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
10+
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
1011
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
1112
StructuredOutputGrammar)
1213
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
@@ -48,6 +49,8 @@ def grammar_init(self, request: Request) -> None:
4849
backend_name = request.sampling_params.guided_decoding.backend_name
4950
if backend_name == "xgrammar":
5051
self.backend = XgrammarBackend(self.vllm_config)
52+
elif backend_name == "guidance":
53+
self.backend = GuidanceBackend(self.vllm_config)
5154
else:
5255
raise ValueError(
5356
f"Unsupported structured output backend: {backend_name}")
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import json
4+
import os
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING
7+
8+
import torch
9+
10+
from vllm.config import VllmConfig
11+
from vllm.logger import init_logger
12+
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
13+
from vllm.utils import LazyLoader
14+
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
15+
StructuredOutputGrammar,
16+
StructuredOutputOptions)
17+
18+
if TYPE_CHECKING:
19+
import llguidance
20+
import llguidance.hf as llguidance_hf
21+
import llguidance.torch as llguidance_torch
22+
else:
23+
llguidance = LazyLoader("llguidance", globals(), "llguidance")
24+
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
25+
llguidance_torch = LazyLoader("llguidance.torch", globals(),
26+
"llguidance.torch")
27+
28+
logger = init_logger(__name__)
29+
30+
31+
class GuidanceBackend(StructuredOutputBackend):
32+
33+
def __init__(self, vllm_config: VllmConfig):
34+
self.vllm_config = vllm_config
35+
tokenizer_group = init_tokenizer_from_configs(
36+
model_config=vllm_config.model_config,
37+
scheduler_config=vllm_config.scheduler_config,
38+
parallel_config=vllm_config.parallel_config,
39+
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
40+
tokenizer_group.ping()
41+
self.vllm_config = vllm_config
42+
self.vocab_size = vllm_config.model_config.get_vocab_size()
43+
44+
tokenizer = tokenizer_group.get_lora_tokenizer(None)
45+
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
46+
47+
def compile_grammar(self, request_type: StructuredOutputOptions,
48+
grammar_spec: str) -> StructuredOutputGrammar:
49+
50+
if request_type == StructuredOutputOptions.JSON:
51+
if isinstance(grammar_spec, dict):
52+
schema = json.dumps(grammar_spec)
53+
else:
54+
schema = str(grammar_spec)
55+
56+
# TODO: make whitespace_flexible configurable
57+
compiler = llguidance.JsonCompiler(whitespace_flexible=False)
58+
self.serialized_grammar = compiler.compile(schema)
59+
elif request_type == StructuredOutputOptions.JSON_OBJECT:
60+
compiler = llguidance.JsonCompiler(whitespace_flexible=False)
61+
self.serialized_grammar = compiler.compile('{"type": "object"}')
62+
elif (request_type == StructuredOutputOptions.REGEX
63+
or request_type == StructuredOutputOptions.CHOICE):
64+
compiler = llguidance.RegexCompiler()
65+
self.serialized_grammar = compiler.compile(regex=grammar_spec)
66+
elif request_type == StructuredOutputOptions.GRAMMAR:
67+
if isinstance(grammar_spec, dict):
68+
self.serialized_grammar = json.dumps(grammar_spec)
69+
else:
70+
self.serialized_grammar = str(grammar_spec)
71+
else:
72+
logger.error(
73+
"Validation should have already occurred. Please file an issue."
74+
)
75+
raise ValueError(
76+
f"grammar is not of valid supported types. ({request_type!s})")
77+
78+
ll_interpreter = llguidance.LLInterpreter(
79+
self.ll_tokenizer,
80+
self.serialized_grammar,
81+
enable_backtrack=False,
82+
enable_ff_tokens=False,
83+
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
84+
)
85+
86+
return GuidanceGrammar(
87+
ll_interpreter=ll_interpreter,
88+
ll_tokenizer=self.ll_tokenizer,
89+
vocab_size=self.vocab_size,
90+
)
91+
92+
def allocate_token_bitmask(self, max_num_seqs: int):
93+
return llguidance_torch.allocate_token_bitmask(
94+
max_num_seqs, self.ll_tokenizer.vocab_size)
95+
96+
97+
@dataclass
98+
class GuidanceGrammar(StructuredOutputGrammar):
99+
100+
ll_interpreter: llguidance.LLInterpreter
101+
ll_tokenizer: llguidance_hf.LLTokenizer
102+
vocab_size: int
103+
stopped: bool = False
104+
105+
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
106+
"""Accepts a list of tokens and advances the FSM.
107+
108+
Returns True if the FSM was advanced successfully.
109+
Returns False if the FSM failed to advance.
110+
"""
111+
112+
if self.stopped:
113+
return True
114+
115+
for token in tokens:
116+
# TODO - Add jump decoding support in the future.
117+
# For now we turn this off when creating the LLInterpreter.
118+
#backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
119+
self.ll_interpreter.commit_token(token)
120+
121+
return True
122+
123+
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
124+
if self.ll_interpreter.has_pending_stop():
125+
# fill bitmask with eos token before is_terminated() return True
126+
eos_token = self.ll_tokenizer.eos_token
127+
bitmask[idx, :] = 0
128+
bitmask[idx, eos_token // 32] = 1 << (eos_token % 32)
129+
self.stopped = True
130+
else:
131+
llguidance_torch.fill_next_token_bitmask(self.ll_interpreter,
132+
bitmask, idx)
133+
134+
def is_terminated(self) -> bool:
135+
return self.stopped
136+
137+
def reset(self):
138+
# This method may be not needed anymore? TODO
139+
pass

vllm/v1/structured_output/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def escape_ebnf_string(s: str) -> str:
239239
return grammar
240240

241241

242-
def validate_structured_output_request(
242+
def validate_structured_output_request_xgrammar(
243243
sampling_params: SamplingParams) -> None:
244244
"""Validate that the request is supported by structured output.
245245

0 commit comments

Comments
 (0)