mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[REFACTOR] Inline FxGraphCache.post_compile into sole call site (#141877)
I am going to break apart the arguments passed to the constituents to only pass exactly what is needed, so easy access to the insides is helpful here. This also moves two helper functions to output_code.py as well. Also set _boxed_call at constructor. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/141877 Approved by: https://github.com/jamesjwu, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe68f61c59
commit
61534391ba
@ -105,7 +105,6 @@ from torch._inductor.cpp_builder import (
|
||||
normalize_path_separator,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import pick_vec_isa
|
||||
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_module_to_triton_kernel,
|
||||
_reload_python_module,
|
||||
@ -114,12 +113,9 @@ from torch._inductor.runtime.compile_tasks import (
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||
from torch._inductor.utils import (
|
||||
ALIGN_BYTES,
|
||||
align_inputs_from_check_idxs,
|
||||
BoxedBool,
|
||||
clear_on_fresh_inductor_cache,
|
||||
is_linux,
|
||||
is_windows,
|
||||
set_tracing_context_output_strides,
|
||||
)
|
||||
from torch._logging import trace_structured
|
||||
from torch._subclasses.fake_tensor import (
|
||||
@ -913,117 +909,6 @@ def compiled_fx_graph_hash(
|
||||
return key, debug_lines
|
||||
|
||||
|
||||
def cudagraph_post_compile(
|
||||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule],
|
||||
) -> None:
|
||||
"""
|
||||
Checks for any reasons not to run cudagraphs and then
|
||||
runs it on compiled_graph.
|
||||
Mutates the `compiled_graph.current_callable` and `cudagraphs`
|
||||
"""
|
||||
assert compiled_graph.current_callable is not None
|
||||
assert compiled_graph.cudagraph_info is not None
|
||||
cached_info = compiled_graph.cudagraph_info
|
||||
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
|
||||
inputs_to_check = compiled_graph.inputs_to_check
|
||||
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
||||
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
|
||||
if not cudagraph_fail_reasons:
|
||||
fx_kwargs = compiled_graph.fx_kwargs
|
||||
static_input_idxs = fx_kwargs["static_input_idxs"]
|
||||
|
||||
placeholders = cached_info.placeholders
|
||||
stack_traces = cached_info.stack_traces
|
||||
if not config.triton.cudagraph_trees:
|
||||
# Force specialize all inputs so that CUDA graphs will work
|
||||
for t in example_inputs:
|
||||
if isinstance(t, torch.SymInt):
|
||||
int(t) # guard
|
||||
|
||||
if (
|
||||
boxed_forward_device_index is not None
|
||||
and not is_inference
|
||||
and not is_backward
|
||||
):
|
||||
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
||||
|
||||
from .compile_fx import cudagraphify
|
||||
|
||||
current_callable = compiled_graph.current_callable
|
||||
assert current_callable is not None
|
||||
compiled_graph.current_callable = cudagraphify(
|
||||
current_callable,
|
||||
static_input_idxs=static_input_idxs or (),
|
||||
device_index=next(iter(compiled_graph.device_idxs)),
|
||||
stack_traces=stack_traces,
|
||||
is_backward=is_backward,
|
||||
is_inference=is_inference,
|
||||
constants=tuple(compiled_graph.get_constants(gm).values()),
|
||||
placeholders=placeholders,
|
||||
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
||||
)
|
||||
|
||||
else:
|
||||
BoxedBool.disable(cudagraphs)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
||||
# know we are we running the backward even if we will not run it in cudagraphs
|
||||
if is_backward and config.triton.cudagraph_trees:
|
||||
assert boxed_forward_device_index is not None
|
||||
assert boxed_forward_device_index.value is not None
|
||||
compiled_graph_callable = compiled_graph.current_callable
|
||||
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_forward_device_index.value, create_if_none_exists=False
|
||||
)
|
||||
# should already exist from forward
|
||||
assert manager is not None
|
||||
|
||||
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
|
||||
manager.set_to_running_backward() # type: ignore[union-attr]
|
||||
return compiled_graph_callable(new_inputs)
|
||||
|
||||
compiled_graph.current_callable = compiled_artifact
|
||||
|
||||
if "cuda" in compiled_graph.device_types:
|
||||
# prefer better disable_cudagraphs_reason bc stack trace
|
||||
# TODO: migrate all disable reasons to stack trace, refactor
|
||||
if compiled_graph.disabled_cudagraphs_reason:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
compiled_graph.disabled_cudagraphs_reason
|
||||
)
|
||||
else:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
|
||||
)
|
||||
|
||||
|
||||
def maybe_realign_inputs(
|
||||
ran_cudagraphs: BoxedBool,
|
||||
compiled_graph: CompiledFxGraph,
|
||||
inputs_to_check: Sequence[int],
|
||||
) -> None:
|
||||
"""
|
||||
Realigns input strides from inputs_to_check if
|
||||
we didn't end up running cudagraphs. Mutates
|
||||
`compiled_graph.current_callable` if cudagraphs
|
||||
was run. Otherwise, does nothing.
|
||||
"""
|
||||
if not ran_cudagraphs:
|
||||
assert compiled_graph.current_callable is not None
|
||||
new_callable = align_inputs_from_check_idxs(
|
||||
compiled_graph.current_callable, inputs_to_check
|
||||
)
|
||||
if new_callable is not compiled_graph.current_callable:
|
||||
compiled_graph.current_callable = new_callable
|
||||
|
||||
|
||||
def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int:
|
||||
"""
|
||||
Ephemerally increases the NCCL timeout when compiling for a distributed job
|
||||
@ -1271,54 +1156,6 @@ class FxGraphCache:
|
||||
)
|
||||
return graph, cache_info
|
||||
|
||||
@staticmethod
|
||||
def post_compile(
|
||||
compiled_graph: CompiledFxGraph,
|
||||
example_inputs: Sequence[InputType],
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule] = None,
|
||||
) -> CompiledFxGraph:
|
||||
"""
|
||||
Run a set of post processing steps after loading from the cache. These involve:
|
||||
- Setting the tracing context output strides
|
||||
- Running cudagraphs if enabled
|
||||
- Realigning inputs
|
||||
|
||||
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
|
||||
The results of this function are *not* saved in the cache itself.
|
||||
"""
|
||||
set_tracing_context_output_strides(example_inputs, compiled_graph)
|
||||
|
||||
if cudagraphs:
|
||||
# It's possible that cudagraphs is enabled, but was disabled
|
||||
# during a previous compilation we're loading from the cache.
|
||||
# If so, we need to disable it on this new process too.
|
||||
if compiled_graph.disabled_cudagraphs_reason:
|
||||
if "cuda" in compiled_graph.device_types:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
|
||||
)
|
||||
else:
|
||||
counters["inductor"]["cudagraph_skips"] += 1
|
||||
BoxedBool.disable(cudagraphs)
|
||||
else:
|
||||
cudagraph_post_compile(
|
||||
example_inputs,
|
||||
compiled_graph,
|
||||
cudagraphs,
|
||||
gm,
|
||||
)
|
||||
inputs_to_check = compiled_graph.inputs_to_check
|
||||
# cudagraphs could have been disabled from the earlier conditions
|
||||
# so we still need to realign inputs if that happens
|
||||
maybe_realign_inputs(
|
||||
cudagraphs,
|
||||
compiled_graph,
|
||||
inputs_to_check,
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
@staticmethod
|
||||
def _save_graph(
|
||||
key: str,
|
||||
|
Reference in New Issue
Block a user