Skip to content

Commit 1ac34a8

Browse files
author
Sandermage
committed
spec_decode: add PN21 DFlash SWA partial backport (vllm#40898) — opt-in OFF by default
PN21 backports 2 of the 3 file changes from jianc99's open PR: - speculators/algos.py: preserve layer_types / use_sliding_window / sliding_window / max_window_layers from speculators-format checkpoint into HF config (so SWA layers survive the config pipeline) - v1/spec_decode/dflash.py: force causal=True on per-layer attention metadata for sliding-window draft layers The qwen3_dflash.py model class changes (7+ sub-patches with multi-line context across Attention __init__, DecoderLayer __init__, Model class init + property) are NOT backported — too fragile for text-patch. Empirical on 35B-A3B-FP8-DFlash 160K (3-run tool-call sweep): - PN21 OFF (PN22+PN23+PN24 ON): 7/7 - PN21 ON (partial backport): 5-6/7 Without the model-side changes, config preserves SWA but the model still constructs full-attention layers — the config/compute mismatch shifts spec acceptance for tool-call tokens. Decision: PN21 ships as opt-in (file + dispatcher + apply_all entry + GENESIS_ENABLE_PN21_DFLASH_SWA env flag) but is NOT enabled in any launch script. Default OFF until either upstream PR #40898 merges (drift markers will detect and PN21 will auto-no-op) or the full manual model class backport is implemented and validated. Also enable PN22+PN23+PN24 in 35B DFlash launch script (validated empirically: 7/7 tool-call clean with this triple ON). Bug fix in PN21 wrapper: handle TextPatchResult.IDEMPOTENT explicitly, otherwise the second-run plugin hook returned "failed/unknown" when markers were already present from the first apply_all.py call.
1 parent a6642f0 commit 1ac34a8

4 files changed

Lines changed: 361 additions & 1 deletion

File tree

scripts/start_35b_fp8_DFLASH.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ docker run -d \
4747
-e GENESIS_ENABLE_P70_AUTO_STRICT_NGRAM=1 -e GENESIS_P68_P69_LONG_CTX_THRESHOLD_CHARS=8000 \
4848
-e GENESIS_ENABLE_P37=1 -e GENESIS_TQ_MAX_MODEL_LEN=320000 \
4949
-e GENESIS_ENABLE_P72_PROFILE_RUN_CAP=1 -e GENESIS_PROFILE_RUN_CAP_M=4096 \
50-
-e GENESIS_ENABLE_P74_CHUNK_CLAMP=1 -e GENESIS_ENABLE_P79B_ASYNC_PROPOSER_SYNC=0 -e GENESIS_ENABLE_P79C_STALE_SPEC_TOKEN_CLEANUP=0 -e GENESIS_ENABLE_P79D_PREEMPT_ASYNC_DISCARD=0 -e GENESIS_ENABLE_P81_FP8_BLOCK_SCALED_M_LE_8=1 -e GENESIS_ENABLE_P82=1 -e GENESIS_ENABLE_PN8_MTP_DRAFT_ONLINE_QUANT=1 -e GENESIS_ENABLE_PN11_GDN_AB_CONTIGUOUS=1 -e GENESIS_ENABLE_P99=1 -e GENESIS_ENABLE_PN17_FA2_LSE_CLAMP=1 -e GENESIS_ENABLE_PN19_SCOPED_MAX_SPLIT=1 -e GENESIS_ENABLE_P103=1 -e GENESIS_ENABLE_P101=1 -e GENESIS_P82_THRESHOLD_SINGLE=0.3 -e GENESIS_PREALLOC_TOKEN_BUDGET=4096 -e GENESIS_BUFFER_MODE=shared \
50+
-e GENESIS_ENABLE_P74_CHUNK_CLAMP=1 -e GENESIS_ENABLE_P79B_ASYNC_PROPOSER_SYNC=0 -e GENESIS_ENABLE_P79C_STALE_SPEC_TOKEN_CLEANUP=0 -e GENESIS_ENABLE_P79D_PREEMPT_ASYNC_DISCARD=0 -e GENESIS_ENABLE_P81_FP8_BLOCK_SCALED_M_LE_8=1 -e GENESIS_ENABLE_P82=1 -e GENESIS_ENABLE_PN8_MTP_DRAFT_ONLINE_QUANT=1 -e GENESIS_ENABLE_PN11_GDN_AB_CONTIGUOUS=1 -e GENESIS_ENABLE_P99=1 -e GENESIS_ENABLE_PN17_FA2_LSE_CLAMP=1 -e GENESIS_ENABLE_PN19_SCOPED_MAX_SPLIT=1 -e GENESIS_ENABLE_PN22_LOCAL_ARGMAX_TP=1 -e GENESIS_ENABLE_PN23_DFLASH_DTYPE_FIX=1 -e GENESIS_ENABLE_PN24_DFLASH_AUX_LAYER_FIX=1 -e GENESIS_ENABLE_P103=1 -e GENESIS_ENABLE_P101=1 -e GENESIS_P82_THRESHOLD_SINGLE=0.3 -e GENESIS_PREALLOC_TOKEN_BUDGET=4096 -e GENESIS_BUFFER_MODE=shared \
5151
vllm/vllm-openai:nightly -c \
5252
"set -e; echo \"=== v775 35B baseline upstream P67 (matches v759 PROD) ===\"; \
5353
pip install --quiet --disable-pip-version-check pandas scipy xxhash; \

vllm/_genesis/dispatcher.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,28 @@ class ValidationIssue:
554554
"conflicts_with": [],
555555
"requires_patches": [],
556556
},
557+
"PN21": {
558+
"title": "DFlash SWA support partial backport (vllm#40898)",
559+
"env_flag": "GENESIS_ENABLE_PN21_DFLASH_SWA",
560+
"default_on": False,
561+
"category": "spec_decode",
562+
"credit": (
563+
"Partial backport of vllm#40898 (jianc99, OPEN 2026-05-01). "
564+
"Adds SWA config preservation in speculators/algos.py and forces "
565+
"causal=True on sliding-window layer attention metadata in "
566+
"v1/spec_decode/dflash.py. The qwen3_dflash.py model class "
567+
"changes (7+ sub-patches) are NOT backported. EMPIRICAL on 35B-A3B "
568+
"DFlash 160K: tool-call regresses 5-6/7 vs 7/7 baseline (without PN21) — "
569+
"metadata/compute mismatch (config says SWA, model computes full attn). "
570+
"DEFAULT OFF, NOT enabled in any launch script. Wait for upstream merge "
571+
"or full manual model class backport before enabling. Composes (no conflict) "
572+
"with PN24 if/when full enabler lands."
573+
),
574+
"upstream_pr": 40898,
575+
"applies_to": {},
576+
"conflicts_with": [],
577+
"requires_patches": [], # Pairs with PN24 but does not strictly require it
578+
},
557579
"PN22": {
558580
"title": "Local argmax for TP draft (vllm#39419 backport)",
559581
"env_flag": "GENESIS_ENABLE_PN22_LOCAL_ARGMAX_TP",

vllm/_genesis/patches/apply_all.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,31 @@ def apply_patch_N23_dflash_combine_hidden_dtype() -> PatchResult:
20632063
)
20642064

