Files
pytorch/torch/_inductor/compile_fx_async.py
James Wu fe954cdcbf Use correct boxed_forward_device_index when running CompiledFxGraph.post_compile (#148130)
This PR threads through the correct boxed_forward_device_index from graph_kwargs to CompiledFXGraph.post_compile. This allows us to correctly update BoxedDeviceIndex from cache hits.

We don't actually need to save `boxed_forward_device_index` in CompiledFXGraph because its value is in the cache key, so it always matches to the ambient one anyway. On forward with cudagraphs enabled, derive `boxed_forward_device_index`'s value from `device_idxs`.

Testing:

```
python benchmarks/dynamo/cachebench.py --mode training --benchmark torchbench --model BERT_pytorch --device cuda --repeat 1 --dynamic --output="dynamic.json"
```

Now cache hits properly on FXGraphCache. AOTAutogradCache has a guard failure. Will look into that as a followup.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148130
Approved by: https://github.com/eellison
2025-03-23 02:57:58 +00:00

182 lines
6.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing_extensions import final, override
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._inductor.output_code import CompiledFxGraphConstants, OutputCode
from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile
from .output_code import complex_memory_overlap as complex_memory_overlap # noqa: F401
if TYPE_CHECKING:
from collections.abc import Sequence
from concurrent.futures import Future
from torch._inductor.utils import InputType
from torch.fx import GraphModule
from .compile_fx_ext import _OutOfProcessFxCompile, _WireProtocolPickledOutput
@dataclass
class _PostCompileData:
example_inputs: Sequence[InputType]
constants: CompiledFxGraphConstants
graph_kwargs: _CompileFxKwargs
# _AsyncOutputCode handles the actual management of waiting for an
# out-of-process compile to finish and then switching over to it.
@final
class _AsyncOutputCode(OutputCode):
_eager_forward: Optional[Callable[..., Any]]
_output_code: Optional[OutputCode]
_future: Optional[Future[_WireProtocolPickledOutput]]
_callback: Callable[[_WireProtocolPickledOutput], OutputCode]
_post_compile_data: Optional[_PostCompileData] = None
_boxed_call: bool # Copied from the forward/output_code
def __init__(
self,
# eager_forward is run until the future is finished.
eager_forward: Callable[..., Any],
# this responds with the result of the out-of-process compile when it's
# ready.
future: Future[_WireProtocolPickledOutput],
# this callback gets called to turn the _WireProtocolPickledOutput into an OutputCode
callback: Callable[[_WireProtocolPickledOutput], OutputCode],
) -> None:
self._eager_forward = eager_forward
self._boxed_call = getattr(eager_forward, "_boxed_call", False)
self._output_code = None
self._future = future
self._callback = callback
@override
def __call__(self, *args: Any) -> Any:
if self._future is not None and self._future.done():
args = self._switch_to_compiled_forward(args)
if eager_forward := self._eager_forward:
_AsyncFxCompile._stat_eager_runs += 1
return eager_forward(*args)
else:
_AsyncFxCompile._stat_compiled_runs += 1
assert self._output_code is not None
return self._output_code.__call__(*args)
# Takes and returns the args (converted to the "right" boxed mode)
def _switch_to_compiled_forward(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
assert self._future is not None
# TODO: If the future ended in an exception do we want to continue
# running eager or hit the exception now?
f, self._future = self._future, None
output_code = self._callback(f.result())
if pcd := self._post_compile_data:
self._post_compile_data = None
output_code.post_compile(
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
)
self._output_code = output_code
self._eager_forward = None
boxed_call = getattr(output_code, "_boxed_call", False)
if self._boxed_call != boxed_call:
if self._boxed_call:
# Was boxed, now unboxed
args = args[0] if len(args) > 0 else ()
else:
# Was unboxed, now boxed
args = (args,)
self._boxed_call = boxed_call
return args
@override
def post_compile(
self,
example_inputs: Sequence[InputType],
constants: CompiledFxGraphConstants,
graph_kwargs: _CompileFxKwargs,
) -> None:
if self._eager_forward is not None:
self._post_compile_data = _PostCompileData(
example_inputs, constants, graph_kwargs
)
else:
assert self._output_code is not None
self._output_code.post_compile(example_inputs, constants, graph_kwargs)
# Given an FxCompile for an out-of-process compile _AsyncFxCompile will run
# eager until the compiled artifact is ready then it will automatically switch
# over to using the compiled version.
@final
class _AsyncFxCompile(FxCompile):
_compile: _OutOfProcessFxCompile
# Some debugging stats:
# Number of times we started a background compile.
_stat_bg_started: int = 0
# Number of times we finished a background compile.
_stat_bg_finished: int = 0
# Number of times we ran "eager"
_stat_eager_runs: int = 0
# Number of times we ran our compiled (out-of-process) artifact
_stat_compiled_runs: int = 0
def __init__(self, compile: _OutOfProcessFxCompile) -> None:
self._compile = compile
@classmethod
def _reset_stats(cls) -> None:
cls._stat_bg_started = 0
cls._stat_bg_finished = 0
cls._stat_eager_runs = 0
cls._stat_compiled_runs = 0
@override
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
eager_output_code = _InProcessFxCompile().codegen_and_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
# This is similar to _SerializedFxCompile.codegen_and_compile() but
# handles the async routing.
serialized = self._compile.serialize_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
if not serialized:
# We can't serialize - just return the eager OutputCode
return eager_output_code
inputs, constants = serialized
_AsyncFxCompile._stat_bg_started += 1
f = self._compile._send_to_child_async(inputs)
# This is called by _switch_to_compiled_forward() when f has a result...
def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode:
_AsyncFxCompile._stat_bg_finished += 1
output = pickled_output.deserialize(constants)
self._compile._postprocess(output)
return output.graph
return _AsyncOutputCode(eager_output_code, f, callback)