Revert "[REFACTOR] Inline FxGraphCache.post_compile into sole call site (#141877)"

This reverts commit 3ab4a28eaa7dc67d5c46c2016bbfe9932b36de06.

Reverted https://github.com/pytorch/pytorch/pull/141877 on behalf of https://github.com/huydhn due to Job are failing en masse after this lands, so it looks like a land race ([comment](https://github.com/pytorch/pytorch/pull/141877#issuecomment-2513552752))
This commit is contained in:
PyTorch MergeBot
2024-12-03 04:57:58 +00:00
parent 38bbe37187
commit 2999dbfd21
4 changed files with 188 additions and 195 deletions

View File

@ -44,7 +44,6 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._inductor.output_code import MockFXGraphCacheOutput
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
@ -6733,19 +6732,26 @@ class MockFXGraphCache:
self.cache[key] = gm
def load(self, gm, inputs):
key, _ = compiled_fx_graph_hash(gm, inputs, {}, [])
if key not in self.cache:
self.cache[key] = gm
gm, _ = self.load_with_key(key, [], inputs, None, None, None)
return gm
key, _ = compiled_fx_graph_hash(gm, inputs, {}, {})
if key in self.cache:
gm = make_boxed_func(gm)
gm._fx_graph_cache_key = key
return gm
else:
self.save(key, gm)
gm = make_boxed_func(gm)
gm._fx_graph_cache_key = key
return gm
def load_with_key(self, key, debug_lines, inputs, local, remote_cache, is_backward):
gm = self.cache.get(key)
if gm is not None:
gm = make_boxed_func(gm)
gm = MockFXGraphCacheOutput(gm, key)
return gm, {}
def post_compile(self, gm, inputs, cudagraphs):
return gm
# The following tests fail in strict caching mode (i.e. they bypass or
# cache miss instead of cache hitting). They will be fixed in the PRs above this.
@ -6817,6 +6823,9 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
with patch(
"torch._inductor.codecache.FxGraphCache.load_with_key",
new=self.inductor_cache.load_with_key,
), patch(
"torch._inductor.codecache.FxGraphCache.post_compile",
new=self.inductor_cache.post_compile,
):
return super().verify_aot_autograd(
f,

View File

@ -354,9 +354,8 @@ class FXGraphCacheLoadable:
payload_fn=lambda: json.dumps(cache_info),
)
# TODO: How come cudagraphs could be None here?
# TODO: How come gm is None here?
result.post_compile(example_inputs, fx_config["cudagraphs"], None) # type: ignore[arg-type]
FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type]
result._boxed_call = True
return result

View File

