|
16 | 16 | # See the License for the specific language governing permissions and
|
17 | 17 | # limitations under the License.
|
18 | 18 | #
|
19 |
| -import gc |
20 | 19 | import json
|
21 | 20 | import os
|
22 | 21 | import re
|
23 |
| -import weakref |
24 | 22 |
|
25 | 23 | import jsonschema
|
26 | 24 | import pytest
|
27 |
| -import torch |
28 |
| -from vllm.entrypoints.llm import LLM |
29 | 25 | from vllm.outputs import RequestOutput
|
30 | 26 | from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
31 | 27 |
|
| 28 | +from tests.conftest import VllmRunner |
| 29 | + |
32 | 30 | os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
33 | 31 | MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
34 |
| -GUIDED_DECODING_BACKENDS = [ |
| 32 | +GuidedDecodingBackendV0 = [ |
35 | 33 | "outlines",
|
36 | 34 | "lm-format-enforcer",
|
37 |
| - "xgrammar:disable-any-whitespace", |
| 35 | + "xgrammar", |
38 | 36 | ]
|
| 37 | +GuidedDecodingBackendV1 = ["xgrammar", "guidance:disable-any-whitespace"] |
| 38 | +GuidedDecodingBackend = list( |
| 39 | + set(GuidedDecodingBackendV0 + GuidedDecodingBackendV1)) |
39 | 40 |
|
40 | 41 |
|
41 | 42 | @pytest.fixture(scope="module")
|
@@ -86,84 +87,89 @@ def sample_json_schema():
|
86 | 87 | }
|
87 | 88 |
|
88 | 89 |
|
89 |
| -def clean_up(): |
90 |
| - gc.collect() |
91 |
| - torch.npu.empty_cache() |
92 |
| - |
93 |
| - |
94 |
| -@pytest.fixture(scope="module") |
95 |
| -def llm(): |
96 |
| - # pytest caches the fixture so we use weakref.proxy to |
97 |
| - # enable garbage collection |
98 |
| - llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0) |
99 |
| - with llm.deprecate_legacy_api(): |
100 |
| - yield weakref.proxy(llm) |
101 |
| - del llm |
102 |
| - clean_up() |
103 |
| - |
104 |
| - |
105 |
| -# TODO: Add v1 fully tested |
106 |
| -@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", |
107 |
| - reason="v1 does not support guided decoding") |
108 |
| -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) |
109 |
| -def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): |
110 |
| - sampling_params = SamplingParams(temperature=0.8, |
111 |
| - top_p=0.95, |
112 |
| - guided_decoding=GuidedDecodingParams( |
113 |
| - regex=sample_regex, |
114 |
| - backend=guided_decoding_backend)) |
115 |
| - print(f"Using backend: {guided_decoding_backend}") |
116 |
| - outputs = llm.generate(prompts=[ |
117 |
| - f"Give an example IPv4 address with this regex: {sample_regex}" |
118 |
| - ] * 2, |
119 |
| - sampling_params=sampling_params, |
120 |
| - use_tqdm=True) |
121 |
| - |
122 |
| - assert outputs is not None |
123 |
| - for output in outputs: |
124 |
| - assert output is not None |
125 |
| - assert isinstance(output, RequestOutput) |
126 |
| - prompt = output.prompt |
127 |
| - generated_text = output.outputs[0].text |
128 |
| - print(generated_text) |
129 |
| - assert generated_text is not None |
130 |
| - assert re.fullmatch(sample_regex, generated_text) is not None |
131 |
| - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
132 |
| - |
133 |
| - |
134 |
| -@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1", |
135 |
| - reason="v1 does not support guided decoding") |
136 |
| -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) |
137 |
| -def test_guided_json_completion(sample_json_schema, llm, |
138 |
| - guided_decoding_backend: str): |
139 |
| - if guided_decoding_backend == "xgrammar:disable-any-whitespace": |
| 90 | +@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) |
| 91 | +def test_guided_json_completion(guided_decoding_backend: str, |
| 92 | + sample_json_schema): |
| 93 | + if guided_decoding_backend == "xgrammar": |
140 | 94 | # xgrammar does not support json schema, will fall back to outlines, skip it
|
141 | 95 | pytest.skip(
|
142 |
| - f"{guided_decoding_backend} does not support json schema validation" |
143 |
| - ) |
| 96 | + f"{guided_decoding_backend} will fall back to outlines, skip it") |
| 97 | + if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( |
| 98 | + "VLLM_USE_V1") == "0": |
| 99 | + # guidance does not support on v0, skip it |
| 100 | + pytest.skip( |
| 101 | + f"{guided_decoding_backend} does not support on v0, skip it") |
| 102 | + if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( |
| 103 | + "VLLM_USE_V1") == "1": |
| 104 | + pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") |
| 105 | + |
| 106 | + sampling_params = SamplingParams( |
| 107 | + temperature=1.0, |
| 108 | + max_tokens=1000, |
| 109 | + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) |
| 110 | + with VllmRunner( |
| 111 | + MODEL_NAME, |
| 112 | + seed=0, |
| 113 | + dtype="auto", |
| 114 | + guided_decoding_backend=guided_decoding_backend, |
| 115 | + ) as vllm_model: |
| 116 | + prompts = [ |
| 117 | + f"Give an example JSON for an employee profile " |
| 118 | + f"that fits this schema: {sample_json_schema}" |
| 119 | + ] * 2 |
| 120 | + inputs = vllm_model.get_inputs(prompts) |
| 121 | + outputs = vllm_model.model.generate(inputs, |
| 122 | + sampling_params=sampling_params) |
| 123 | + |
| 124 | + assert outputs is not None |
| 125 | + |
| 126 | + for output in outputs: |
| 127 | + assert output is not None |
| 128 | + assert isinstance(output, RequestOutput) |
| 129 | + prompt = output.prompt |
| 130 | + |
| 131 | + generated_text = output.outputs[0].text |
| 132 | + assert generated_text is not None |
| 133 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 134 | + output_json = json.loads(generated_text) |
| 135 | + jsonschema.validate(instance=output_json, |
| 136 | + schema=sample_json_schema) |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend) |
| 140 | +def test_guided_regex(guided_decoding_backend: str, sample_regex): |
| 141 | + if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( |
| 142 | + "VLLM_USE_V1") == "0": |
| 143 | + # guidance does not support on v0, skip it |
| 144 | + pytest.skip( |
| 145 | + f"{guided_decoding_backend} does not support on v0, skip it") |
| 146 | + if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv( |
| 147 | + "VLLM_USE_V1") == "1": |
| 148 | + pytest.skip(f"{guided_decoding_backend} does not support v1, skip it") |
144 | 149 |
|
145 |
| - sampling_params = SamplingParams(temperature=1.0, |
146 |
| - max_tokens=1000, |
| 150 | + sampling_params = SamplingParams(temperature=0.8, |
| 151 | + top_p=0.95, |
147 | 152 | guided_decoding=GuidedDecodingParams(
|
148 |
| - json=sample_json_schema, |
149 |
| - backend=guided_decoding_backend)) |
150 |
| - print(f"Using backend: {guided_decoding_backend}") |
151 |
| - outputs = llm.generate(prompts=[ |
152 |
| - f"Give an example JSON for an employee profile " |
153 |
| - f"that fits this schema: {sample_json_schema}" |
154 |
| - ] * 2, |
155 |
| - sampling_params=sampling_params, |
156 |
| - use_tqdm=True) |
157 |
| - |
158 |
| - assert outputs is not None |
159 |
| - |
160 |
| - for output in outputs: |
161 |
| - assert output is not None |
162 |
| - assert isinstance(output, RequestOutput) |
163 |
| - prompt = output.prompt |
164 |
| - |
165 |
| - generated_text = output.outputs[0].text |
166 |
| - assert generated_text is not None |
167 |
| - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
168 |
| - output_json = json.loads(generated_text) |
169 |
| - jsonschema.validate(instance=output_json, schema=sample_json_schema) |
| 153 | + regex=sample_regex, )) |
| 154 | + with VllmRunner( |
| 155 | + MODEL_NAME, |
| 156 | + seed=0, |
| 157 | + dtype="auto", |
| 158 | + guided_decoding_backend=guided_decoding_backend, |
| 159 | + ) as vllm_model: |
| 160 | + prompts = [ |
| 161 | + f"Give an example IPv4 address with this regex: {sample_regex}" |
| 162 | + ] * 2 |
| 163 | + inputs = vllm_model.get_inputs(prompts) |
| 164 | + outputs = vllm_model.model.generate(inputs, |
| 165 | + sampling_params=sampling_params) |
| 166 | + assert outputs is not None |
| 167 | + for output in outputs: |
| 168 | + assert output is not None |
| 169 | + assert isinstance(output, RequestOutput) |
| 170 | + prompt = output.prompt |
| 171 | + generated_text = output.outputs[0].text |
| 172 | + print(generated_text) |
| 173 | + assert generated_text is not None |
| 174 | + assert re.fullmatch(".*", generated_text) is not None |
| 175 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
0 commit comments