From 7666c8263a11048d8875a3aff0dedab0cf73821d Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 3 Dec 2024 18:06:09 -0800 Subject: [PATCH] [REFACTOR] Inline FxGraphCache.post_compile into sole call site (#141877) I am going to break apart the arguments passed to the constituents to only pass exactly what is needed, so easy access to the insides is helpful here. This also moves two helper functions to output_code.py as well. Also set _boxed_call at constructor. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/141877 Approved by: https://github.com/jamesjwu, https://github.com/jansel Co-authored-by: James Wu --- test/functorch/test_aotdispatch.py | 25 +-- .../_aot_autograd/autograd_cache.py | 5 +- torch/_inductor/codecache.py | 163 --------------- torch/_inductor/output_code.py | 186 +++++++++++++++++- 4 files changed, 191 insertions(+), 188 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 64eb0f630a76..cfe30a6e399e 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -44,6 +44,7 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._higher_order_ops.out_dtype import out_dtype from torch._inductor.codecache import compiled_fx_graph_hash +from torch._inductor.output_code import MockFXGraphCacheOutput from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv @@ -6768,26 +6769,21 @@ class MockFXGraphCache: self.cache[key] = gm def load(self, gm, inputs): - key, _ = compiled_fx_graph_hash(gm, inputs, {}, {}) - if key in self.cache: - gm = make_boxed_func(gm) - gm._fx_graph_cache_key = key - return gm - else: - self.save(key, gm) - gm = make_boxed_func(gm) - gm._fx_graph_cache_key = key - return gm + key, _ = compiled_fx_graph_hash(gm, inputs, {}, []) + if key not in self.cache: + self.cache[key] = gm + gm, _ = self.load_with_key(key, [], inputs, None, None, None) + return gm def load_with_key(self, key, debug_lines, inputs, local, remote_cache, is_backward): gm = self.cache.get(key) if gm is not None: gm = make_boxed_func(gm) + gm = MockFXGraphCacheOutput(gm) + gm._fx_graph_cache_key = key + gm._time_taken_ns = 0 return gm, {} - def post_compile(self, gm, inputs, cudagraphs): - return gm - # The following tests fail in strict caching mode (i.e. they bypass or # cache miss instead of cache hitting). They will be fixed in the PRs above this. @@ -6859,9 +6855,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): with patch( "torch._inductor.codecache.FxGraphCache.load_with_key", new=self.inductor_cache.load_with_key, - ), patch( - "torch._inductor.codecache.FxGraphCache.post_compile", - new=self.inductor_cache.post_compile, ): return super().verify_aot_autograd( f, diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 94ad6335b58a..f2a350bf898a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -354,8 +354,9 @@ class FXGraphCacheLoadable: payload_fn=lambda: json.dumps(cache_info), ) - FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type] - result._boxed_call = True + # TODO: How come cudagraphs could be None here? + # TODO: How come gm is None here? + result.post_compile(example_inputs, fx_config["cudagraphs"], None) # type: ignore[arg-type] return result diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 96d9474e39d0..c014cef0f348 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -100,7 +100,6 @@ from torch._inductor.cpp_builder import ( normalize_path_separator, ) from torch._inductor.cpu_vec_isa import pick_vec_isa -from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter from torch._inductor.runtime.compile_tasks import ( _module_to_triton_kernel, _reload_python_module, @@ -109,12 +108,9 @@ from torch._inductor.runtime.compile_tasks import ( from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir from torch._inductor.utils import ( ALIGN_BYTES, - align_inputs_from_check_idxs, - BoxedBool, clear_on_fresh_inductor_cache, is_linux, is_windows, - set_tracing_context_output_strides, ) from torch._logging import trace_structured from torch._subclasses.fake_tensor import ( @@ -908,117 +904,6 @@ def compiled_fx_graph_hash( return key, debug_lines -def cudagraph_post_compile( - example_inputs: Sequence[InputType], - compiled_graph: CompiledFxGraph, - cudagraphs: BoxedBool, - gm: Optional[torch.fx.GraphModule], -) -> None: - """ - Checks for any reasons not to run cudagraphs and then - runs it on compiled_graph. - Mutates the `compiled_graph.current_callable` and `cudagraphs` - """ - assert compiled_graph.current_callable is not None - assert compiled_graph.cudagraph_info is not None - cached_info = compiled_graph.cudagraph_info - cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons - inputs_to_check = compiled_graph.inputs_to_check - boxed_forward_device_index = compiled_graph.boxed_forward_device_index - is_inference = compiled_graph.fx_kwargs["is_inference"] - is_backward = compiled_graph.fx_kwargs["is_backward"] - - if not cudagraph_fail_reasons: - fx_kwargs = compiled_graph.fx_kwargs - static_input_idxs = fx_kwargs["static_input_idxs"] - - placeholders = cached_info.placeholders - stack_traces = cached_info.stack_traces - if not config.triton.cudagraph_trees: - # Force specialize all inputs so that CUDA graphs will work - for t in example_inputs: - if isinstance(t, torch.SymInt): - int(t) # guard - - if ( - boxed_forward_device_index is not None - and not is_inference - and not is_backward - ): - boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) - - from .compile_fx import cudagraphify - - current_callable = compiled_graph.current_callable - assert current_callable is not None - compiled_graph.current_callable = cudagraphify( - current_callable, - static_input_idxs=static_input_idxs or (), - device_index=next(iter(compiled_graph.device_idxs)), - stack_traces=stack_traces, - is_backward=is_backward, - is_inference=is_inference, - constants=tuple(compiled_graph.get_constants(gm).values()), - placeholders=placeholders, - mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), - ) - - else: - BoxedBool.disable(cudagraphs) - - # See [Backward Generation Handling] - # if cudagraph'd the forward and set the device, we need to let the cudagraph manager - # know we are we running the backward even if we will not run it in cudagraphs - if is_backward and config.triton.cudagraph_trees: - assert boxed_forward_device_index is not None - assert boxed_forward_device_index.value is not None - compiled_graph_callable = compiled_graph.current_callable - - manager = torch._inductor.cudagraph_trees.get_manager( - boxed_forward_device_index.value, create_if_none_exists=False - ) - # should already exist from forward - assert manager is not None - - def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]: - manager.set_to_running_backward() # type: ignore[union-attr] - return compiled_graph_callable(new_inputs) - - compiled_graph.current_callable = compiled_artifact - - if "cuda" in compiled_graph.device_types: - # prefer better disable_cudagraphs_reason bc stack trace - # TODO: migrate all disable reasons to stack trace, refactor - if compiled_graph.disabled_cudagraphs_reason: - log_cudagraph_skip_and_bump_counter( - compiled_graph.disabled_cudagraphs_reason - ) - else: - log_cudagraph_skip_and_bump_counter( - f"skipping cudagraphs due to {cudagraph_fail_reasons}" - ) - - -def maybe_realign_inputs( - ran_cudagraphs: BoxedBool, - compiled_graph: CompiledFxGraph, - inputs_to_check: Sequence[int], -) -> None: - """ - Realigns input strides from inputs_to_check if - we didn't end up running cudagraphs. Mutates - `compiled_graph.current_callable` if cudagraphs - was run. Otherwise, does nothing. - """ - if not ran_cudagraphs: - assert compiled_graph.current_callable is not None - new_callable = align_inputs_from_check_idxs( - compiled_graph.current_callable, inputs_to_check - ) - if new_callable is not compiled_graph.current_callable: - compiled_graph.current_callable = new_callable - - def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int: """ Ephemerally increases the NCCL timeout when compiling for a distributed job @@ -1236,54 +1121,6 @@ class FxGraphCache: ) return graph, cache_info - @staticmethod - def post_compile( - compiled_graph: CompiledFxGraph, - example_inputs: Sequence[InputType], - cudagraphs: BoxedBool, - gm: Optional[torch.fx.GraphModule] = None, - ) -> CompiledFxGraph: - """ - Run a set of post processing steps after loading from the cache. These involve: - - Setting the tracing context output strides - - Running cudagraphs if enabled - - Realigning inputs - - This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. - The results of this function are *not* saved in the cache itself. - """ - set_tracing_context_output_strides(example_inputs, compiled_graph) - - if cudagraphs: - # It's possible that cudagraphs is enabled, but was disabled - # during a previous compilation we're loading from the cache. - # If so, we need to disable it on this new process too. - if compiled_graph.disabled_cudagraphs_reason: - if "cuda" in compiled_graph.device_types: - log_cudagraph_skip_and_bump_counter( - f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}" - ) - else: - counters["inductor"]["cudagraph_skips"] += 1 - BoxedBool.disable(cudagraphs) - else: - cudagraph_post_compile( - example_inputs, - compiled_graph, - cudagraphs, - gm, - ) - inputs_to_check = compiled_graph.inputs_to_check - # cudagraphs could have been disabled from the earlier conditions - # so we still need to realign inputs if that happens - maybe_realign_inputs( - cudagraphs, - compiled_graph, - inputs_to_check, - ) - - return compiled_graph - @staticmethod def _save_graph( key: str, diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 10fa9f940284..69187615b6d9 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -50,10 +50,16 @@ from torch._inductor.cudagraph_utils import ( get_placeholder_info, log_cudagraph_skip_and_bump_counter, ) +from torch._inductor.utils import ( + align_inputs_from_check_idxs, + BoxedBool, + InputType, + output_node, + set_tracing_context_output_strides, +) from . import config from .runtime.autotune_cache import AutotuneCacheBundler -from .utils import BoxedBool, InputType, output_node if TYPE_CHECKING: @@ -138,6 +144,117 @@ def complex_memory_overlap(t: torch.Tensor) -> bool: return False +def cudagraph_post_compile( + example_inputs: Sequence[InputType], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, + gm: Optional[torch.fx.GraphModule], +) -> None: + """ + Checks for any reasons not to run cudagraphs and then + runs it on compiled_graph. + Mutates the `compiled_graph.current_callable` and `cudagraphs` + """ + assert compiled_graph.current_callable is not None + assert compiled_graph.cudagraph_info is not None + cached_info = compiled_graph.cudagraph_info + cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons + inputs_to_check = compiled_graph.inputs_to_check + boxed_forward_device_index = compiled_graph.boxed_forward_device_index + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + + if not cudagraph_fail_reasons: + fx_kwargs = compiled_graph.fx_kwargs + static_input_idxs = fx_kwargs["static_input_idxs"] + + placeholders = cached_info.placeholders + stack_traces = cached_info.stack_traces + if not config.triton.cudagraph_trees: + # Force specialize all inputs so that CUDA graphs will work + for t in example_inputs: + if isinstance(t, torch.SymInt): + int(t) # guard + + if ( + boxed_forward_device_index is not None + and not is_inference + and not is_backward + ): + boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) + + from .compile_fx import cudagraphify + + current_callable = compiled_graph.current_callable + assert current_callable is not None + compiled_graph.current_callable = cudagraphify( + current_callable, + static_input_idxs=static_input_idxs or (), + device_index=next(iter(compiled_graph.device_idxs)), + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(compiled_graph.get_constants(gm).values()), + placeholders=placeholders, + mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), + ) + + else: + BoxedBool.disable(cudagraphs) + + # See [Backward Generation Handling] + # if cudagraph'd the forward and set the device, we need to let the cudagraph manager + # know we are we running the backward even if we will not run it in cudagraphs + if is_backward and config.triton.cudagraph_trees: + assert boxed_forward_device_index is not None + assert boxed_forward_device_index.value is not None + compiled_graph_callable = compiled_graph.current_callable + + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_forward_device_index.value, create_if_none_exists=False + ) + # should already exist from forward + assert manager is not None + + def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]: + manager.set_to_running_backward() # type: ignore[union-attr] + return compiled_graph_callable(new_inputs) + + compiled_graph.current_callable = compiled_artifact + + if "cuda" in compiled_graph.device_types: + # prefer better disable_cudagraphs_reason bc stack trace + # TODO: migrate all disable reasons to stack trace, refactor + if compiled_graph.disabled_cudagraphs_reason: + log_cudagraph_skip_and_bump_counter( + compiled_graph.disabled_cudagraphs_reason + ) + else: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {cudagraph_fail_reasons}" + ) + + +def maybe_realign_inputs( + ran_cudagraphs: BoxedBool, + compiled_graph: CompiledFxGraph, + inputs_to_check: Sequence[int], +) -> None: + """ + Realigns input strides from inputs_to_check if + we didn't end up running cudagraphs. Mutates + `compiled_graph.current_callable` if cudagraphs + was run. Otherwise, does nothing. + """ + if not ran_cudagraphs: + assert compiled_graph.current_callable is not None + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check + ) + if new_callable is not compiled_graph.current_callable: + compiled_graph.current_callable = new_callable + + @dataclasses.dataclass class CompiledFxGraph(OutputCode): """ @@ -295,6 +412,9 @@ class CompiledFxGraph(OutputCode): # TODO: should this be part of fx_kwargs self.boxed_forward_device_index = boxed_forward_device_index + # aot autograd needs to know to pass in inputs as a list + self._boxed_call = True + def __call__(self, inputs: Sequence[Any]) -> Any: assert self.current_callable is not None try: @@ -308,14 +428,44 @@ class CompiledFxGraph(OutputCode): cudagraphs: BoxedBool, gm: GraphModule, ) -> None: - # TODO: maybe move this here? Not sure. - from torch._inductor.codecache import FxGraphCache + """ + Run a set of post processing steps after loading from the cache. These involve: + - Setting the tracing context output strides + - Running cudagraphs if enabled + - Realigning inputs - FxGraphCache.post_compile(self, example_inputs, cudagraphs, gm) + This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. + The results of this function are *not* saved in the cache itself. + """ + set_tracing_context_output_strides(example_inputs, self) - # aot autograd needs to know to pass in inputs as a list - # TODO: Not sure why this isn't just set by default on CompiledFxGraph - self._boxed_call = True + if cudagraphs: + # It's possible that cudagraphs is enabled, but was disabled + # during a previous compilation we're loading from the cache. + # If so, we need to disable it on this new process too. + if self.disabled_cudagraphs_reason: + if "cuda" in self.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + cudagraph_post_compile( + example_inputs, + self, + cudagraphs, + gm, + ) + inputs_to_check = self.inputs_to_check + # cudagraphs could have been disabled from the earlier conditions + # so we still need to realign inputs if that happens + maybe_realign_inputs( + cudagraphs, + self, + inputs_to_check, + ) def set_triton_bundle(self, triton_bundle: Any) -> None: self._triton_bundle = triton_bundle @@ -428,3 +578,25 @@ class CompiledAOTI(OutputCode): def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode: return h + + +@dataclasses.dataclass +class MockFXGraphCacheOutput(OutputCode): + gm: Any = None + + def __post_init__(self) -> None: + self._boxed_call = True + + def post_compile( + self, + example_inputs: Sequence[InputType], + cudagraphs: BoxedBool, + gm: GraphModule, + ) -> None: + pass + + def __call__(self, inputs: Sequence[Any]) -> Any: + return self.gm(inputs) + + def set_triton_bundle(self, triton_bundle: Any) -> None: + pass