From a6e72e1e4fb450c80f15e09b9f09d5754635724e Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 25 Apr 2025 23:28:20 -0400 Subject: [PATCH] [Bugfix] [pytorch] Patch AOTAutogradCache._get_shape_env (#17142) Signed-off-by: James Wu --- vllm/compilation/compiler_interface.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 833be28926..bc9e421a66 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -195,7 +195,6 @@ class InductorAdaptor(CompilerInterface): hash_str, file_path = None, None from torch._inductor.codecache import (FxGraphCache, compiled_fx_graph_hash) - if torch.__version__.startswith("2.5"): original_load = FxGraphCache.load original_load_name = "torch._inductor.codecache.FxGraphCache.load" @@ -280,6 +279,16 @@ class InductorAdaptor(CompilerInterface): patch("torch._inductor.codecache.FxGraphCache._get_shape_env", _get_shape_env)) + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) + + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + _get_shape_env)) + # for forcing the graph to be cached stack.enter_context( patch( @@ -325,11 +334,19 @@ class InductorAdaptor(CompilerInterface): assert isinstance(handle[1], str) hash_str = handle[0] + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache) from torch._inductor.codecache import FxGraphCache with ExitStack() as exit_stack: exit_stack.enter_context( patch("torch._inductor.codecache.FxGraphCache._get_shape_env", lambda *args, **kwargs: AlwaysHitShapeEnv())) + # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache + if hasattr(AOTAutogradCache, "_get_shape_env"): + exit_stack.enter_context( + patch( + "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv())) # Dynamo metrics context, see method for more details. exit_stack.enter_context(self.metrics_context())