Skip to content

Commit 497e9af

Browse files
committed
Patch AOTAutogradCache._get_shape_env
Signed-off-by: James Wu <[email protected]>
1 parent 6d0df0e commit 497e9af

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

vllm/compilation/compiler_interface.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def compile(
195195
hash_str, file_path = None, None
196196
from torch._inductor.codecache import (FxGraphCache,
197197
compiled_fx_graph_hash)
198-
199198
if torch.__version__.startswith("2.5"):
200199
original_load = FxGraphCache.load
201200
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@@ -280,6 +279,14 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280279
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
281280
_get_shape_env))
282281

282+
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
283+
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
284+
if hasattr(AOTAutogradCache, "_get_shape_env"):
285+
stack.enter_context(
286+
patch(
287+
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
288+
_get_shape_env))
289+
283290
# for forcing the graph to be cached
284291
stack.enter_context(
285292
patch(
@@ -326,10 +333,17 @@ def load(self,
326333
hash_str = handle[0]
327334

328335
from torch._inductor.codecache import FxGraphCache
336+
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
329337
with ExitStack() as exit_stack:
330338
exit_stack.enter_context(
331339
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
332340
lambda *args, **kwargs: AlwaysHitShapeEnv()))
341+
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
342+
if hasattr(AOTAutogradCache, "_get_shape_env"):
343+
exit_stack.enter_context(
344+
patch(
345+
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
346+
lambda *args, **kwargs: AlwaysHitShapeEnv()))
333347

334348
# Dynamo metrics context, see method for more details.
335349
exit_stack.enter_context(self.metrics_context())

0 commit comments

Comments
 (0)