@@ -195,7 +195,6 @@ def compile(
195
195
hash_str , file_path = None , None
196
196
from torch ._inductor .codecache import (FxGraphCache ,
197
197
compiled_fx_graph_hash )
198
-
199
198
if torch .__version__ .startswith ("2.5" ):
200
199
original_load = FxGraphCache .load
201
200
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@@ -280,6 +279,14 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280
279
patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
281
280
_get_shape_env ))
282
281
282
+ from torch ._functorch ._aot_autograd .autograd_cache import AOTAutogradCache
283
+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
284
+ if hasattr (AOTAutogradCache , "_get_shape_env" ):
285
+ stack .enter_context (
286
+ patch (
287
+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
288
+ _get_shape_env ))
289
+
283
290
# for forcing the graph to be cached
284
291
stack .enter_context (
285
292
patch (
@@ -326,10 +333,17 @@ def load(self,
326
333
hash_str = handle [0 ]
327
334
328
335
from torch ._inductor .codecache import FxGraphCache
336
+ from torch ._functorch ._aot_autograd .autograd_cache import AOTAutogradCache
329
337
with ExitStack () as exit_stack :
330
338
exit_stack .enter_context (
331
339
patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
332
340
lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
341
+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
342
+ if hasattr (AOTAutogradCache , "_get_shape_env" ):
343
+ exit_stack .enter_context (
344
+ patch (
345
+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
346
+ lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
333
347
334
348
# Dynamo metrics context, see method for more details.
335
349
exit_stack .enter_context (self .metrics_context ())
0 commit comments