[precompile] Pass tensor_to_context to backend. (#165702)

Summary:

Fixing a VLLM issue https://github.com/vllm-project/vllm/issues/27040 where
aot precompile fails on some models using symbolic shapes in inductor.

Test Plan:
pp HF_HUB_DISABLE_XET=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_USE_AOT_COMPILE=1 vllm bench latency --model microsoft/DialoGPT-small --input-len 128 --output-len 256 --num-iters 50 --dtype float16

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165702
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2025-10-17 21:52:01 +00:00
committed by PyTorch MergeBot
parent 8cb2fb44f2
commit 86ebce1766
2 changed files with 10 additions and 3 deletions

View File

@ -247,8 +247,10 @@ def aot_compile_fullgraph(
assert backend_input is not None
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
device_type = _graph_device_type(backend_input.graph_module.graph)
tracing_context = TracingContext(backend_input.fake_mode)
tracing_context.tensor_to_context = backend_input.tensor_to_context
with (
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
torch._guards.tracing(tracing_context),
torch._functorch.config.patch(
{
"bundled_autograd_cache": True,

View File

@ -176,6 +176,8 @@ except ModuleNotFoundError:
if typing.TYPE_CHECKING:
from torch.utils.weak import WeakIdKeyDictionary
from .backends.registry import CompilerFn
from .package import CompilePackage
from .repro.after_dynamo import WrapBackendDebug
@ -909,6 +911,7 @@ class BackendInput:
graph_module: torch.fx.GraphModule
example_inputs: Any
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
tensor_to_context: WeakIdKeyDictionary
@dataclass
@ -1080,11 +1083,13 @@ def _fullgraph_capture_frame(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
nonlocal backend_input
fake_mode = TracingContext.get().fake_mode
tracing_context = TracingContext.get()
fake_mode = tracing_context.fake_mode
tensor_to_context = tracing_context.tensor_to_context
assert fake_mode is not None
assert isinstance(gm.meta["backend_id"], str)
backend_input = BackendInput(
gm.meta["backend_id"], gm, example_inputs, fake_mode
gm.meta["backend_id"], gm, example_inputs, fake_mode, tensor_to_context
)
return gm