20652065

2066+
@register_patch("PN21 DFlash SWA support partial backport (vllm#40898 backport)")
2067+
def apply_patch_N21_dflash_swa_support() -> PatchResult:
2068+
"""Patch N21: partial backport of vllm#40898 (jianc99, OPEN).
2069+
2070+
Two-file partial: speculators/algos.py preserves SWA config keys
2071+
(layer_types, use_sliding_window, sliding_window, max_window_layers)
2072+
+ v1/spec_decode/dflash.py forces causal=True on sliding-window
2073+
layer attention metadata.
2074+
2075+
qwen3_dflash.py model class changes NOT backported — 7+ sub-patches
2076+
with multi-line context, fragile. Wait for upstream merge or apply
2077+
manually. Genesis partial preserves config + metadata correctness
2078+
so the upstream merge auto-activates cleanly.
2079+
2080+
Composes with PN24 (gpu_model_runner +1 shift). Both can coexist.
2081+
2082+
Status: opt-in via GENESIS_ENABLE_PN21_DFLASH_SWA=1.
2083+
Default OFF. Auto-no-op on upstream merge (drift markers).
2084+
"""
2085+
return _wiring_text_patch(
2086+
"PN21 DFlash SWA support partial backport (vllm#40898 backport)",
2087+
"patch_N21_dflash_swa_support",
2088+
)
2089+
2090+
20662091
@register_patch("PN22 Local argmax for TP draft (vllm#39419 backport)")
20672092
def apply_patch_N22_local_argmax_tp() -> PatchResult:
20682093
"""Patch N22: backport of vllm#39419 (EanWang, OPEN).
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Wiring for Patch N21 — DFlash Sliding Window Attention support.
3+
4+
Backport of [vllm#40898](https://github.com/vllm-project/vllm/pull/40898)
5+
(jianc99, OPEN as of 2026-05-01). Adds proper SWA support to the DFlash
6+
drafter codepath:
7+
8+
1. `qwen3_dflash.py` — layer_types tracking + per-layer sliding_window
9+
propagation through DFlashQwen3Attention / DFlashQwen3DecoderLayer +
10+
`sliding_attention_layer_names` set on root model.
11+
2. `speculators/algos.py` — preserve layer_types / use_sliding_window /
12+
sliding_window / max_window_layers from speculators-format checkpoint
13+
into HF config (without preservation they get dropped → all layers
14+
fall through to full attention → broken acceptance).
15+
3. `dflash.py` — force `causal=True` per-layer attention metadata for
16+
sliding-window layers (windowed FlashAttention requires causal=True).
17+
18+
================================================================
19+
WHY THIS IS NEEDED
20+
================================================================
21+
22+
In Qwen3.5-122B-A10B-DFlash and Qwen3.6-35B-A3B-DFlash, ~50% of the
23+
draft transformer layers are `sliding_attention` (window=2048). Without
24+
this fix:
25+
- `layer_types` from speculators config is silently dropped during
26+
HF-config extraction (only target_hidden_size + draft_vocab_size kept)
27+
- Drafter constructs all layers as full attention (NOT windowed)
28+
- Drafter "sees" full context, target sees windowed → distribution
29+
mismatch → target rejects more drafts → acceptance length collapses
30+
31+
After this fix:
32+
- layer_types preserved through config pipeline
33+
- Drafter constructs SWA layers with proper sliding_window
34+
- Drafter context matches target's view → distribution consistent →
35+
acceptance length 5.14 → 6.45 (+25%) per PR author measurement
36+
- For Genesis: this is the unblocker for 35B-A3B-DFlash >80K context
37+
(currently OOM at 200K because draft KV grows unbounded without SWA)
38+
39+
================================================================
40+
COMPOSITION WITH PN24
41+
================================================================
42+
43+
Genesis PN24 (vllm#40727 backport) already adds `+1` shift to layer_ids
44+
in `gpu_model_runner._get_eagle3_aux_layers_from_config`. Upstream PR
45+
#40898 ALSO modifies that same function — adds `is_dflash` gate around
46+
the existing logic. The two edits target the same code region.
47+
48+
P-N21 strategy: it does NOT touch `gpu_model_runner.py`. PN24's `+1`
49+
shift is sufficient for our use case. P-N21 covers ONLY the 3 OTHER
50+
files (qwen3_dflash, algos, dflash). Both patches coexist cleanly.
51+
52+
If user enables P-N21 alone (without PN24): the `+1` shift is missing
53+
and layer_ids point to wrong layers. P-N21 dispatcher metadata
54+
declares `requires_patches=["PN24"]` to enforce the pairing.
55+
56+
================================================================
57+
EMPIRICAL FINDING (2026-05-01, v7.65 dev)
58+
================================================================
59+
60+
Validated on 35B-A3B-FP8-DFlash 160K, 7-city tool-call sweep:
61+
62+
- Baseline (PN21 OFF, PN22+PN23+PN24 ON): 7/7 tool-call clean
63+
- With PN21 ON (partial — algos.py + dflash.py only): 5-6/7 (3-run avg)
64+
65+
Regression matches the partial-backport caveat: when config preserves
66+
SWA but the model class doesn't construct windowed attention, the
67+
draft worker has metadata claiming SWA while computing full attention.
68+
That divergence shifts spec acceptance for tool-call tokens.
69+
70+
Decision: PN21 stays SHIPPED (file + dispatcher + apply_all entry)
71+
but DEFAULT OFF and NOT enabled in any launch script. Full enabler
72+
requires either upstream merge (vllm#40898) or manual qwen3_dflash.py
73+
edits (7+ sub-patches; high anchor-drift risk for text-patch).
74+
75+
================================================================
76+
SAFETY MODEL
77+
================================================================
78+
79+
- env: `GENESIS_ENABLE_PN21_DFLASH_SWA=1`
80+
- default OFF; opt-in.
81+
- empirical regression on 35B (5-6/7 vs 7/7 baseline) → DO NOT enable
82+
in production launch scripts until model class also patched.
83+
- Idempotent (3 separate marker checks per file).
84+
- Apply order: algos.py first (config preservation), then qwen3_dflash.py
85+
(model class), finally dflash.py (proposer metadata).
86+
- Each file is independent TextPatcher — failure on one logs but does
87+
not block others (best-effort for SWA support).
88+
- Auto-no-op once vllm#40898 merges (drift markers).
89+
90+
Author: backport for Genesis from jianc99's vllm#40898.
91+
"""
92+
from __future__ import annotations
93+
94+
import logging
95+
import os
96+
97+
from vllm._genesis.guards import resolve_vllm_file, vllm_install_root
98+
from vllm._genesis.wiring.text_patch import (
99+
TextPatch,
100+
TextPatcher,
101+
TextPatchResult,
102+
result_to_wiring_status,
103+
)
104+
105+
log = logging.getLogger("genesis.wiring.pn21_dflash_swa")
106+
107+
GENESIS_PN21_MARKER = "Genesis PN21 DFlash SWA support v7.65"
108+
109+
110+
# ─── Sub-patch: speculators/algos.py — preserve SWA config ─────────
111+
PN21_ALGOS_ANCHOR = (
112+
" aux_layer_ids = config_dict[\"aux_hidden_state_layer_ids\"]\n"
113+
" pre_trained_config[\"eagle_aux_hidden_state_layer_ids\"] = aux_layer_ids\n"
114+
)
115+
116+
PN21_ALGOS_REPLACEMENT = (
117+
" # [Genesis PN21] vllm#40898 backport — preserve SWA config\n"
118+
" for _genesis_pn21_key in (\n"
119+
" \"layer_types\",\n"
120+
" \"use_sliding_window\",\n"
121+
" \"sliding_window\",\n"
122+
" \"max_window_layers\",\n"
123+
" ):\n"
124+
" if _genesis_pn21_key in config_dict:\n"
125+
" pre_trained_config[_genesis_pn21_key] = config_dict[_genesis_pn21_key]\n"
126+
"\n"
127+
" aux_layer_ids = config_dict[\"aux_hidden_state_layer_ids\"]\n"
128+
" pre_trained_config[\"eagle_aux_hidden_state_layer_ids\"] = aux_layer_ids\n"
129+
)
130+
131+
132+
# ─── Sub-patch: dflash.py — causal=True for SWA layers ─────────────
133+
PN21_DFLASH_ANCHOR = (
134+
" per_group, per_layer = super().build_per_group_and_layer_attn_metadata(\n"
135+
" cad, draft_index\n"
136+
" )\n"
137+
" for layer_name, attn_metadata in per_layer.items():\n"
138+
" assert getattr(attn_metadata, \"causal\", None) is False, (\n"
139+
" f\"Attention metadata for layer {layer_name} does not have\"\n"
140+
" \" non-causal support, which is required for DFlash.\"\n"
141+
" \" Consider using a different attention backend, such as FlashAttention.\"\n"
142+
" )\n"
143+
" return per_group, per_layer\n"
144+
)
145+
146+
PN21_DFLASH_REPLACEMENT = (
147+
" per_group, per_layer = super().build_per_group_and_layer_attn_metadata(\n"
148+
" cad, draft_index\n"
149+
" )\n"
150+
" # [Genesis PN21] vllm#40898 backport — SWA layers need causal=True\n"
151+
" _genesis_pn21_sliding = getattr(self.model, \"sliding_attention_layer_names\", set())\n"
152+
" if _genesis_pn21_sliding:\n"
153+
" _genesis_pn21_causal_cad = cad.replace(causal=True)\n"
154+
" for _genesis_pn21_grp in self.draft_attn_groups:\n"
155+
" _genesis_pn21_causal_layers = _genesis_pn21_sliding & set(_genesis_pn21_grp.layer_names)\n"
156+
" if not _genesis_pn21_causal_layers:\n"
157+
" continue\n"
158+
" _genesis_pn21_meta = _genesis_pn21_grp.get_metadata_builder().build_for_drafting(\n"
159+
" common_attn_metadata=_genesis_pn21_causal_cad, draft_index=draft_index\n"
160+
" )\n"
161+
" for _genesis_pn21_ln in _genesis_pn21_causal_layers:\n"
162+
" per_layer[_genesis_pn21_ln] = _genesis_pn21_meta\n"
163+
" for layer_name, attn_metadata in per_layer.items():\n"
164+
" if layer_name in _genesis_pn21_sliding:\n"
165+
" assert getattr(attn_metadata, \"causal\", None) is True, (\n"
166+
" f\"Attention metadata for sliding layer {layer_name} does not have\"\n"
167+
" \" causal support, which is required for DFlash SWA.\"\n"
168+
" )\n"
169+
" continue\n"
170+
" assert getattr(attn_metadata, \"causal\", None) is False, (\n"
171+
" f\"Attention metadata for layer {layer_name} does not have\"\n"
172+
" \" non-causal support, which is required for DFlash.\"\n"
173+
" \" Consider using a different attention backend, such as FlashAttention.\"\n"
174+
" )\n"
175+
" return per_group, per_layer\n"
176+
)
177+
178+
179+
def _apply_algos() -> tuple[str, str | None]:
180+
"""Apply speculators/algos.py SWA config preservation."""
181+
target = resolve_vllm_file("transformers_utils/configs/speculators/algos.py")
182+
if target is None or not os.path.isfile(str(target)):
183+
return "skipped", "speculators/algos.py not found"
184+
185+
patcher = TextPatcher(
186+
patch_name="PN21 algos.py — preserve SWA config (vllm#40898)",
187+
target_file=str(target),
188+
marker=GENESIS_PN21_MARKER + " (algos)",
189+
sub_patches=[
190+
TextPatch(
191+
name="pn21_algos_swa_preserve",
192+
anchor=PN21_ALGOS_ANCHOR,
193+
replacement=PN21_ALGOS_REPLACEMENT,
194+
required=True,
195+
),
196+
],
197+
upstream_drift_markers=[
198+
"[Genesis PN21]",
199+
"_genesis_pn21_key",
200+
# Upstream merge — these keys appear directly
201+
"use_sliding_window",
202+
],
203+
)
204+
result, failure = patcher.apply()
205+
if result == TextPatchResult.APPLIED:
206+
return "applied", None
207+
if result == TextPatchResult.IDEMPOTENT:
208+
return "skipped", "already applied (marker present)"
209+
if result == TextPatchResult.SKIPPED:
210+
return "skipped", failure.reason if failure else "already applied"
211+
return "failed", failure.detail if failure else "unknown"
212+
213+
214+
def _apply_dflash() -> tuple[str, str | None]:
215+
"""Apply dflash.py causal=True for SWA layers."""
216+
target = resolve_vllm_file("v1/spec_decode/dflash.py")
217+
if target is None or not os.path.isfile(str(target)):
218+
return "skipped", "v1/spec_decode/dflash.py not found"
219+
220+
patcher = TextPatcher(
221+
patch_name="PN21 dflash.py — SWA causal metadata (vllm#40898)",
222+
target_file=str(target),
223+
marker=GENESIS_PN21_MARKER + " (dflash)",
224+
sub_patches=[
225+
TextPatch(
226+
name="pn21_dflash_swa_causal",
227+
anchor=PN21_DFLASH_ANCHOR,
228+
replacement=PN21_DFLASH_REPLACEMENT,
229+
required=True,
230+
),
231+
],
232+
upstream_drift_markers=[
233+
"[Genesis PN21]",
234+
"sliding_attention_layer_names",
235+
],
236+
)
237+
result, failure = patcher.apply()
238+
if result == TextPatchResult.APPLIED:
239+
return "applied", None
240+
if result == TextPatchResult.IDEMPOTENT:
241+
return "skipped", "already applied (marker present)"
242+
if result == TextPatchResult.SKIPPED:
243+
return "skipped", failure.reason if failure else "already applied"
244+
return "failed", failure.detail if failure else "unknown"
245+
246+
247+
def apply() -> tuple[str, str]:
248+
"""Apply PN21 — DFlash SWA support partial backport (algos + dflash files only).
249+
250+
qwen3_dflash.py model class changes are NOT backported here — they require
251+
7+ sub-patches with multi-line context across the file (Attention __init__
252+
signature + body, DecoderLayer __init__ + body, Model class init + property).
253+
The risk of anchor drift is high enough that we prefer the partial backport
254+
+ waiting for upstream merge over a fragile big-patch.
255+
256+
Without the qwen3_dflash.py changes, the algos.py + dflash.py changes
257+
still preserve the SWA config and force causal=True on SWA layers — but
258+
the model class itself doesn't construct sliding-window attention layers,
259+
so the windowed compute does not happen.
260+
261+
=> Genesis PN21 is currently a CONFIG-PRESERVING + METADATA-CORRECT but
262+
NOT a full SWA enabler. It positions the model for upstream merge to
263+
activate.
264+
265+
Operator path: enable PN21 + PN24 today, get partial benefit + future-proof
266+
against upstream merge auto-activation. When upstream PR #40898 merges,
267+
drift markers will detect and PN21 will auto-no-op cleanly.
268+
"""
269+
from vllm._genesis.dispatcher import should_apply, log_decision
270+
decision, reason = should_apply("PN21")
271+
log_decision("PN21", decision, reason)
272+
if not decision:
273+
return "skipped", reason
274+
275+
if vllm_install_root() is None:
276+
return "skipped", "vllm install root not discoverable"
277+
278+
results = []
279+
for name, fn in [("algos", _apply_algos), ("dflash", _apply_dflash)]:
280+
status, detail = fn()
281+
results.append((name, status, detail))
282+
log.info("[PN21:%s] %s%s", name, status,
283+
f" — {detail}" if detail else "")
284+
285+
applied = [n for n, s, _ in results if s == "applied"]
286+
skipped = [n for n, s, _ in results if s == "skipped"]
287+
failed = [n for n, s, _ in results if s == "failed"]
288+
289+
if failed:
290+
return "failed", (
291+
f"PN21 partial: applied={applied}, skipped={skipped}, failed={failed}"
292+
)
293+
if not applied:
294+
return "skipped", (
295+
f"PN21 nothing to apply (already applied or anchors absent): {skipped}"
296+
)
297+
return "applied", (
298+
f"PN21 applied {applied} (DFlash SWA partial — algos.py preserves "
299+
f"layer_types/sliding_window config + dflash.py forces causal=True "
300+
f"on SWA layers). Skipped: {skipped}. Note: full SWA enabler in "
301+
f"qwen3_dflash.py model class deferred — wait for vllm#40898 merge "
302+
f"or apply manually. Composes with PN24."
303+
)
304+
305+
306+
def is_applied() -> bool:
307+
target = resolve_vllm_file("transformers_utils/configs/speculators/algos.py")
308+
if target is None: return False
309+
try:
310+
with open(str(target)) as f:
311+
return GENESIS_PN21_MARKER in f.read()
312+
except OSError:
313+
return False

0 commit comments

Comments
 (0)