[REFACTOR] Inline FxGraphCache.post_compile into sole call site (#141877)

I am going to break apart the arguments passed to the constituents
to only pass exactly what is needed, so easy access to the insides
is helpful here.

This also moves two helper functions to output_code.py as well.

Also set _boxed_call at constructor.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141877
Approved by: https://github.com/jamesjwu, https://github.com/jansel
This commit is contained in:
Edward Z. Yang
2024-12-02 07:46:42 -08:00
committed by PyTorch MergeBot
parent fe68f61c59
commit 61534391ba
3 changed files with 160 additions and 172 deletions

View File

@ -105,7 +105,6 @@ from torch._inductor.cpp_builder import (
normalize_path_separator,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
@ -114,12 +113,9 @@ from torch._inductor.runtime.compile_tasks import (
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import (
ALIGN_BYTES,
align_inputs_from_check_idxs,
BoxedBool,
clear_on_fresh_inductor_cache,
is_linux,
is_windows,
set_tracing_context_output_strides,
)
from torch._logging import trace_structured
from torch._subclasses.fake_tensor import (
@ -913,117 +909,6 @@ def compiled_fx_graph_hash(
return key, debug_lines
def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
runs it on compiled_graph.
Mutates the `compiled_graph.current_callable` and `cudagraphs`
"""
assert compiled_graph.current_callable is not None
assert compiled_graph.cudagraph_info is not None
cached_info = compiled_graph.cudagraph_info
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
inputs_to_check = compiled_graph.inputs_to_check
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
if not cudagraph_fail_reasons:
fx_kwargs = compiled_graph.fx_kwargs
static_input_idxs = fx_kwargs["static_input_idxs"]
placeholders = cached_info.placeholders
stack_traces = cached_info.stack_traces
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
from .compile_fx import cudagraphify
current_callable = compiled_graph.current_callable
assert current_callable is not None
compiled_graph.current_callable = cudagraphify(
current_callable,
static_input_idxs=static_input_idxs or (),
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
)
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
if "cuda" in compiled_graph.device_types:
# prefer better disable_cudagraphs_reason bc stack trace
# TODO: migrate all disable reasons to stack trace, refactor
if compiled_graph.disabled_cudagraphs_reason:
log_cudagraph_skip_and_bump_counter(
compiled_graph.disabled_cudagraphs_reason
)
else:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
)
def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
inputs_to_check: Sequence[int],
) -> None:
"""
Realigns input strides from inputs_to_check if
we didn't end up running cudagraphs. Mutates
`compiled_graph.current_callable` if cudagraphs
was run. Otherwise, does nothing.
"""
if not ran_cudagraphs:
assert compiled_graph.current_callable is not None
new_callable = align_inputs_from_check_idxs(
compiled_graph.current_callable, inputs_to_check
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int:
"""
Ephemerally increases the NCCL timeout when compiling for a distributed job
@ -1271,54 +1156,6 @@ class FxGraphCache:
)
return graph, cache_info
@staticmethod
def post_compile(
compiled_graph: CompiledFxGraph,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule] = None,
) -> CompiledFxGraph:
"""
Run a set of post processing steps after loading from the cache. These involve:
- Setting the tracing context output strides
- Running cudagraphs if enabled
- Realigning inputs
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
The results of this function are *not* saved in the cache itself.
"""
set_tracing_context_output_strides(example_inputs, compiled_graph)
if cudagraphs:
# It's possible that cudagraphs is enabled, but was disabled
# during a previous compilation we're loading from the cache.
# If so, we need to disable it on this new process too.
if compiled_graph.disabled_cudagraphs_reason:
if "cuda" in compiled_graph.device_types:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
)
else:
counters["inductor"]["cudagraph_skips"] += 1
BoxedBool.disable(cudagraphs)
else:
cudagraph_post_compile(
example_inputs,
compiled_graph,
cudagraphs,
gm,
)
inputs_to_check = compiled_graph.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
# so we still need to realign inputs if that happens
maybe_realign_inputs(
cudagraphs,
compiled_graph,
inputs_to_check,
)
return compiled_graph
@staticmethod
def _save_graph(
key: str,