|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +"""Wiring for Patch N21 — DFlash Sliding Window Attention support. |
| 3 | +
|
| 4 | +Backport of [vllm#40898](https://github.com/vllm-project/vllm/pull/40898) |
| 5 | +(jianc99, OPEN as of 2026-05-01). Adds proper SWA support to the DFlash |
| 6 | +drafter codepath: |
| 7 | +
|
| 8 | +1. `qwen3_dflash.py` — layer_types tracking + per-layer sliding_window |
| 9 | + propagation through DFlashQwen3Attention / DFlashQwen3DecoderLayer + |
| 10 | + `sliding_attention_layer_names` set on root model. |
| 11 | +2. `speculators/algos.py` — preserve layer_types / use_sliding_window / |
| 12 | + sliding_window / max_window_layers from speculators-format checkpoint |
| 13 | + into HF config (without preservation they get dropped → all layers |
| 14 | + fall through to full attention → broken acceptance). |
| 15 | +3. `dflash.py` — force `causal=True` per-layer attention metadata for |
| 16 | + sliding-window layers (windowed FlashAttention requires causal=True). |
| 17 | +
|
| 18 | +================================================================ |
| 19 | +WHY THIS IS NEEDED |
| 20 | +================================================================ |
| 21 | +
|
| 22 | +In Qwen3.5-122B-A10B-DFlash and Qwen3.6-35B-A3B-DFlash, ~50% of the |
| 23 | +draft transformer layers are `sliding_attention` (window=2048). Without |
| 24 | +this fix: |
| 25 | +- `layer_types` from speculators config is silently dropped during |
| 26 | + HF-config extraction (only target_hidden_size + draft_vocab_size kept) |
| 27 | +- Drafter constructs all layers as full attention (NOT windowed) |
| 28 | +- Drafter "sees" full context, target sees windowed → distribution |
| 29 | + mismatch → target rejects more drafts → acceptance length collapses |
| 30 | +
|
| 31 | +After this fix: |
| 32 | +- layer_types preserved through config pipeline |
| 33 | +- Drafter constructs SWA layers with proper sliding_window |
| 34 | +- Drafter context matches target's view → distribution consistent → |
| 35 | + acceptance length 5.14 → 6.45 (+25%) per PR author measurement |
| 36 | +- For Genesis: this is the unblocker for 35B-A3B-DFlash >80K context |
| 37 | + (currently OOM at 200K because draft KV grows unbounded without SWA) |
| 38 | +
|
| 39 | +================================================================ |
| 40 | +COMPOSITION WITH PN24 |
| 41 | +================================================================ |
| 42 | +
|
| 43 | +Genesis PN24 (vllm#40727 backport) already adds `+1` shift to layer_ids |
| 44 | +in `gpu_model_runner._get_eagle3_aux_layers_from_config`. Upstream PR |
| 45 | +#40898 ALSO modifies that same function — adds `is_dflash` gate around |
| 46 | +the existing logic. The two edits target the same code region. |
| 47 | +
|
| 48 | +P-N21 strategy: it does NOT touch `gpu_model_runner.py`. PN24's `+1` |
| 49 | +shift is sufficient for our use case. P-N21 covers ONLY the 3 OTHER |
| 50 | +files (qwen3_dflash, algos, dflash). Both patches coexist cleanly. |
| 51 | +
|
| 52 | +If user enables P-N21 alone (without PN24): the `+1` shift is missing |
| 53 | +and layer_ids point to wrong layers. P-N21 dispatcher metadata |
| 54 | +declares `requires_patches=["PN24"]` to enforce the pairing. |
| 55 | +
|
| 56 | +================================================================ |
| 57 | +EMPIRICAL FINDING (2026-05-01, v7.65 dev) |
| 58 | +================================================================ |
| 59 | +
|
| 60 | +Validated on 35B-A3B-FP8-DFlash 160K, 7-city tool-call sweep: |
| 61 | +
|
| 62 | +- Baseline (PN21 OFF, PN22+PN23+PN24 ON): 7/7 tool-call clean |
| 63 | +- With PN21 ON (partial — algos.py + dflash.py only): 5-6/7 (3-run avg) |
| 64 | +
|
| 65 | +Regression matches the partial-backport caveat: when config preserves |
| 66 | +SWA but the model class doesn't construct windowed attention, the |
| 67 | +draft worker has metadata claiming SWA while computing full attention. |
| 68 | +That divergence shifts spec acceptance for tool-call tokens. |
| 69 | +
|
| 70 | +Decision: PN21 stays SHIPPED (file + dispatcher + apply_all entry) |
| 71 | +but DEFAULT OFF and NOT enabled in any launch script. Full enabler |
| 72 | +requires either upstream merge (vllm#40898) or manual qwen3_dflash.py |
| 73 | +edits (7+ sub-patches; high anchor-drift risk for text-patch). |
| 74 | +
|
| 75 | +================================================================ |
| 76 | +SAFETY MODEL |
| 77 | +================================================================ |
| 78 | +
|
| 79 | +- env: `GENESIS_ENABLE_PN21_DFLASH_SWA=1` |
| 80 | +- default OFF; opt-in. |
| 81 | +- empirical regression on 35B (5-6/7 vs 7/7 baseline) → DO NOT enable |
| 82 | + in production launch scripts until model class also patched. |
| 83 | +- Idempotent (3 separate marker checks per file). |
| 84 | +- Apply order: algos.py first (config preservation), then qwen3_dflash.py |
| 85 | + (model class), finally dflash.py (proposer metadata). |
| 86 | +- Each file is independent TextPatcher — failure on one logs but does |
| 87 | + not block others (best-effort for SWA support). |
| 88 | +- Auto-no-op once vllm#40898 merges (drift markers). |
| 89 | +
|
| 90 | +Author: backport for Genesis from jianc99's vllm#40898. |
| 91 | +""" |
| 92 | +from __future__ import annotations |
| 93 | + |
| 94 | +import logging |
| 95 | +import os |
| 96 | + |
| 97 | +from vllm._genesis.guards import resolve_vllm_file, vllm_install_root |
| 98 | +from vllm._genesis.wiring.text_patch import ( |
| 99 | + TextPatch, |
| 100 | + TextPatcher, |
| 101 | + TextPatchResult, |
| 102 | + result_to_wiring_status, |
| 103 | +) |
| 104 | + |
| 105 | +log = logging.getLogger("genesis.wiring.pn21_dflash_swa") |
| 106 | + |
| 107 | +GENESIS_PN21_MARKER = "Genesis PN21 DFlash SWA support v7.65" |
| 108 | + |
| 109 | + |
| 110 | +# ─── Sub-patch: speculators/algos.py — preserve SWA config ───────── |
| 111 | +PN21_ALGOS_ANCHOR = ( |
| 112 | + " aux_layer_ids = config_dict[\"aux_hidden_state_layer_ids\"]\n" |
| 113 | + " pre_trained_config[\"eagle_aux_hidden_state_layer_ids\"] = aux_layer_ids\n" |
| 114 | +) |
| 115 | + |
| 116 | +PN21_ALGOS_REPLACEMENT = ( |
| 117 | + " # [Genesis PN21] vllm#40898 backport — preserve SWA config\n" |
| 118 | + " for _genesis_pn21_key in (\n" |
| 119 | + " \"layer_types\",\n" |
| 120 | + " \"use_sliding_window\",\n" |
| 121 | + " \"sliding_window\",\n" |
| 122 | + " \"max_window_layers\",\n" |
| 123 | + " ):\n" |
| 124 | + " if _genesis_pn21_key in config_dict:\n" |
| 125 | + " pre_trained_config[_genesis_pn21_key] = config_dict[_genesis_pn21_key]\n" |
| 126 | + "\n" |
| 127 | + " aux_layer_ids = config_dict[\"aux_hidden_state_layer_ids\"]\n" |
| 128 | + " pre_trained_config[\"eagle_aux_hidden_state_layer_ids\"] = aux_layer_ids\n" |
| 129 | +) |
| 130 | + |
| 131 | + |
| 132 | +# ─── Sub-patch: dflash.py — causal=True for SWA layers ───────────── |
| 133 | +PN21_DFLASH_ANCHOR = ( |
| 134 | + " per_group, per_layer = super().build_per_group_and_layer_attn_metadata(\n" |
| 135 | + " cad, draft_index\n" |
| 136 | + " )\n" |
| 137 | + " for layer_name, attn_metadata in per_layer.items():\n" |
| 138 | + " assert getattr(attn_metadata, \"causal\", None) is False, (\n" |
| 139 | + " f\"Attention metadata for layer {layer_name} does not have\"\n" |
| 140 | + " \" non-causal support, which is required for DFlash.\"\n" |
| 141 | + " \" Consider using a different attention backend, such as FlashAttention.\"\n" |
| 142 | + " )\n" |
| 143 | + " return per_group, per_layer\n" |
| 144 | +) |
| 145 | + |
| 146 | +PN21_DFLASH_REPLACEMENT = ( |
| 147 | + " per_group, per_layer = super().build_per_group_and_layer_attn_metadata(\n" |
| 148 | + " cad, draft_index\n" |
| 149 | + " )\n" |
| 150 | + " # [Genesis PN21] vllm#40898 backport — SWA layers need causal=True\n" |
| 151 | + " _genesis_pn21_sliding = getattr(self.model, \"sliding_attention_layer_names\", set())\n" |
| 152 | + " if _genesis_pn21_sliding:\n" |
| 153 | + " _genesis_pn21_causal_cad = cad.replace(causal=True)\n" |
| 154 | + " for _genesis_pn21_grp in self.draft_attn_groups:\n" |
| 155 | + " _genesis_pn21_causal_layers = _genesis_pn21_sliding & set(_genesis_pn21_grp.layer_names)\n" |
| 156 | + " if not _genesis_pn21_causal_layers:\n" |
| 157 | + " continue\n" |
| 158 | + " _genesis_pn21_meta = _genesis_pn21_grp.get_metadata_builder().build_for_drafting(\n" |
| 159 | + " common_attn_metadata=_genesis_pn21_causal_cad, draft_index=draft_index\n" |
| 160 | + " )\n" |
| 161 | + " for _genesis_pn21_ln in _genesis_pn21_causal_layers:\n" |
| 162 | + " per_layer[_genesis_pn21_ln] = _genesis_pn21_meta\n" |
| 163 | + " for layer_name, attn_metadata in per_layer.items():\n" |
| 164 | + " if layer_name in _genesis_pn21_sliding:\n" |
| 165 | + " assert getattr(attn_metadata, \"causal\", None) is True, (\n" |
| 166 | + " f\"Attention metadata for sliding layer {layer_name} does not have\"\n" |
| 167 | + " \" causal support, which is required for DFlash SWA.\"\n" |
| 168 | + " )\n" |
| 169 | + " continue\n" |
| 170 | + " assert getattr(attn_metadata, \"causal\", None) is False, (\n" |
| 171 | + " f\"Attention metadata for layer {layer_name} does not have\"\n" |
| 172 | + " \" non-causal support, which is required for DFlash.\"\n" |
| 173 | + " \" Consider using a different attention backend, such as FlashAttention.\"\n" |
| 174 | + " )\n" |
| 175 | + " return per_group, per_layer\n" |
| 176 | +) |
| 177 | + |
| 178 | + |
| 179 | +def _apply_algos() -> tuple[str, str | None]: |
| 180 | + """Apply speculators/algos.py SWA config preservation.""" |
| 181 | + target = resolve_vllm_file("transformers_utils/configs/speculators/algos.py") |
| 182 | + if target is None or not os.path.isfile(str(target)): |
| 183 | + return "skipped", "speculators/algos.py not found" |
| 184 | + |
| 185 | + patcher = TextPatcher( |
| 186 | + patch_name="PN21 algos.py — preserve SWA config (vllm#40898)", |
| 187 | + target_file=str(target), |
| 188 | + marker=GENESIS_PN21_MARKER + " (algos)", |
| 189 | + sub_patches=[ |
| 190 | + TextPatch( |
| 191 | + name="pn21_algos_swa_preserve", |
| 192 | + anchor=PN21_ALGOS_ANCHOR, |
| 193 | + replacement=PN21_ALGOS_REPLACEMENT, |
| 194 | + required=True, |
| 195 | + ), |
| 196 | + ], |
| 197 | + upstream_drift_markers=[ |
| 198 | + "[Genesis PN21]", |
| 199 | + "_genesis_pn21_key", |
| 200 | + # Upstream merge — these keys appear directly |
| 201 | + "use_sliding_window", |
| 202 | + ], |
| 203 | + ) |
| 204 | + result, failure = patcher.apply() |
| 205 | + if result == TextPatchResult.APPLIED: |
| 206 | + return "applied", None |
| 207 | + if result == TextPatchResult.IDEMPOTENT: |
| 208 | + return "skipped", "already applied (marker present)" |
| 209 | + if result == TextPatchResult.SKIPPED: |
| 210 | + return "skipped", failure.reason if failure else "already applied" |
| 211 | + return "failed", failure.detail if failure else "unknown" |
| 212 | + |
| 213 | + |
| 214 | +def _apply_dflash() -> tuple[str, str | None]: |
| 215 | + """Apply dflash.py causal=True for SWA layers.""" |
| 216 | + target = resolve_vllm_file("v1/spec_decode/dflash.py") |
| 217 | + if target is None or not os.path.isfile(str(target)): |
| 218 | + return "skipped", "v1/spec_decode/dflash.py not found" |
| 219 | + |
| 220 | + patcher = TextPatcher( |
| 221 | + patch_name="PN21 dflash.py — SWA causal metadata (vllm#40898)", |
| 222 | + target_file=str(target), |
| 223 | + marker=GENESIS_PN21_MARKER + " (dflash)", |
| 224 | + sub_patches=[ |
| 225 | + TextPatch( |
| 226 | + name="pn21_dflash_swa_causal", |
| 227 | + anchor=PN21_DFLASH_ANCHOR, |
| 228 | + replacement=PN21_DFLASH_REPLACEMENT, |
| 229 | + required=True, |
| 230 | + ), |
| 231 | + ], |
| 232 | + upstream_drift_markers=[ |
| 233 | + "[Genesis PN21]", |
| 234 | + "sliding_attention_layer_names", |
| 235 | + ], |
| 236 | + ) |
| 237 | + result, failure = patcher.apply() |
| 238 | + if result == TextPatchResult.APPLIED: |
| 239 | + return "applied", None |
| 240 | + if result == TextPatchResult.IDEMPOTENT: |
| 241 | + return "skipped", "already applied (marker present)" |
| 242 | + if result == TextPatchResult.SKIPPED: |
| 243 | + return "skipped", failure.reason if failure else "already applied" |
| 244 | + return "failed", failure.detail if failure else "unknown" |
| 245 | + |
| 246 | + |
| 247 | +def apply() -> tuple[str, str]: |
| 248 | + """Apply PN21 — DFlash SWA support partial backport (algos + dflash files only). |
| 249 | +
|
| 250 | + qwen3_dflash.py model class changes are NOT backported here — they require |
| 251 | + 7+ sub-patches with multi-line context across the file (Attention __init__ |
| 252 | + signature + body, DecoderLayer __init__ + body, Model class init + property). |
| 253 | + The risk of anchor drift is high enough that we prefer the partial backport |
| 254 | + + waiting for upstream merge over a fragile big-patch. |
| 255 | +
|
| 256 | + Without the qwen3_dflash.py changes, the algos.py + dflash.py changes |
| 257 | + still preserve the SWA config and force causal=True on SWA layers — but |
| 258 | + the model class itself doesn't construct sliding-window attention layers, |
| 259 | + so the windowed compute does not happen. |
| 260 | +
|
| 261 | + => Genesis PN21 is currently a CONFIG-PRESERVING + METADATA-CORRECT but |
| 262 | + NOT a full SWA enabler. It positions the model for upstream merge to |
| 263 | + activate. |
| 264 | +
|
| 265 | + Operator path: enable PN21 + PN24 today, get partial benefit + future-proof |
| 266 | + against upstream merge auto-activation. When upstream PR #40898 merges, |
| 267 | + drift markers will detect and PN21 will auto-no-op cleanly. |
| 268 | + """ |
| 269 | + from vllm._genesis.dispatcher import should_apply, log_decision |
| 270 | + decision, reason = should_apply("PN21") |
| 271 | + log_decision("PN21", decision, reason) |
| 272 | + if not decision: |
| 273 | + return "skipped", reason |
| 274 | + |
| 275 | + if vllm_install_root() is None: |
| 276 | + return "skipped", "vllm install root not discoverable" |
| 277 | + |
| 278 | + results = [] |
| 279 | + for name, fn in [("algos", _apply_algos), ("dflash", _apply_dflash)]: |
| 280 | + status, detail = fn() |
| 281 | + results.append((name, status, detail)) |
| 282 | + log.info("[PN21:%s] %s%s", name, status, |
| 283 | + f" — {detail}" if detail else "") |
| 284 | + |
| 285 | + applied = [n for n, s, _ in results if s == "applied"] |
| 286 | + skipped = [n for n, s, _ in results if s == "skipped"] |
| 287 | + failed = [n for n, s, _ in results if s == "failed"] |
| 288 | + |
| 289 | + if failed: |
| 290 | + return "failed", ( |
| 291 | + f"PN21 partial: applied={applied}, skipped={skipped}, failed={failed}" |
| 292 | + ) |
| 293 | + if not applied: |
| 294 | + return "skipped", ( |
| 295 | + f"PN21 nothing to apply (already applied or anchors absent): {skipped}" |
| 296 | + ) |
| 297 | + return "applied", ( |
| 298 | + f"PN21 applied {applied} (DFlash SWA partial — algos.py preserves " |
| 299 | + f"layer_types/sliding_window config + dflash.py forces causal=True " |
| 300 | + f"on SWA layers). Skipped: {skipped}. Note: full SWA enabler in " |
| 301 | + f"qwen3_dflash.py model class deferred — wait for vllm#40898 merge " |
| 302 | + f"or apply manually. Composes with PN24." |
| 303 | + ) |
| 304 | + |
| 305 | + |
| 306 | +def is_applied() -> bool: |
| 307 | + target = resolve_vllm_file("transformers_utils/configs/speculators/algos.py") |
| 308 | + if target is None: return False |
| 309 | + try: |
| 310 | + with open(str(target)) as f: |
| 311 | + return GENESIS_PN21_MARKER in f.read() |
| 312 | + except OSError: |
| 313 | + return False |
0 commit comments