Skip to content

Commit a27bfea

Browse files
Qiaolin-Yuxwu-intel
authored andcommitted
Fix lora batch processing when input lora_path contains None (sgl-project#5930)
1 parent c503bc4 commit a27bfea

File tree

4 files changed

+60
-279
lines changed

4 files changed

+60
-279
lines changed

python/sglang/srt/lora/lora_manager.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
153153
assert len(cur_uids) <= self.max_loras_per_batch
154154
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
155155

156-
# FIXME: Handle lora uid with None more safely
157-
if cur_uids == set([None]):
158-
return
159-
160156
# set up batch info shared by all lora modules
161157
bs = forward_batch.batch_size
162158

@@ -185,13 +181,14 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
185181
self.cuda_graph_batch_info.weight_indices[i] = (
186182
self.memory_pool.get_buffer_id(lora_path)
187183
)
188-
lora = self.loras[lora_path]
189-
self.cuda_graph_batch_info.lora_ranks[
190-
self.cuda_graph_batch_info.weight_indices[i]
191-
] = lora.config.hf_config["r"]
192-
self.cuda_graph_batch_info.scalings[
193-
self.cuda_graph_batch_info.weight_indices[i]
194-
] = lora.scaling
184+
if lora_path is not None:
185+
lora = self.loras[lora_path]
186+
self.cuda_graph_batch_info.lora_ranks[
187+
self.cuda_graph_batch_info.weight_indices[i]
188+
] = lora.config.hf_config["r"]
189+
self.cuda_graph_batch_info.scalings[
190+
self.cuda_graph_batch_info.weight_indices[i]
191+
] = lora.scaling
195192
batch_info = self.cuda_graph_batch_info
196193
else:
197194
seg_lens = (
@@ -212,9 +209,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
212209
)
213210
for i, lora_path in enumerate(forward_batch.lora_paths):
214211
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
215-
lora = self.loras[lora_path]
216-
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
217-
scalings[weight_indices[i]] = lora.scaling
212+
if lora_path is not None:
213+
lora = self.loras[lora_path]
214+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
215+
scalings[weight_indices[i]] = lora.scaling
218216
batch_info = LoRABatchInfo(
219217
bs=bs,
220218
seg_lens=seg_lens,

python/sglang/test/runners.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,9 @@ def forward_generation_raw(
423423
)
424424
del input_logits
425425

426-
if lora_paths is not None and lora_paths[i] is not None:
427-
# Unload the LoRA adapter if it is used
428-
model.unload()
426+
if lora_paths is not None and lora_paths[i] is not None:
427+
# Unload the LoRA adapter if it is used
428+
model.unload()
429429

430430
return ModelOutput(
431431
output_strs=output_strs,

test/srt/models/lora/test_lora.py

Lines changed: 39 additions & 260 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,10 @@
1515
import multiprocessing as mp
1616
import unittest
1717

18-
import torch
18+
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, run_lora_test_by_batch
1919

20-
from sglang.test.runners import HFRunner, SRTRunner
2120
from sglang.test.test_utils import CustomTestCase
2221

23-
LORA_SETS = [
24-
# {
25-
# "base": "meta-llama/Llama-2-7b-hf",
26-
# "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"],
27-
# },
28-
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
29-
# {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]},
30-
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
31-
# {
32-
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
33-
# "loras": [
34-
# "/home/ying/test_lora",
35-
# "/home/ying/test_lora_1",
36-
# "/home/ying/test_lora_2",
37-
# "/home/ying/test_lora_3",
38-
# "/home/ying/test_lora_4",
39-
# ],
40-
# },
41-
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]},
42-
]
43-
TORCH_DTYPES = [torch.float16]
44-
4522
PROMPTS = [
4623
"""
4724
### Instruction:
@@ -51,248 +28,50 @@
5128
The Transformers are large language models,
5229
They're used to make predictions on text.
5330
""",
54-
"""
55-
### Instruction:
56-
Tell me about llamas and alpacas
57-
### Response:
58-
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
59-
### Question 2:
60-
What do you know about llamas?
61-
### Answer:
62-
""",
31+
"AI is a field of computer science focused on",
6332
]
6433

65-
# import json
66-
#
67-
# with open("/home/ying/test_prompt/dialogue_choice_prompts.json", "r") as f:
68-
# samples = json.load(f)
69-
# for sample in samples[:5]:
70-
# assert sample[0]["role"] == "user"
71-
# PROMPTS.append(sample[0]["content"][:2000])
34+
LORA_MODELS_WITH_NONE = [
35+
LoRAModelCase(
36+
base="meta-llama/Llama-3.1-8B-Instruct",
37+
adaptors=[
38+
LoRAAdaptor(
39+
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
40+
),
41+
LoRAAdaptor(
42+
name=None,
43+
),
44+
],
45+
max_loras_per_batch=2,
46+
),
47+
LoRAModelCase(
48+
base="meta-llama/Llama-3.1-8B-Instruct",
49+
adaptors=[
50+
LoRAAdaptor(
51+
name=None,
52+
),
53+
LoRAAdaptor(
54+
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
55+
),
56+
],
57+
max_loras_per_batch=2,
58+
),
59+
]
7260

7361

7462
class TestLoRA(CustomTestCase):
75-
76-
def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
77-
print("=================== testing inference =======================")
78-
base_path = lora_set["base"]
79-
all_lora_paths = lora_set["loras"]
80-
batch_lora_paths = [None]
81-
i = 0
82-
for _ in range(len(prompts) - 1):
83-
batch_lora_paths.append(all_lora_paths[i])
84-
i = (i + 1) % len(all_lora_paths)
85-
86-
with SRTRunner(
87-
base_path,
88-
torch_dtype=torch_dtype,
89-
model_type="generation",
90-
tp_size=tp_size,
91-
lora_paths=all_lora_paths,
92-
max_loras_per_batch=3,
93-
disable_cuda_graph=True,
94-
disable_radix_cache=True,
95-
) as srt_runner:
96-
srt_outputs = srt_runner.forward(
97-
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
98-
)
99-
srt_outputs_lora_path_none = srt_runner.forward(
100-
prompts,
101-
max_new_tokens=max_new_tokens,
102-
lora_paths=[None] * len(prompts),
103-
)
104-
105-
with HFRunner(
106-
base_path, torch_dtype=torch_dtype, model_type="generation"
107-
) as hf_runner:
108-
hf_outputs = hf_runner.forward(
109-
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
110-
)
111-
112-
with HFRunner(
113-
base_path,
114-
torch_dtype=torch_dtype,
115-
model_type="generation",
116-
) as hf_runner:
117-
hf_no_lora_outputs = hf_runner.forward(
118-
prompts, max_new_tokens=max_new_tokens
119-
)
120-
121-
with SRTRunner(
122-
base_path,
123-
tp_size=tp_size,
124-
torch_dtype=torch_dtype,
125-
model_type="generation",
126-
) as srt_runner:
127-
srt_no_lora_outputs = srt_runner.forward(
128-
prompts, max_new_tokens=max_new_tokens
129-
)
130-
131-
for i in range(len(prompts)):
132-
# compare input logprobs
133-
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
134-
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
135-
hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i])
136-
srt_no_lora_logprobs = torch.Tensor(
137-
srt_no_lora_outputs.top_input_logprobs[i]
138-
)
139-
print(
140-
"max input diff between hf_lora and srt_lora",
141-
torch.max(abs(hf_logprobs - srt_logprobs)),
142-
)
143-
print(
144-
"max input diff between srt_base and srt_lora",
145-
torch.max(abs(srt_no_lora_logprobs - srt_logprobs)),
146-
)
147-
print(
148-
"max input diff between srt_base and hf_base",
149-
torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)),
150-
)
151-
print(
152-
"max input diff between hf_lora and hf_base",
153-
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
154-
)
155-
156-
# compare output logprobs
157-
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
158-
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
159-
# print(
160-
# "\noutput logprobs diff",
161-
# [
162-
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
163-
# for j in range(max_new_tokens)
164-
# ],
165-
# )
166-
print(
167-
"max output diff between hf_lora and srt_lora",
168-
torch.max(abs(hf_logprobs - srt_logprobs)),
169-
"\n",
170-
)
171-
172-
# compare output strings
173-
print(f"{hf_outputs.output_strs=}")
174-
print(f"{srt_outputs.output_strs=}")
175-
print(f"{hf_no_lora_outputs.output_strs=}")
176-
print(f"{srt_no_lora_outputs.output_strs=}")
177-
print(f"{srt_outputs_lora_path_none.output_strs=}")
178-
for i in range(len(prompts)):
179-
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[
180-
i
181-
].strip(" "), (
182-
srt_outputs.output_strs[i].strip(" "),
183-
hf_outputs.output_strs[i].strip(" "),
184-
)
185-
assert (
186-
srt_no_lora_outputs.output_strs[i].strip(" ")
187-
== hf_no_lora_outputs.output_strs[i]
188-
), (
189-
srt_no_lora_outputs.output_strs[i].strip(" "),
190-
hf_no_lora_outputs.output_strs[i],
191-
)
192-
# assert srt_outputs_lora_path_none == srt_no_lora_outputs
193-
194-
def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
195-
print("=================== testing serving =======================")
196-
# test batch forward
197-
base_path = lora_set["base"]
198-
all_lora_paths = lora_set["loras"]
199-
batch_lora_paths = [None]
200-
i = 0
201-
for _ in range(len(prompts) - 1):
202-
batch_lora_paths.append(all_lora_paths[i])
203-
i = (i + 1) % len(all_lora_paths)
204-
205-
with SRTRunner(
206-
base_path,
207-
tp_size=tp_size,
208-
torch_dtype=torch_dtype,
209-
model_type="generation",
210-
lora_paths=all_lora_paths,
211-
max_loras_per_batch=3,
212-
disable_cuda_graph=True,
213-
disable_radix_cache=True,
214-
) as srt_runner:
215-
srt_outputs = srt_runner.batch_forward(
216-
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
217-
)
218-
219-
with HFRunner(
220-
base_path,
221-
torch_dtype=torch_dtype,
222-
model_type="generation",
223-
output_str_only=True,
224-
) as hf_runner:
225-
hf_outputs = hf_runner.forward(
226-
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
227-
)
228-
229-
# compare output strings
230-
print(f"{hf_outputs.output_strs=}")
231-
print(f"{srt_outputs.output_strs=}")
232-
for i in range(len(prompts)):
233-
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
234-
srt_outputs.output_strs[i].strip(" "),
235-
hf_outputs.output_strs[i],
236-
)
237-
238-
def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
239-
print("=================== testing base inference =======================")
240-
base_path = lora_set["base"]
241-
all_lora_paths = lora_set["loras"]
242-
batch_lora_paths = [None] * len(prompts)
243-
244-
with SRTRunner(
245-
base_path,
246-
tp_size=tp_size,
247-
torch_dtype=torch_dtype,
248-
model_type="generation",
249-
) as srt_runner:
250-
srt_no_lora_outputs = srt_runner.forward(
251-
prompts, max_new_tokens=max_new_tokens
252-
)
253-
254-
with SRTRunner(
255-
base_path,
256-
tp_size=tp_size,
257-
torch_dtype=torch_dtype,
258-
model_type="generation",
259-
lora_paths=all_lora_paths,
260-
) as srt_runner:
261-
srt_outputs = srt_runner.forward(
262-
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
263-
)
264-
265-
for i in range(len(prompts)):
266-
srt_no_lora_logprobs = torch.Tensor(
267-
srt_no_lora_outputs.top_input_logprobs[i]
268-
)
269-
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
270-
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))
271-
272-
print(f"{srt_no_lora_outputs.output_strs=}")
273-
print(f"{srt_outputs.output_strs=}")
274-
275-
for i in range(len(prompts)):
276-
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
277-
srt_outputs.output_strs[i].strip(" "),
278-
hf_outputs.output_strs[i],
279-
)
280-
assert (
281-
srt_no_lora_outputs[i].output_strs.strip(" ")
282-
== hf_no_lora_outputs[i].output_strs
283-
)
284-
285-
def test_all(self):
286-
for lora_set in LORA_SETS:
287-
# self.load_lora_adapter(lora_set, 1)
63+
def test_lora_batch_with_none(self):
64+
for model_case in LORA_MODELS_WITH_NONE:
65+
prompts = PROMPTS
28866
for torch_dtype in TORCH_DTYPES:
289-
tp_size = 1
290-
max_new_tokens = 32
291-
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
292-
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
293-
# self.base_inference(
294-
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
295-
# )
67+
run_lora_test_by_batch(
68+
prompts,
69+
model_case,
70+
torch_dtype,
71+
max_new_tokens=32,
72+
backend="triton",
73+
test_tag="test_lora_batch_with_none",
74+
)
29675

29776

29877
if __name__ == "__main__":

0 commit comments

Comments
 (0)