Skip to content

Commit 0676604

Browse files
committed
add v1 test
Signed-off-by: wangli <[email protected]>
1 parent 21bc169 commit 0676604

File tree

1 file changed

+89
-83
lines changed

1 file changed

+89
-83
lines changed

tests/singlecard/test_guided_decoding.py

Lines changed: 89 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,27 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818
#
19-
import gc
2019
import json
2120
import os
2221
import re
23-
import weakref
2422

2523
import jsonschema
2624
import pytest
27-
import torch
28-
from vllm.entrypoints.llm import LLM
2925
from vllm.outputs import RequestOutput
3026
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
3127

28+
from tests.conftest import VllmRunner
29+
3230
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
3331
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
34-
GUIDED_DECODING_BACKENDS = [
32+
GuidedDecodingBackendV0 = [
3533
"outlines",
3634
"lm-format-enforcer",
37-
"xgrammar:disable-any-whitespace",
35+
"xgrammar",
3836
]
37+
GuidedDecodingBackendV1 = ["xgrammar", "guidance:disable-any-whitespace"]
38+
GuidedDecodingBackend = list(
39+
set(GuidedDecodingBackendV0 + GuidedDecodingBackendV1))
3940

4041

4142
@pytest.fixture(scope="module")
@@ -86,84 +87,89 @@ def sample_json_schema():
8687
}
8788

8889

89-
def clean_up():
90-
gc.collect()
91-
torch.npu.empty_cache()
92-
93-
94-
@pytest.fixture(scope="module")
95-
def llm():
96-
# pytest caches the fixture so we use weakref.proxy to
97-
# enable garbage collection
98-
llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0)
99-
with llm.deprecate_legacy_api():
100-
yield weakref.proxy(llm)
101-
del llm
102-
clean_up()
103-
104-
105-
# TODO: Add v1 fully tested
106-
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
107-
reason="v1 does not support guided decoding")
108-
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
109-
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str):
110-
sampling_params = SamplingParams(temperature=0.8,
111-
top_p=0.95,
112-
guided_decoding=GuidedDecodingParams(
113-
regex=sample_regex,
114-
backend=guided_decoding_backend))
115-
print(f"Using backend: {guided_decoding_backend}")
116-
outputs = llm.generate(prompts=[
117-
f"Give an example IPv4 address with this regex: {sample_regex}"
118-
] * 2,
119-
sampling_params=sampling_params,
120-
use_tqdm=True)
121-
122-
assert outputs is not None
123-
for output in outputs:
124-
assert output is not None
125-
assert isinstance(output, RequestOutput)
126-
prompt = output.prompt
127-
generated_text = output.outputs[0].text
128-
print(generated_text)
129-
assert generated_text is not None
130-
assert re.fullmatch(sample_regex, generated_text) is not None
131-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
132-
133-
134-
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "1",
135-
reason="v1 does not support guided decoding")
136-
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
137-
def test_guided_json_completion(sample_json_schema, llm,
138-
guided_decoding_backend: str):
139-
if guided_decoding_backend == "xgrammar:disable-any-whitespace":
90+
@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
91+
def test_guided_json_completion(guided_decoding_backend: str,
92+
sample_json_schema):
93+
if guided_decoding_backend == "xgrammar":
14094
# xgrammar does not support json schema, will fall back to outlines, skip it
14195
pytest.skip(
142-
f"{guided_decoding_backend} does not support json schema validation"
143-
)
96+
f"{guided_decoding_backend} will fall back to outlines, skip it")
97+
if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv(
98+
"VLLM_USE_V1") == "0":
99+
# guidance does not support on v0, skip it
100+
pytest.skip(
101+
f"{guided_decoding_backend} does not support on v0, skip it")
102+
if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv(
103+
"VLLM_USE_V1") == "1":
104+
pytest.skip(f"{guided_decoding_backend} does not support v1, skip it")
105+
106+
sampling_params = SamplingParams(
107+
temperature=1.0,
108+
max_tokens=1000,
109+
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
110+
with VllmRunner(
111+
MODEL_NAME,
112+
seed=0,
113+
dtype="auto",
114+
guided_decoding_backend=guided_decoding_backend,
115+
) as vllm_model:
116+
prompts = [
117+
f"Give an example JSON for an employee profile "
118+
f"that fits this schema: {sample_json_schema}"
119+
] * 2
120+
inputs = vllm_model.get_inputs(prompts)
121+
outputs = vllm_model.model.generate(inputs,
122+
sampling_params=sampling_params)
123+
124+
assert outputs is not None
125+
126+
for output in outputs:
127+
assert output is not None
128+
assert isinstance(output, RequestOutput)
129+
prompt = output.prompt
130+
131+
generated_text = output.outputs[0].text
132+
assert generated_text is not None
133+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
134+
output_json = json.loads(generated_text)
135+
jsonschema.validate(instance=output_json,
136+
schema=sample_json_schema)
137+
138+
139+
@pytest.mark.parametrize("guided_decoding_backend", GuidedDecodingBackend)
140+
def test_guided_regex(guided_decoding_backend: str, sample_regex):
141+
if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv(
142+
"VLLM_USE_V1") == "0":
143+
# guidance does not support on v0, skip it
144+
pytest.skip(
145+
f"{guided_decoding_backend} does not support on v0, skip it")
146+
if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv(
147+
"VLLM_USE_V1") == "1":
148+
pytest.skip(f"{guided_decoding_backend} does not support v1, skip it")
144149

145-
sampling_params = SamplingParams(temperature=1.0,
146-
max_tokens=1000,
150+
sampling_params = SamplingParams(temperature=0.8,
151+
top_p=0.95,
147152
guided_decoding=GuidedDecodingParams(
148-
json=sample_json_schema,
149-
backend=guided_decoding_backend))
150-
print(f"Using backend: {guided_decoding_backend}")
151-
outputs = llm.generate(prompts=[
152-
f"Give an example JSON for an employee profile "
153-
f"that fits this schema: {sample_json_schema}"
154-
] * 2,
155-
sampling_params=sampling_params,
156-
use_tqdm=True)
157-
158-
assert outputs is not None
159-
160-
for output in outputs:
161-
assert output is not None
162-
assert isinstance(output, RequestOutput)
163-
prompt = output.prompt
164-
165-
generated_text = output.outputs[0].text
166-
assert generated_text is not None
167-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
168-
output_json = json.loads(generated_text)
169-
jsonschema.validate(instance=output_json, schema=sample_json_schema)
153+
regex=sample_regex, ))
154+
with VllmRunner(
155+
MODEL_NAME,
156+
seed=0,
157+
dtype="auto",
158+
guided_decoding_backend=guided_decoding_backend,
159+
) as vllm_model:
160+
prompts = [
161+
f"Give an example IPv4 address with this regex: {sample_regex}"
162+
] * 2
163+
inputs = vllm_model.get_inputs(prompts)
164+
outputs = vllm_model.model.generate(inputs,
165+
sampling_params=sampling_params)
166+
assert outputs is not None
167+
for output in outputs:
168+
assert output is not None
169+
assert isinstance(output, RequestOutput)
170+
prompt = output.prompt
171+
generated_text = output.outputs[0].text
172+
print(generated_text)
173+
assert generated_text is not None
174+
assert re.fullmatch(".*", generated_text) is not None
175+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

0 commit comments

Comments
 (0)