@@ -249,3 +249,110 @@ def prepare(self, topk_ids: torch.Tensor) -> ExpertWeightResult:
249249 w1_scale = self ._buf_w13_scale ,
250250 w2_scale = self ._buf_w2_scale ,
251251 )
252+
253+
254+ class LFRUCachedWeightProvider (CachedWeightProvider ):
255+ """GPU LFRU (Least Frequently + Recently Used) cache for MoE experts.
256+
257+ Extends CachedWeightProvider with frequency-weighted eviction.
258+ Standard LRU lets early layers monopolize the cache because they
259+ execute first every forward pass. LFRU tracks access frequency per
260+ expert and evicts the one with lowest score = frequency * recency.
261+
262+ On GPT-OSS-20B benchmarks, LFRU improved deep-layer (18-23) hit rate
263+ from 0-8% (LRU) to 52-94%. With 128 experts per layer (Gemma 4,
264+ Nemotron), the improvement is expected to be even larger.
265+
266+ Reference: vllm-project/vllm#37190 (e1n00r)
267+ """
268+
269+ def __init__ (self , * args , decay : float = 0.95 , ** kwargs ) -> None :
270+ super ().__init__ (* args , ** kwargs )
271+ # Frequency counter per expert (decayed over time)
272+ self ._freq : dict [int , float ] = {}
273+ # Monotonic step counter for recency scoring
274+ self ._step : int = 0
275+ # Last access step per expert
276+ self ._last_access : dict [int , int ] = {}
277+ # Decay factor: controls how fast old frequency decays
278+ self ._decay = decay
279+
280+ def _score (self , expert_id : int ) -> float :
281+ """Compute eviction score: lower = more likely to evict."""
282+ freq = self ._freq .get (expert_id , 0.0 )
283+ recency = self ._step - self ._last_access .get (expert_id , 0 )
284+ # Combine frequency and recency: high freq + recent = high score (keep)
285+ return freq / (1.0 + recency )
286+
287+ @torch .compiler .disable
288+ def prepare (self , topk_ids : torch .Tensor ) -> ExpertWeightResult :
289+ self ._step += 1
290+
291+ unique_ids = topk_ids .unique ().tolist ()
292+ if len (unique_ids ) > self .capacity :
293+ if not self ._overflow_warned :
294+ logger .warning (
295+ "LFRUCachedWeightProvider.prepare() called with %d unique "
296+ "experts but capacity is only %d. Truncating to last %d." ,
297+ len (unique_ids ), self .capacity , self .capacity ,
298+ )
299+ self ._overflow_warned = True
300+ unique_ids = unique_ids [- self .capacity :]
301+
302+ for expert_id in unique_ids :
303+ # Update frequency (decayed)
304+ self ._freq [expert_id ] = self ._freq .get (expert_id , 0.0 ) * self ._decay + 1.0
305+ self ._last_access [expert_id ] = self ._step
306+
307+ if expert_id in self ._lru :
308+ # Cache hit
309+ self ._lru .move_to_end (expert_id )
310+ self .hits += 1
311+ else :
312+ # Cache miss
313+ if self ._free_slots :
314+ slot = self ._free_slots .pop ()
315+ else :
316+ # Evict expert with lowest LFRU score
317+ min_score = float ("inf" )
318+ victim_id = next (iter (self ._lru )) # fallback to LRU
319+ for eid in self ._lru :
320+ s = self ._score (eid )
321+ if s < min_score :
322+ min_score = s
323+ victim_id = eid
324+ slot = self ._lru .pop (victim_id )
325+
326+ # Copy from CPU to GPU
327+ self ._buf_w13 [slot ].copy_ (self ._cpu_w13 [expert_id ])
328+ self ._buf_w2 [slot ].copy_ (self ._cpu_w2 [expert_id ])
329+ if self ._buf_w13_scale is not None :
330+ assert self ._cpu_w13_scale is not None
331+ assert self ._cpu_w2_scale is not None
332+ assert self ._buf_w2_scale is not None
333+ self ._buf_w13_scale [slot ].copy_ (self ._cpu_w13_scale [expert_id ])
334+ self ._buf_w2_scale [slot ].copy_ (self ._cpu_w2_scale [expert_id ])
335+
336+ self ._lru [expert_id ] = slot
337+ self ._mapping [expert_id ] = slot
338+ self .misses += 1
339+
340+ now = time .monotonic ()
341+ if now - self ._last_log_time >= 60.0 :
342+ self ._last_log_time = now
343+ total = self .hits + self .misses
344+ if total > 0 :
345+ logger .debug (
346+ "Expert LFRU cache: %d hits, %d misses (%.1f%% hit rate)" ,
347+ self .hits , self .misses , 100.0 * self .hits / total ,
348+ )
349+
350+ remapped_ids = self ._mapping [topk_ids .long ()].to (dtype = topk_ids .dtype )
351+
352+ return ExpertWeightResult (
353+ w1 = self ._buf_w13 ,
354+ w2 = self ._buf_w2 ,
355+ topk_ids = remapped_ids ,
356+ w1_scale = self ._buf_w13_scale ,
357+ w2_scale = self ._buf_w2_scale ,
358+ )
0 commit comments