Skip to content

Commit d8fbc7c

Browse files
authored
[feature] support for roberta embedding models (#5730)
1 parent c5e1026 commit d8fbc7c

File tree

3 files changed

+186
-2
lines changed

3 files changed

+186
-2
lines changed

python/sglang/srt/layers/pooler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class PoolingType(IntEnum):
1414
LAST = 0
15+
CLS = 1
1516

1617

1718
@dataclass
@@ -41,6 +42,11 @@ def forward(
4142
if self.pooling_type == PoolingType.LAST:
4243
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
4344
pooled_data = hidden_states[last_token_indices]
45+
elif self.pooling_type == PoolingType.CLS:
46+
prompt_lens = forward_batch.extend_seq_lens
47+
first_token_flat_indices = torch.zeros_like(prompt_lens)
48+
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
49+
pooled_data = hidden_states[first_token_flat_indices]
4450
else:
4551
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
4652

python/sglang/srt/models/roberta.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import itertools
4+
from typing import Iterable, Optional, Tuple
5+
6+
import torch
7+
from torch import nn
8+
9+
from sglang.srt.layers.pooler import Pooler, PoolingType
10+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
11+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
12+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
13+
from sglang.srt.model_loader.weight_utils import default_weight_loader
14+
from sglang.srt.models.bert import BertEncoder
15+
16+
RobertaConfig = None
17+
18+
19+
class RobertaEmbedding(nn.Module):
20+
21+
def __init__(self, config: RobertaConfig):
22+
super().__init__()
23+
self.size = config.hidden_size
24+
self.word_embeddings = VocabParallelEmbedding(
25+
config.vocab_size, config.hidden_size
26+
)
27+
self.padding_idx = config.pad_token_id
28+
self.position_embeddings = nn.Embedding(
29+
config.max_position_embeddings,
30+
config.hidden_size,
31+
padding_idx=self.padding_idx,
32+
)
33+
34+
self.token_type_embeddings = nn.Embedding(
35+
config.type_vocab_size, config.hidden_size
36+
)
37+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
38+
39+
self.position_ids = nn.Parameter(
40+
torch.empty((1, config.max_position_embeddings)),
41+
)
42+
43+
self.position_embedding_type = config.position_embedding_type
44+
if self.position_embedding_type != "absolute":
45+
raise ValueError(
46+
"Only 'absolute' position_embedding_type" + " is supported"
47+
)
48+
49+
def forward(
50+
self,
51+
input_ids: torch.Tensor,
52+
seq_lens: torch.Tensor,
53+
position_ids: torch.Tensor,
54+
inputs_embeds=None,
55+
token_type_ids: Optional[torch.Tensor] = None,
56+
) -> torch.Tensor:
57+
input_shape = input_ids.size()
58+
inputs_embeds = self.word_embeddings(input_ids)
59+
60+
# adpated from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
61+
62+
pos_list = []
63+
token_list = []
64+
offset = 0
65+
for seq_len in seq_lens:
66+
pos_list.append(position_ids[offset : offset + seq_len])
67+
token_list.append(input_ids[offset : offset + seq_len])
68+
offset += seq_len
69+
70+
new_pos_list = []
71+
for positions, tokens in zip(pos_list, token_list):
72+
# Verify assumption that incoming position are
73+
# always a sequence from 0 to N.
74+
expected_pos = torch.arange(
75+
positions.size()[0], dtype=torch.long, device=inputs_embeds.device
76+
)
77+
assert torch.equal(positions, expected_pos)
78+
new_pos_list.append(
79+
create_position_ids_from_input_ids(tokens, self.padding_idx)
80+
)
81+
position_ids = torch.cat(new_pos_list)
82+
83+
# Position embeddings.
84+
position_embeddings = self.position_embeddings(position_ids)
85+
if token_type_ids is None:
86+
token_type_ids = torch.zeros(
87+
input_shape, dtype=torch.long, device=inputs_embeds.device
88+
)
89+
90+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
91+
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
92+
embeddings = self.LayerNorm(embeddings)
93+
return embeddings
94+
95+
96+
class XLMRobertaModel(nn.Module):
97+
def __init__(
98+
self,
99+
*,
100+
config: RobertaConfig,
101+
quant_config: Optional[QuantizationConfig] = None,
102+
prefix: str = "",
103+
):
104+
super().__init__()
105+
106+
self.config = config
107+
self.embeddings = RobertaEmbedding(config)
108+
self.encoder = BertEncoder(config=config, quant_config=quant_config, prefix="")
109+
self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
110+
111+
@torch.no_grad()
112+
def forward(
113+
self,
114+
input_ids: torch.Tensor,
115+
positions: torch.Tensor,
116+
forward_batch: ForwardBatch,
117+
input_embeds: torch.Tensor = None,
118+
get_embedding: bool = False,
119+
) -> torch.Tensor:
120+
assert get_embedding == True
121+
# Your tokenized IDs
122+
123+
hidden_states = self.embeddings(
124+
input_ids=input_ids,
125+
position_ids=positions,
126+
seq_lens=forward_batch.seq_lens,
127+
)
128+
129+
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
130+
pooler_out = self.pooler(hidden_states, forward_batch)
131+
return pooler_out
132+
133+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
134+
stacked_params_mapping = [
135+
# (param_name, shard_name, shard_id)
136+
("qkv_proj", "query", "q"),
137+
("qkv_proj", "key", "k"),
138+
("qkv_proj", "value", "v"),
139+
]
140+
141+
params_dict = dict(self.named_parameters())
142+
for name, loaded_weight in weights:
143+
name = name.replace("self", "self_attn")
144+
if "pooler" in name:
145+
continue
146+
for param_name, weight_name, shard_id in stacked_params_mapping:
147+
148+
if weight_name not in name:
149+
continue
150+
name = name.replace(weight_name, param_name)
151+
# Skip loading extra bias for GPTQ models.
152+
if name.endswith(".bias") and name not in params_dict:
153+
continue
154+
param = params_dict[name]
155+
weight_loader = param.weight_loader
156+
weight_loader(param, loaded_weight, shard_id)
157+
break
158+
else:
159+
# Skip loading extra bias for GPTQ models.
160+
if name.endswith(".bias") and name not in params_dict:
161+
continue
162+
param = params_dict[name]
163+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
164+
weight_loader(param, loaded_weight)
165+
166+
167+
# Adapted from transformers
168+
def create_position_ids_from_input_ids(
169+
input_ids, padding_idx, past_key_values_length=0
170+
):
171+
mask = input_ids.ne(padding_idx).int()
172+
incremental_indices = (
173+
torch.cumsum(mask, dim=0).type_as(mask) + past_key_values_length
174+
) * mask
175+
return incremental_indices.long() + padding_idx
176+
177+
178+
EntryClass = [XLMRobertaModel]

test/srt/models/test_encoder_embedding_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
2626
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
2727

28-
MODELS = [("BAAI/bge-small-en", 1, 1e-5)]
28+
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
2929

3030
ATTENTION_BACKEND = ["torch_native", "triton"]
31-
BATCH_SIZE = [30]
31+
BATCH_SIZE = [1, 2]
3232
TORCH_DTYPES = [torch.float32]
3333
sgl_to_st_ratio = []
3434

0 commit comments

Comments
 (0)