Skip to content

Commit f746d07

Browse files
caio vicentinoclaude
authored andcommitted
Add LFRU (frequency-weighted LRU) expert cache eviction policy
Standard LRU lets early layers monopolize the cache because they execute first every forward pass. LFRU tracks per-expert access frequency (decayed) and evicts the expert with lowest score = freq / (1 + recency). On GPT-OSS-20B: deep-layer hit rate improved from 0-8% to 52-94%. Critical for models with 128 experts/layer (Gemma 4, Nemotron). LFRUCachedWeightProvider is a drop-in replacement for CachedWeightProvider. Ref: vllm-project#37190 (e1n00r LFRU findings) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d4eb8d5 commit f746d07

1 file changed

Lines changed: 107 additions & 0 deletions

File tree

vllm/model_executor/layers/fused_moe/expert_weight_provider.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,110 @@ def prepare(self, topk_ids: torch.Tensor) -> ExpertWeightResult:
249249
w1_scale=self._buf_w13_scale,
250250
w2_scale=self._buf_w2_scale,
251251
)
252+
253+
254+
class LFRUCachedWeightProvider(CachedWeightProvider):
255+
"""GPU LFRU (Least Frequently + Recently Used) cache for MoE experts.
256+
257+
Extends CachedWeightProvider with frequency-weighted eviction.
258+
Standard LRU lets early layers monopolize the cache because they
259+
execute first every forward pass. LFRU tracks access frequency per
260+
expert and evicts the one with lowest score = frequency * recency.
261+
262+
On GPT-OSS-20B benchmarks, LFRU improved deep-layer (18-23) hit rate
263+
from 0-8% (LRU) to 52-94%. With 128 experts per layer (Gemma 4,
264+
Nemotron), the improvement is expected to be even larger.
265+
266+
Reference: vllm-project/vllm#37190 (e1n00r)
267+
"""
268+
269+
def __init__(self, *args, decay: float = 0.95, **kwargs) -> None:
270+
super().__init__(*args, **kwargs)
271+
# Frequency counter per expert (decayed over time)
272+
self._freq: dict[int, float] = {}
273+
# Monotonic step counter for recency scoring
274+
self._step: int = 0
275+
# Last access step per expert
276+
self._last_access: dict[int, int] = {}
277+
# Decay factor: controls how fast old frequency decays
278+
self._decay = decay
279+
280+
def _score(self, expert_id: int) -> float:
281+
"""Compute eviction score: lower = more likely to evict."""
282+
freq = self._freq.get(expert_id, 0.0)
283+
recency = self._step - self._last_access.get(expert_id, 0)
284+
# Combine frequency and recency: high freq + recent = high score (keep)
285+
return freq / (1.0 + recency)
286+
287+
@torch.compiler.disable
288+
def prepare(self, topk_ids: torch.Tensor) -> ExpertWeightResult:
289+
self._step += 1
290+
291+
unique_ids = topk_ids.unique().tolist()
292+
if len(unique_ids) > self.capacity:
293+
if not self._overflow_warned:
294+
logger.warning(
295+
"LFRUCachedWeightProvider.prepare() called with %d unique "
296+
"experts but capacity is only %d. Truncating to last %d.",
297+
len(unique_ids), self.capacity, self.capacity,
298+
)
299+
self._overflow_warned = True
300+
unique_ids = unique_ids[-self.capacity:]
301+
302+
for expert_id in unique_ids:
303+
# Update frequency (decayed)
304+
self._freq[expert_id] = self._freq.get(expert_id, 0.0) * self._decay + 1.0
305+
self._last_access[expert_id] = self._step
306+
307+
if expert_id in self._lru:
308+
# Cache hit
309+
self._lru.move_to_end(expert_id)
310+
self.hits += 1
311+
else:
312+
# Cache miss
313+
if self._free_slots:
314+
slot = self._free_slots.pop()
315+
else:
316+
# Evict expert with lowest LFRU score
317+
min_score = float("inf")
318+
victim_id = next(iter(self._lru)) # fallback to LRU
319+
for eid in self._lru:
320+
s = self._score(eid)
321+
if s < min_score:
322+
min_score = s
323+
victim_id = eid
324+
slot = self._lru.pop(victim_id)
325+
326+
# Copy from CPU to GPU
327+
self._buf_w13[slot].copy_(self._cpu_w13[expert_id])
328+
self._buf_w2[slot].copy_(self._cpu_w2[expert_id])
329+
if self._buf_w13_scale is not None:
330+
assert self._cpu_w13_scale is not None
331+
assert self._cpu_w2_scale is not None
332+
assert self._buf_w2_scale is not None
333+
self._buf_w13_scale[slot].copy_(self._cpu_w13_scale[expert_id])
334+
self._buf_w2_scale[slot].copy_(self._cpu_w2_scale[expert_id])
335+
336+
self._lru[expert_id] = slot
337+
self._mapping[expert_id] = slot
338+
self.misses += 1
339+
340+
now = time.monotonic()
341+
if now - self._last_log_time >= 60.0:
342+
self._last_log_time = now
343+
total = self.hits + self.misses
344+
if total > 0:
345+
logger.debug(
346+
"Expert LFRU cache: %d hits, %d misses (%.1f%% hit rate)",
347+
self.hits, self.misses, 100.0 * self.hits / total,
348+
)
349+
350+
remapped_ids = self._mapping[topk_ids.long()].to(dtype=topk_ids.dtype)
351+
352+
return ExpertWeightResult(
353+
w1=self._buf_w13,
354+
w2=self._buf_w2,
355+
topk_ids=remapped_ids,
356+
w1_scale=self._buf_w13_scale,
357+
w2_scale=self._buf_w2_scale,
358+
)

0 commit comments

Comments
 (0)