@ -105,6 +105,7 @@ from torch._inductor.cpp_builder import (
normalize_path_separator,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
@ -113,9 +114,12 @@ from torch._inductor.runtime.compile_tasks import (
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import (
ALIGN_BYTES,
align_inputs_from_check_idxs,
BoxedBool,
clear_on_fresh_inductor_cache,
is_linux,
is_windows,
set_tracing_context_output_strides,
)
from torch._logging import trace_structured
from torch._subclasses.fake_tensor import (
@ -909,6 +913,117 @@ def compiled_fx_graph_hash(
return key, debug_lines
def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
runs it on compiled_graph.
Mutates the `compiled_graph.current_callable` and `cudagraphs`
"""
assert compiled_graph.current_callable is not None
assert compiled_graph.cudagraph_info is not None
cached_info = compiled_graph.cudagraph_info
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
inputs_to_check = compiled_graph.inputs_to_check
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
if not cudagraph_fail_reasons:
fx_kwargs = compiled_graph.fx_kwargs
static_input_idxs = fx_kwargs["static_input_idxs"]
placeholders = cached_info.placeholders
stack_traces = cached_info.stack_traces
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
from .compile_fx import cudagraphify
current_callable = compiled_graph.current_callable
assert current_callable is not None
compiled_graph.current_callable = cudagraphify(
current_callable,
static_input_idxs=static_input_idxs or (),
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
)
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
if "cuda" in compiled_graph.device_types:
# prefer better disable_cudagraphs_reason bc stack trace
# TODO: migrate all disable reasons to stack trace, refactor
if compiled_graph.disabled_cudagraphs_reason:
log_cudagraph_skip_and_bump_counter(
compiled_graph.disabled_cudagraphs_reason
)
else:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
)
def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
inputs_to_check: Sequence[int],
) -> None:
"""
Realigns input strides from inputs_to_check if
we didn't end up running cudagraphs. Mutates
`compiled_graph.current_callable` if cudagraphs
was run. Otherwise, does nothing.
"""
if not ran_cudagraphs:
assert compiled_graph.current_callable is not None
new_callable = align_inputs_from_check_idxs(
compiled_graph.current_callable, inputs_to_check
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int:
"""
Ephemerally increases the NCCL timeout when compiling for a distributed job
@ -1156,6 +1271,54 @@ class FxGraphCache:
)
return graph, cache_info
@staticmethod
def post_compile(
compiled_graph: CompiledFxGraph,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule] = None,
) -> CompiledFxGraph:
"""
Run a set of post processing steps after loading from the cache. These involve:
- Setting the tracing context output strides
- Running cudagraphs if enabled
- Realigning inputs
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
The results of this function are *not* saved in the cache itself.
"""
set_tracing_context_output_strides(example_inputs, compiled_graph)
if cudagraphs:
# It's possible that cudagraphs is enabled, but was disabled
# during a previous compilation we're loading from the cache.
# If so, we need to disable it on this new process too.
if compiled_graph.disabled_cudagraphs_reason:
if "cuda" in compiled_graph.device_types:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
)
else:
counters["inductor"]["cudagraph_skips"] += 1
BoxedBool.disable(cudagraphs)
else:
cudagraph_post_compile(
example_inputs,
compiled_graph,
cudagraphs,
gm,
)
inputs_to_check = compiled_graph.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
# so we still need to realign inputs if that happens
maybe_realign_inputs(
cudagraphs,
compiled_graph,
inputs_to_check,
)
return compiled_graph
@staticmethod
def _save_graph(
key: str,

View File

@ -46,16 +46,10 @@ from torch._inductor.cudagraph_utils import (
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.utils import (
align_inputs_from_check_idxs,
BoxedBool,
InputType,
output_node,
set_tracing_context_output_strides,
)
from . import config
from .runtime.autotune_cache import AutotuneCacheBundler
from .utils import BoxedBool, InputType, output_node
if TYPE_CHECKING:
@ -138,117 +132,6 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
return False
def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
runs it on compiled_graph.
Mutates the `compiled_graph.current_callable` and `cudagraphs`
"""
assert compiled_graph.current_callable is not None
assert compiled_graph.cudagraph_info is not None
cached_info = compiled_graph.cudagraph_info
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
inputs_to_check = compiled_graph.inputs_to_check
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
if not cudagraph_fail_reasons:
fx_kwargs = compiled_graph.fx_kwargs
static_input_idxs = fx_kwargs["static_input_idxs"]
placeholders = cached_info.placeholders
stack_traces = cached_info.stack_traces
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
from .compile_fx import cudagraphify
current_callable = compiled_graph.current_callable
assert current_callable is not None
compiled_graph.current_callable = cudagraphify(
current_callable,
static_input_idxs=static_input_idxs or (),
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
)
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
if "cuda" in compiled_graph.device_types:
# prefer better disable_cudagraphs_reason bc stack trace
# TODO: migrate all disable reasons to stack trace, refactor
if compiled_graph.disabled_cudagraphs_reason:
log_cudagraph_skip_and_bump_counter(
compiled_graph.disabled_cudagraphs_reason
)
else:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
)
def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
inputs_to_check: Sequence[int],
) -> None:
"""
Realigns input strides from inputs_to_check if
we didn't end up running cudagraphs. Mutates
`compiled_graph.current_callable` if cudagraphs
was run. Otherwise, does nothing.
"""
if not ran_cudagraphs:
assert compiled_graph.current_callable is not None
new_callable = align_inputs_from_check_idxs(
compiled_graph.current_callable, inputs_to_check
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
@dataclasses.dataclass
class CompiledFxGraph(OutputCode):
"""
@ -406,9 +289,6 @@ class CompiledFxGraph(OutputCode):
# TODO: should this be part of fx_kwargs
self.boxed_forward_device_index = boxed_forward_device_index
# aot autograd needs to know to pass in inputs as a list
self._boxed_call = True
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
@ -422,44 +302,14 @@ class CompiledFxGraph(OutputCode):
cudagraphs: BoxedBool,
gm: GraphModule,
) -> None:
"""
Run a set of post processing steps after loading from the cache. These involve:
- Setting the tracing context output strides
- Running cudagraphs if enabled
- Realigning inputs
# TODO: maybe move this here? Not sure.
from torch._inductor.codecache import FxGraphCache
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
The results of this function are *not* saved in the cache itself.
"""
set_tracing_context_output_strides(example_inputs, self)
FxGraphCache.post_compile(self, example_inputs, cudagraphs, gm)
if cudagraphs:
# It's possible that cudagraphs is enabled, but was disabled
# during a previous compilation we're loading from the cache.
# If so, we need to disable it on this new process too.
if self.disabled_cudagraphs_reason:
if "cuda" in self.device_types:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}"
)
else:
counters["inductor"]["cudagraph_skips"] += 1
BoxedBool.disable(cudagraphs)
else:
cudagraph_post_compile(
example_inputs,
self,
cudagraphs,
gm,
)
inputs_to_check = self.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
# so we still need to realign inputs if that happens
maybe_realign_inputs(
cudagraphs,
self,
inputs_to_check,
)
# aot autograd needs to know to pass in inputs as a list
# TODO: Not sure why this isn't just set by default on CompiledFxGraph
self._boxed_call = True
def set_triton_bundle(self, triton_bundle: Any) -> None:
self._triton_bundle = triton_bundle
@ -514,31 +364,3 @@ class CompiledAOTI(OutputCode):
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
return h
@dataclasses.dataclass
class MockFXGraphCacheOutput(OutputCode):
gm: Any
_fx_graph_cache_key: Optional[str]
# How long it took to compile this OutputCode, end to end
_time_taken_ns: Optional[int]
def __init__(self, gm: Any, key: Optional[str]) -> None:
self.gm = gm
self._fx_graph_cache_key = key
self._time_taken_ns = 0
self._boxed_call = True
def post_compile(
self,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: GraphModule,
) -> None:
pass
def __call__(self, inputs: Sequence[Any]) -> Any:
return self.gm(inputs)
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass