Skip to content

Commit d401407

Browse files
yyihuangyhyang201hnyls2002zhaochenyang20zhyncs
authored andcommitted
[Model Support] unsloth/Phi-4-mini bnb model (sgl-project#4982)
Co-authored-by: yhyang201 <[email protected]> Co-authored-by: Liangsheng Yin <[email protected]> Co-authored-by: Chayenne <[email protected]> Co-authored-by: Yineng Zhang <[email protected]>
1 parent adf9fd8 commit d401407

File tree

3 files changed

+235
-8
lines changed

3 files changed

+235
-8
lines changed

python/sglang/srt/layers/linear.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
22

3+
import itertools
34
import logging
45
from abc import abstractmethod
56
from typing import Dict, List, Optional, Tuple
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
6162

6263

6364
def adjust_bitsandbytes_4bit_shard(
64-
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
65+
param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
6566
) -> Tuple[int, int]:
6667
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
6768

68-
total, _ = qkv_offsets["total"]
69-
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
69+
total, _ = shard_offsets["total"]
70+
orig_offset, orig_size = shard_offsets[loaded_shard_id]
7071

7172
quantized_total = param.data.shape[0]
7273
quantized_offset = orig_offset * quantized_total // total
@@ -573,6 +574,8 @@ def weight_loader(
573574
shard_offsets.append((i, current_shard_offset, output_size))
574575
current_shard_offset += output_size
575576
packed_dim = getattr(param, "packed_dim", None)
577+
578+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
576579
for shard_id, shard_offset, shard_size in shard_offsets:
577580
# Special case for Quantization.
578581
# If quantized, we need to adjust the offset and size to account
@@ -585,6 +588,17 @@ def weight_loader(
585588
param, shard_size, shard_offset
586589
)
587590

591+
if use_bitsandbytes_4bit:
592+
index = list(itertools.accumulate([0] + self.output_sizes))
593+
orig_offsets = {
594+
str(i): (index[i], size)
595+
for i, size in enumerate(self.output_sizes)
596+
}
597+
orig_offsets["total"] = (self.output_size, 0)
598+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
599+
param, orig_offsets, str(shard_id)
600+
)
601+
588602
loaded_weight_shard = loaded_weight.narrow(
589603
output_dim, shard_offset, shard_size
590604
)

python/sglang/srt/models/llama.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
362362
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
363363
bitsandbytes_stacked_params_mapping = {
364364
# shard_name, weight_name, index
365-
"q_proj": ("qkv_proj", 0),
366-
"k_proj": ("qkv_proj", 1),
367-
"v_proj": ("qkv_proj", 2),
368-
"gate_proj": ("gate_up_proj", 0),
369-
"up_proj": ("gate_up_proj", 1),
365+
".q_proj": (".qkv_proj", 0),
366+
".k_proj": (".qkv_proj", 1),
367+
".v_proj": (".qkv_proj", 2),
368+
".gate_proj": (".gate_up_proj", 0),
369+
".up_proj": (".gate_up_proj", 1),
370370
}
371371

372372
def __init__(
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import unittest
2+
from types import SimpleNamespace
3+
4+
from sglang.srt.utils import kill_process_tree
5+
from sglang.test.few_shot_gsm8k import run_eval
6+
from sglang.test.test_utils import (
7+
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
8+
DEFAULT_URL_FOR_TEST,
9+
CustomTestCase,
10+
popen_launch_server,
11+
)
12+
13+
14+
class TestUnslothPhi4(CustomTestCase):
15+
@classmethod
16+
def setUpClass(cls):
17+
cls.model = "unsloth/phi-4"
18+
cls.base_url = DEFAULT_URL_FOR_TEST
19+
cls.process = popen_launch_server(
20+
cls.model,
21+
cls.base_url,
22+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
23+
other_args=[],
24+
)
25+
26+
@classmethod
27+
def tearDownClass(cls):
28+
kill_process_tree(cls.process.pid)
29+
30+
def test_gsm8k(self):
31+
args = SimpleNamespace(
32+
num_shots=5,
33+
data_path=None,
34+
num_questions=200,
35+
max_new_tokens=512,
36+
parallel=128,
37+
host="http://127.0.0.1",
38+
port=int(self.base_url.split(":")[-1]),
39+
)
40+
metrics = run_eval(args)
41+
print(f"{metrics=}")
42+
self.assertGreater(metrics["accuracy"], 0.78)
43+
44+
45+
class TestUnslothPhi4Bnb4bit(CustomTestCase):
46+
@classmethod
47+
def setUpClass(cls):
48+
cls.model = "unsloth/phi-4-bnb-4bit"
49+
cls.base_url = DEFAULT_URL_FOR_TEST
50+
cls.process = popen_launch_server(
51+
cls.model,
52+
cls.base_url,
53+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
54+
other_args=[
55+
"--load-format",
56+
"bitsandbytes",
57+
],
58+
)
59+
60+
@classmethod
61+
def tearDownClass(cls):
62+
kill_process_tree(cls.process.pid)
63+
64+
def test_gsm8k(self):
65+
args = SimpleNamespace(
66+
num_shots=5,
67+
data_path=None,
68+
num_questions=200,
69+
max_new_tokens=512,
70+
parallel=128,
71+
host="http://127.0.0.1",
72+
port=int(self.base_url.split(":")[-1]),
73+
)
74+
metrics = run_eval(args)
75+
print(f"{metrics=}")
76+
self.assertGreater(metrics["accuracy"], 0.75)
77+
78+
79+
class TestUnslothPhi4UnslothBnb4bit(CustomTestCase):
80+
@classmethod
81+
def setUpClass(cls):
82+
cls.model = "unsloth/phi-4-unsloth-bnb-4bit"
83+
cls.base_url = DEFAULT_URL_FOR_TEST
84+
cls.process = popen_launch_server(
85+
cls.model,
86+
cls.base_url,
87+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
88+
other_args=[
89+
"--load-format",
90+
"bitsandbytes",
91+
],
92+
)
93+
94+
@classmethod
95+
def tearDownClass(cls):
96+
kill_process_tree(cls.process.pid)
97+
98+
def test_gsm8k(self):
99+
args = SimpleNamespace(
100+
num_shots=5,
101+
data_path=None,
102+
num_questions=200,
103+
max_new_tokens=512,
104+
parallel=128,
105+
host="http://127.0.0.1",
106+
port=int(self.base_url.split(":")[-1]),
107+
)
108+
metrics = run_eval(args)
109+
print(f"{metrics=}")
110+
self.assertGreater(metrics["accuracy"], 0.75)
111+
112+
113+
class TestUnslothPhi4MiniInstruct(CustomTestCase):
114+
@classmethod
115+
def setUpClass(cls):
116+
cls.model = "unsloth/Phi-4-mini-instruct"
117+
cls.base_url = DEFAULT_URL_FOR_TEST
118+
cls.process = popen_launch_server(
119+
cls.model,
120+
cls.base_url,
121+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
122+
other_args=[],
123+
)
124+
125+
@classmethod
126+
def tearDownClass(cls):
127+
kill_process_tree(cls.process.pid)
128+
129+
def test_gsm8k(self):
130+
args = SimpleNamespace(
131+
num_shots=5,
132+
data_path=None,
133+
num_questions=200,
134+
max_new_tokens=512,
135+
parallel=128,
136+
host="http://127.0.0.1",
137+
port=int(self.base_url.split(":")[-1]),
138+
)
139+
metrics = run_eval(args)
140+
print(f"{metrics=}")
141+
self.assertGreater(metrics["accuracy"], 0.65)
142+
143+
144+
class TestUnslothPhi4MiniBnb4bit(CustomTestCase):
145+
@classmethod
146+
def setUpClass(cls):
147+
cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit"
148+
cls.base_url = DEFAULT_URL_FOR_TEST
149+
cls.process = popen_launch_server(
150+
cls.model,
151+
cls.base_url,
152+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
153+
other_args=[
154+
"--load-format",
155+
"bitsandbytes",
156+
],
157+
)
158+
159+
@classmethod
160+
def tearDownClass(cls):
161+
kill_process_tree(cls.process.pid)
162+
163+
def test_gsm8k(self):
164+
args = SimpleNamespace(
165+
num_shots=5,
166+
data_path=None,
167+
num_questions=200,
168+
max_new_tokens=512,
169+
parallel=128,
170+
host="http://127.0.0.1",
171+
port=int(self.base_url.split(":")[-1]),
172+
)
173+
metrics = run_eval(args)
174+
print(f"{metrics=}")
175+
self.assertGreater(metrics["accuracy"], 0.6)
176+
177+
178+
class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase):
179+
@classmethod
180+
def setUpClass(cls):
181+
cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit"
182+
cls.base_url = DEFAULT_URL_FOR_TEST
183+
cls.process = popen_launch_server(
184+
cls.model,
185+
cls.base_url,
186+
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
187+
other_args=[
188+
"--load-format",
189+
"bitsandbytes",
190+
],
191+
)
192+
193+
@classmethod
194+
def tearDownClass(cls):
195+
kill_process_tree(cls.process.pid)
196+
197+
def test_gsm8k(self):
198+
args = SimpleNamespace(
199+
num_shots=5,
200+
data_path=None,
201+
num_questions=200,
202+
max_new_tokens=512,
203+
parallel=128,
204+
host="http://127.0.0.1",
205+
port=int(self.base_url.split(":")[-1]),
206+
)
207+
metrics = run_eval(args)
208+
print(f"{metrics=}")
209+
self.assertGreater(metrics["accuracy"], 0.6)
210+
211+
212+
if __name__ == "__main__":
213+
unittest.main()

0 commit comments

Comments
 (0)