mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[precompile] Pass tensor_to_context to backend. (#165702)
Summary: Fixing a VLLM issue https://github.com/vllm-project/vllm/issues/27040 where aot precompile fails on some models using symbolic shapes in inductor. Test Plan: pp HF_HUB_DISABLE_XET=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_USE_AOT_COMPILE=1 vllm bench latency --model microsoft/DialoGPT-small --input-len 128 --output-len 256 --num-iters 50 --dtype float16 Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/165702 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
8cb2fb44f2
commit
86ebce1766
@ -247,8 +247,10 @@ def aot_compile_fullgraph(
|
||||
assert backend_input is not None
|
||||
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
|
||||
device_type = _graph_device_type(backend_input.graph_module.graph)
|
||||
tracing_context = TracingContext(backend_input.fake_mode)
|
||||
tracing_context.tensor_to_context = backend_input.tensor_to_context
|
||||
with (
|
||||
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
|
||||
torch._guards.tracing(tracing_context),
|
||||
torch._functorch.config.patch(
|
||||
{
|
||||
"bundled_autograd_cache": True,
|
||||
|
@ -176,6 +176,8 @@ except ModuleNotFoundError:
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from torch.utils.weak import WeakIdKeyDictionary
|
||||
|
||||
from .backends.registry import CompilerFn
|
||||
from .package import CompilePackage
|
||||
from .repro.after_dynamo import WrapBackendDebug
|
||||
@ -909,6 +911,7 @@ class BackendInput:
|
||||
graph_module: torch.fx.GraphModule
|
||||
example_inputs: Any
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
|
||||
tensor_to_context: WeakIdKeyDictionary
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1080,11 +1083,13 @@ def _fullgraph_capture_frame(
|
||||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
nonlocal backend_input
|
||||
fake_mode = TracingContext.get().fake_mode
|
||||
tracing_context = TracingContext.get()
|
||||
fake_mode = tracing_context.fake_mode
|
||||
tensor_to_context = tracing_context.tensor_to_context
|
||||
assert fake_mode is not None
|
||||
assert isinstance(gm.meta["backend_id"], str)
|
||||
backend_input = BackendInput(
|
||||
gm.meta["backend_id"], gm, example_inputs, fake_mode
|
||||
gm.meta["backend_id"], gm, example_inputs, fake_mode, tensor_to_context
|
||||
)
|
||||
return gm
|
||||
|
||||
|
Reference in New Issue
Block a user