[Bugfix] [pytorch] Patch AOTAutogradCache._get_shape_env (#17142)

Signed-off-by: James Wu <jjwu@meta.com>
This commit is contained in:
James Wu
2025-04-25 23:28:20 -04:00
committed by GitHub
parent 5e83a7277f
commit a6e72e1e4f

View File

@ -195,7 +195,6 @@ class InductorAdaptor(CompilerInterface):
hash_str, file_path = None, None
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
if torch.__version__.startswith("2.5"):
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@ -280,6 +279,16 @@ class InductorAdaptor(CompilerInterface):
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
_get_shape_env))
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
_get_shape_env))
# for forcing the graph to be cached
stack.enter_context(
patch(
@ -325,11 +334,19 @@ class InductorAdaptor(CompilerInterface):
assert isinstance(handle[1], str)
hash_str = handle[0]
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCache)
from torch._inductor.codecache import FxGraphCache
with ExitStack() as exit_stack:
exit_stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
if hasattr(AOTAutogradCache, "_get_shape_env"):
exit_stack.enter_context(
patch(
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
lambda *args, **kwargs: AlwaysHitShapeEnv()))
# Dynamo metrics context, see method for more details.
exit_stack.enter_context(self.metrics_context())