|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from collections import OrderedDict |
3 | 4 | from collections.abc import Collection, Iterable |
4 | 5 | from typing import Literal |
5 | 6 |
|
@@ -198,3 +199,106 @@ def take_events(self) -> Iterable[OffloadingEvent]: |
198 | 199 | if self.events is not None: |
199 | 200 | yield from self.events |
200 | 201 | self.events.clear() |
| 202 | + |
| 203 | + |
| 204 | +class FilterReusedOffloadingManager(OffloadingManager): |
| 205 | + """An :class:`OffloadingManager` decorator that skips storing blocks |
| 206 | + whose reuse frequency is below *store_threshold*. |
| 207 | +
|
| 208 | + All methods are delegated to the *backing* manager. Two methods are |
| 209 | + intercepted: |
| 210 | +
|
| 211 | + * ``prepare_store`` — filters out keys that have not yet |
| 212 | + * ``lookup`` — records the visited key in an internal LRU |
| 213 | + counter, then delegates to the backing manager. |
| 214 | + crossed the threshold *before* calling the backing |
| 215 | + ``prepare_store``. |
| 216 | +
|
| 217 | + Args: |
| 218 | + backing: The underlying ``OffloadingManager`` to delegate to. |
| 219 | + store_threshold: A block must be seen at least this many times in |
| 220 | + ``lookup()`` before it is eligible for offloading. Must be >= 2 |
| 221 | + (a value of 1 would be equivalent to no filtering). |
| 222 | + max_tracker_size: Maximum entries in the internal tracker's LRU table. |
| 223 | + """ |
| 224 | + |
| 225 | + def __init__( |
| 226 | + self, |
| 227 | + backing: OffloadingManager, |
| 228 | + store_threshold: int = 2, |
| 229 | + max_tracker_size: int = 64_000, |
| 230 | + ): |
| 231 | + if store_threshold < 2: |
| 232 | + raise ValueError( |
| 233 | + "FilterReusedOffloadingManager store_threshold must be >= 2, " |
| 234 | + f"got {store_threshold}" |
| 235 | + ) |
| 236 | + if max_tracker_size < 1: |
| 237 | + raise ValueError( |
| 238 | + "FilterReusedOffloadingManager max_tracker_size must be >= 1, " |
| 239 | + f"got {max_tracker_size}" |
| 240 | + ) |
| 241 | + self._backing = backing |
| 242 | + self.store_threshold = store_threshold |
| 243 | + self.max_tracker_size = max_tracker_size |
| 244 | + # Ordered so we can evict the LRU entry in O(1). |
| 245 | + self.counts: OrderedDict[OffloadKey, int] = OrderedDict() |
| 246 | + |
| 247 | + # ------------------------------------------------------------------ |
| 248 | + # Intercepted methods |
| 249 | + # ------------------------------------------------------------------ |
| 250 | + |
| 251 | + def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: |
| 252 | + """Record the key, then delegate lookup to backing manager.""" |
| 253 | + if key in self.counts: |
| 254 | + self.counts.move_to_end(key) |
| 255 | + self.counts[key] += 1 |
| 256 | + else: |
| 257 | + if len(self.counts) >= self.max_tracker_size: |
| 258 | + self.counts.popitem(last=False) # evict LRU |
| 259 | + self.counts[key] = 1 |
| 260 | + return self._backing.lookup(key, req_context) |
| 261 | + |
| 262 | + def prepare_store( |
| 263 | + self, keys: Collection[OffloadKey], req_context: ReqContext |
| 264 | + ) -> PrepareStoreOutput | None: |
| 265 | + """Filter out blocks below threshold, then delegate to backing. |
| 266 | +
|
| 267 | + Filtering is evaluated *before* calling the backing manager's |
| 268 | + ``prepare_store`` so that blocks that would be skipped do not |
| 269 | + consume any CPU offload capacity. |
| 270 | + """ |
| 271 | + eligible = [ |
| 272 | + key for key in keys if self.counts.get(key, 0) >= self.store_threshold |
| 273 | + ] |
| 274 | + |
| 275 | + # Passing an empty list is intentional and safe — CPUOffloadingManager |
| 276 | + # handles it correctly, returning a PrepareStoreOutput with empty lists. |
| 277 | + # Delegate to the backing manager with only the eligible keys. |
| 278 | + return self._backing.prepare_store(eligible, req_context) |
| 279 | + |
| 280 | + # ------------------------------------------------------------------ |
| 281 | + # Delegated methods |
| 282 | + # ------------------------------------------------------------------ |
| 283 | + |
| 284 | + def prepare_load( |
| 285 | + self, keys: Collection[OffloadKey], req_context: ReqContext |
| 286 | + ) -> LoadStoreSpec: |
| 287 | + return self._backing.prepare_load(keys, req_context) |
| 288 | + |
| 289 | + def touch(self, keys: Collection[OffloadKey]) -> None: |
| 290 | + return self._backing.touch(keys) |
| 291 | + |
| 292 | + def complete_load(self, keys: Collection[OffloadKey]) -> None: |
| 293 | + return self._backing.complete_load(keys) |
| 294 | + |
| 295 | + def complete_store( |
| 296 | + self, keys: Collection[OffloadKey], success: bool = True |
| 297 | + ) -> None: |
| 298 | + return self._backing.complete_store(keys, success) |
| 299 | + |
| 300 | + def take_events(self) -> Iterable[OffloadingEvent]: |
| 301 | + return self._backing.take_events() |
| 302 | + |
| 303 | + def request_finished(self, req_id: str) -> bool: |
| 304 | + return self._backing.request_finished(req_id) |
0 commit comments