Skip to content

Commit f836658

Browse files
authored
[Spec][Ngram] 4/N: Remove max_match_window_size and min_match_window_size, matching all suffixes of the Trie (#21225)
1 parent 269589a commit f836658

13 files changed

Lines changed: 46 additions & 134 deletions

File tree

docs/advanced_features/server_arguments.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,10 @@ Please consult the documentation below and [server_args.py](https://github.com/s
295295
## Ngram speculative decoding
296296
| Argument | Description | Defaults | Options |
297297
| --- | --- | --- | --- |
298-
| `--speculative-ngram-min-match-window-size` | The minimum window size for pattern matching in ngram speculative decoding. | `1` | Type: int |
299-
| `--speculative-ngram-max-match-window-size` | The maximum window size for pattern matching in ngram speculative decoding. | `12` | Type: int |
300298
| `--speculative-ngram-min-bfs-breadth` | The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `1` | Type: int |
301299
| `--speculative-ngram-max-bfs-breadth` | The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding. | `10` | Type: int |
302300
| `--speculative-ngram-match-type` | Ngram tree-building mode. `BFS` selects recency-based expansion and `PROB` selects frequency-based expansion. This setting is forwarded to the ngram cache implementation. | `BFS` | `BFS`, `PROB` |
303-
| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` | Type: int |
301+
| `--speculative-ngram-max-trie-depth` | Maximum suffix length stored and matched by the ngram trie. | `18` | Type: int |
304302
| `--speculative-ngram-capacity` | The cache capacity for ngram speculative decoding. | `10000000` | Type: int |
305303
306304
## Multi-layer Eagle speculative decoding

docs/advanced_features/speculative_decoding.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,11 @@ Enable it with:
387387

388388
| Parameter | Description | Default |
389389
|---|---|---|
390-
| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. If omitted, defaults to `--speculative-ngram-max-match-window-size`. | `12` (with default ngram settings) |
391-
| `--speculative-ngram-min-match-window-size` | Minimum matching window size. | `1` |
392-
| `--speculative-ngram-max-match-window-size` | Maximum matching window size. | `12` |
390+
| `--speculative-num-draft-tokens` | Number of draft tokens verified per step. If omitted, defaults to `min(--speculative-ngram-max-trie-depth, 12)`. | `12` (with default ngram settings) |
393391
| `--speculative-ngram-min-bfs-breadth` | Minimum BFS breadth. | `1` |
394392
| `--speculative-ngram-max-bfs-breadth` | Maximum BFS breadth. | `10` |
395393
| `--speculative-ngram-match-type` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion. | `"BFS"` |
396-
| `--speculative-ngram-max-trie-depth` | The max trie depth for ngram speculative decoding. | `18` |
394+
| `--speculative-ngram-max-trie-depth` | Maximum suffix length stored and matched by the ngram trie. | `18` |
397395
| `--speculative-ngram-capacity` | Cache capacity (number of entries). | `10,000,000` |
398396

399397
Notes:
@@ -408,7 +406,6 @@ python3 -m sglang.launch_server \
408406
--model Qwen/Qwen2.5-7B-Instruct \
409407
--speculative-algorithm NGRAM \
410408
--speculative-num-draft-tokens 16 \
411-
--speculative-ngram-max-match-window-size 12 \
412409
--speculative-ngram-max-bfs-breadth 10 \
413410
--mem-fraction-static 0.7 \
414411
--cuda-graph-max-bs 8 \
@@ -464,12 +461,10 @@ Below is a comprehensive list of all speculative decoding parameters available i
464461

465462
| Parameter | Type | Default | Description |
466463
|---|---|---|---|
467-
| `--speculative-ngram-min-match-window-size` | `int` | `1` | Minimum ngram matching window |
468-
| `--speculative-ngram-max-match-window-size` | `int` | `12` | Maximum ngram matching window |
469464
| `--speculative-ngram-min-bfs-breadth` | `int` | `1` | Minimum BFS breadth |
470465
| `--speculative-ngram-max-bfs-breadth` | `int` | `10` | Maximum BFS breadth |
471466
| `--speculative-ngram-match-type` | `str` | `"BFS"` | Ngram tree-building mode: `"BFS"` for recency-based expansion or `"PROB"` for frequency-based expansion |
472-
| `--speculative-ngram-max-trie-depth` | `int` | `18` | Max trie depth for ngram speculative decoding |
467+
| `--speculative-ngram-max-trie-depth` | `int` | `18` | Maximum suffix length stored and matched by the ngram trie |
473468
| `--speculative-ngram-capacity` | `int` | `10,000,000` | Cache capacity |
474469

475470
### Environment variables

python/sglang/srt/server_args.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,6 @@ class ServerArgs:
506506
speculative_draft_model_quantization: Optional[str] = None
507507

508508
# Speculative decoding (ngram)
509-
speculative_ngram_min_match_window_size: int = 1
510-
speculative_ngram_max_match_window_size: int = 12
511509
speculative_ngram_min_bfs_breadth: int = 1
512510
speculative_ngram_max_bfs_breadth: int = 10
513511
speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
@@ -3108,8 +3106,10 @@ def _handle_speculative_decoding(self):
31083106
self.enable_mixed_chunk = False
31093107
self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth
31103108
if self.speculative_num_draft_tokens is None:
3111-
self.speculative_num_draft_tokens = (
3112-
self.speculative_ngram_max_match_window_size
3109+
self.speculative_num_draft_tokens = 12
3110+
logger.warning(
3111+
"speculative_num_draft_tokens is set to 12 by default for ngram speculative decoding. "
3112+
"You can override this by explicitly setting --speculative-num-draft-tokens."
31133113
)
31143114
logger.warning(
31153115
"The overlap scheduler and mixed chunked prefill are disabled because of "
@@ -4851,18 +4851,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
48514851
)
48524852

48534853
# Speculative decoding (ngram)
4854-
parser.add_argument(
4855-
"--speculative-ngram-min-match-window-size",
4856-
type=int,
4857-
default=ServerArgs.speculative_ngram_min_match_window_size,
4858-
help="The minimum window size for pattern matching in ngram speculative decoding.",
4859-
)
4860-
parser.add_argument(
4861-
"--speculative-ngram-max-match-window-size",
4862-
type=int,
4863-
default=ServerArgs.speculative_ngram_max_match_window_size,
4864-
help="The maximum window size for pattern matching in ngram speculative decoding.",
4865-
)
48664854
parser.add_argument(
48674855
"--speculative-ngram-min-bfs-breadth",
48684856
type=int,

python/sglang/srt/speculative/cpp_ngram/ngram.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,6 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) {
1313
throw std::runtime_error(
1414
"param_.max_trie_depth must be greater than 1, current value: " + std::to_string(param_.max_trie_depth));
1515
}
16-
if (!(param_.min_match_window_size > 0)) {
17-
throw std::runtime_error(
18-
"min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
19-
}
20-
if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
21-
throw std::runtime_error(
22-
"min_match_window_size must be less than or equal to "
23-
"max_match_window_size, current min_match_window_size: " +
24-
std::to_string(param_.min_match_window_size) +
25-
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
26-
}
27-
if (!(param_.max_match_window_size < param_.max_trie_depth)) {
28-
throw std::runtime_error(
29-
"max_match_window_size must be less than max_trie_depth, current "
30-
"max_match_window_size: " +
31-
std::to_string(param_.max_match_window_size) + ", max_trie_depth: " + std::to_string(param_.max_trie_depth));
32-
}
3316
if (!(param_.min_bfs_breadth > 0)) {
3417
throw std::runtime_error(
3518
"min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
@@ -53,20 +36,6 @@ Ngram::Ngram(size_t capacity, const Param& param) : param_(param) {
5336
}
5437
}
5538
}
56-
for (auto config : param_.batch_min_match_window_size) {
57-
if (config != std::numeric_limits<decltype(config)>::max()) {
58-
if (!(config >= param_.min_match_window_size)) {
59-
throw std::runtime_error(
60-
"batch_min_match_window_size config value " + std::to_string(config) +
61-
" must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
62-
}
63-
if (!(config <= param_.max_match_window_size)) {
64-
throw std::runtime_error(
65-
"batch_min_match_window_size config value " + std::to_string(config) +
66-
" must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
67-
}
68-
}
69-
}
7039

7140
trie_ = std::make_unique<Trie>(capacity, param_);
7241

python/sglang/srt/speculative/cpp_ngram/ngram_corpus.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ class NgramCorpus:
2626
def __init__(
2727
self,
2828
max_trie_depth=18,
29-
min_match_window_size=1,
30-
max_match_window_size=10,
3129
min_bfs_breadth=1,
3230
max_bfs_breadth=8,
3331
draft_token_num=8,
@@ -36,8 +34,6 @@ def __init__(
3634
):
3735
param = ngram_corpus_cpp.Param()
3836
param.max_trie_depth = max_trie_depth
39-
param.min_match_window_size = min_match_window_size
40-
param.max_match_window_size = max_match_window_size
4137
param.min_bfs_breadth = min_bfs_breadth
4238
param.max_bfs_breadth = max_bfs_breadth
4339
param.draft_token_num = draft_token_num

python/sglang/srt/speculative/cpp_ngram/ngram_corpus_binding.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,12 @@ PYBIND11_MODULE(ngram_corpus_cpp, m) {
2121
.def_readwrite("enable_router_mode", &Param::enable_router_mode)
2222
.def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth)
2323
.def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
24-
.def_readwrite("min_match_window_size", &Param::min_match_window_size)
25-
.def_readwrite("max_match_window_size", &Param::max_match_window_size)
2624
.def_readwrite("max_trie_depth", &Param::max_trie_depth)
2725
.def_readwrite("draft_token_num", &Param::draft_token_num)
2826
.def_readwrite("match_type", &Param::match_type)
29-
.def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
3027
.def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num)
3128
.def("get_draft_token_num", &Param::get_draft_token_num, "")
32-
.def("get_min_match_window_size", &Param::get_min_match_window_size, "")
3329
.def("parse", &Param::parse, "")
34-
.def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "")
3530
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
3631
.def("detail", &Param::detail, "");
3732

python/sglang/srt/speculative/cpp_ngram/param.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@ struct Param {
1717
bool enable_router_mode;
1818
size_t min_bfs_breadth;
1919
size_t max_bfs_breadth;
20-
size_t min_match_window_size;
21-
size_t max_match_window_size;
2220
size_t max_trie_depth;
2321
size_t draft_token_num;
2422
std::string match_type;
2523

26-
std::vector<size_t> batch_min_match_window_size;
2724
std::vector<size_t> batch_draft_token_num;
2825

2926
size_t get_draft_token_num(size_t batch_size) const {
@@ -36,16 +33,6 @@ struct Param {
3633
return draft_token_num - 1;
3734
}
3835

39-
size_t get_min_match_window_size(size_t batch_size) const {
40-
if (batch_size < batch_min_match_window_size.size()) {
41-
if (batch_min_match_window_size[batch_size] !=
42-
std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
43-
return batch_min_match_window_size[batch_size];
44-
}
45-
}
46-
return min_match_window_size;
47-
}
48-
4936
std::vector<size_t> parse(const std::string& value) {
5037
// 0-1|10,2-3|20,
5138
std::vector<size_t> result;
@@ -96,10 +83,6 @@ struct Param {
9683
return result;
9784
}
9885

99-
void resetBatchMinMatchWindowSize(const std::string& value) {
100-
batch_min_match_window_size = parse(value);
101-
}
102-
10386
void resetBatchReturnTokenNum(const std::string& value) {
10487
batch_draft_token_num = parse(value);
10588
}
@@ -108,13 +91,8 @@ struct Param {
10891
std::stringstream ss;
10992
ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
11093
<< ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
111-
<< ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
11294
<< ", max_trie_depth = " << max_trie_depth << ", draft_token_num = " << draft_token_num
11395
<< ", match_type = " << match_type;
114-
ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
115-
for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
116-
ss << i << "|" << batch_min_match_window_size[i] << ",";
117-
}
11896
ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
11997
for (int i = 0; i < batch_draft_token_num.size(); ++i) {
12098
ss << i << "|" << batch_draft_token_num[i] << ",";

python/sglang/srt/speculative/cpp_ngram/trie.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Trie::Trie(size_t capacity, const Param& param) : param_(param) {
1919
}
2020

2121
void Trie::insert(const int32_t* tokens, size_t len) {
22-
for (size_t i = 0; i + param_.min_match_window_size < len; ++i) {
22+
for (size_t i = 0; i < len; ++i) {
2323
auto start = tokens + i;
2424
auto end = start + std::min(len - i, param_.max_trie_depth);
2525

@@ -100,14 +100,13 @@ void Trie::reset() {
100100
root_ = getNode();
101101
}
102102

103-
std::vector<std::pair<TrieNode*, int32_t>>
104-
Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_window) const {
103+
std::vector<std::pair<TrieNode*, int32_t>> Trie::match(const int32_t* context, size_t len) const {
105104
std::vector<std::pair<TrieNode*, int32_t>> result;
106-
result.reserve(max_window - min_window);
107-
for (int32_t match_window_size = std::min(len, max_window); match_window_size >= static_cast<int32_t>(min_window);
108-
--match_window_size) {
109-
auto start = context + len - match_window_size;
110-
auto end = start + match_window_size;
105+
const auto max_match_depth = std::min(len, param_.max_trie_depth);
106+
result.reserve(max_match_depth);
107+
for (size_t match_depth = max_match_depth; match_depth > 0; --match_depth) {
108+
auto start = context + len - match_depth;
109+
auto end = start + match_depth;
111110
auto cursor = root_;
112111
while (start != end) {
113112
auto iter = cursor->child.find(*start);
@@ -118,27 +117,27 @@ Trie::match(const int32_t* context, size_t len, size_t min_window, size_t max_wi
118117
++start;
119118
cursor = iter->second;
120119
}
121-
if (cursor) {
122-
result.emplace_back(std::make_pair(cursor, match_window_size));
120+
if (cursor != nullptr && !cursor->child.empty()) {
121+
result.emplace_back(cursor, static_cast<int32_t>(match_depth));
123122
}
124123
}
125124
return result;
126125
}
127126

128127
Result Trie::buildRecency(
129128
const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const {
130-
auto anchors = match(context, len, param.min_match_window_size, param.max_match_window_size);
129+
auto anchors = match(context, len);
131130

132-
double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) /
133-
(param.max_match_window_size - param.min_match_window_size + 1);
131+
const auto max_match_depth = std::max<int32_t>(1, static_cast<int32_t>(param.max_trie_depth - 1));
132+
double bfs_breadth_scale = double(param.max_bfs_breadth - param.min_bfs_breadth) / max_match_depth;
134133

135134
std::vector<Node> tree(draft_token_num + 1);
136135
int root = 0;
137136
int cursor = 1;
138137

139138
for (auto [node, depth] : anchors) {
140139
std::queue<std::tuple<int32_t, double, const TrieNode*>> queue;
141-
queue.push({root, (param.max_match_window_size - depth) * bfs_breadth_scale + param.min_bfs_breadth, node});
140+
queue.push({root, (max_match_depth - depth) * bfs_breadth_scale + param.min_bfs_breadth, node});
142141
while (queue.size() && cursor <= static_cast<int>(draft_token_num)) {
143142
auto front = queue.front();
144143
queue.pop();
@@ -168,7 +167,7 @@ Result Trie::buildRecency(
168167

169168
Result Trie::buildFrequency(
170169
const int32_t* context, size_t len, int32_t last_token, size_t draft_token_num, const Param& param) const {
171-
auto anchors = match(context, len, param.min_match_window_size, param.max_match_window_size);
170+
auto anchors = match(context, len);
172171

173172
struct CompareByLastDouble {
174173
bool operator()(

python/sglang/srt/speculative/cpp_ngram/trie.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ class Trie {
4949
void reset();
5050

5151
private:
52-
std::vector<std::pair<TrieNode*, int32_t>>
53-
match(const int32_t* context, size_t len, size_t min_window, size_t max_window) const;
52+
std::vector<std::pair<TrieNode*, int32_t>> match(const int32_t* context, size_t len) const;
5453

5554
TrieNode* getNode() {
5655
auto node = node_pool_[--free_node_count_];

python/sglang/srt/speculative/ngram_worker.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,13 @@ def __init__(
4141
self.page_size = server_args.page_size
4242
self.draft_token_num: int = server_args.speculative_num_draft_tokens
4343
self.max_trie_depth: int = server_args.speculative_ngram_max_trie_depth
44-
self.max_match_window_size: int = (
45-
server_args.speculative_ngram_max_match_window_size
46-
)
4744

4845
self.max_batch_size = target_worker.max_running_requests
4946
self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
5047

5148
self._init_preallocated_tensors()
5249

5350
self.ngram_corpus = NgramCorpus(
54-
min_match_window_size=server_args.speculative_ngram_min_match_window_size,
55-
max_match_window_size=server_args.speculative_ngram_max_match_window_size,
5651
min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
5752
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
5853
match_type=server_args.speculative_ngram_match_type,
@@ -131,7 +126,7 @@ def _prepare_draft_tokens(
131126
batch_tokens = []
132127
for req in batch.reqs:
133128
check_token = self._efficient_concat_last_n(
134-
req.origin_input_ids, req.output_ids, self.max_match_window_size
129+
req.origin_input_ids, req.output_ids, self.max_trie_depth
135130
)
136131
batch_tokens.append(check_token)
137132
req_drafts, mask = self.ngram_corpus.batch_get(batch_tokens)

0 commit comments

Comments
 (0)