mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[REFACTOR] Inline FxGraphCache.post_compile into sole call site (#141877)"
This reverts commit 3ab4a28eaa7dc67d5c46c2016bbfe9932b36de06. Reverted https://github.com/pytorch/pytorch/pull/141877 on behalf of https://github.com/huydhn due to Job are failing en masse after this lands, so it looks like a land race ([comment](https://github.com/pytorch/pytorch/pull/141877#issuecomment-2513552752))
This commit is contained in:
@ -44,7 +44,6 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
from torch._inductor.output_code import MockFXGraphCacheOutput
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
|
||||
@ -6733,19 +6732,26 @@ class MockFXGraphCache:
|
||||
self.cache[key] = gm
|
||||
|
||||
def load(self, gm, inputs):
|
||||
key, _ = compiled_fx_graph_hash(gm, inputs, {}, [])
|
||||
if key not in self.cache:
|
||||
self.cache[key] = gm
|
||||
gm, _ = self.load_with_key(key, [], inputs, None, None, None)
|
||||
return gm
|
||||
key, _ = compiled_fx_graph_hash(gm, inputs, {}, {})
|
||||
if key in self.cache:
|
||||
gm = make_boxed_func(gm)
|
||||
gm._fx_graph_cache_key = key
|
||||
return gm
|
||||
else:
|
||||
self.save(key, gm)
|
||||
gm = make_boxed_func(gm)
|
||||
gm._fx_graph_cache_key = key
|
||||
return gm
|
||||
|
||||
def load_with_key(self, key, debug_lines, inputs, local, remote_cache, is_backward):
|
||||
gm = self.cache.get(key)
|
||||
if gm is not None:
|
||||
gm = make_boxed_func(gm)
|
||||
gm = MockFXGraphCacheOutput(gm, key)
|
||||
return gm, {}
|
||||
|
||||
def post_compile(self, gm, inputs, cudagraphs):
|
||||
return gm
|
||||
|
||||
|
||||
# The following tests fail in strict caching mode (i.e. they bypass or
|
||||
# cache miss instead of cache hitting). They will be fixed in the PRs above this.
|
||||
@ -6817,6 +6823,9 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
||||
with patch(
|
||||
"torch._inductor.codecache.FxGraphCache.load_with_key",
|
||||
new=self.inductor_cache.load_with_key,
|
||||
), patch(
|
||||
"torch._inductor.codecache.FxGraphCache.post_compile",
|
||||
new=self.inductor_cache.post_compile,
|
||||
):
|
||||
return super().verify_aot_autograd(
|
||||
f,
|
||||
|
@ -354,9 +354,8 @@ class FXGraphCacheLoadable:
|
||||
payload_fn=lambda: json.dumps(cache_info),
|
||||
)
|
||||
|
||||
# TODO: How come cudagraphs could be None here?
|
||||
# TODO: How come gm is None here?
|
||||
result.post_compile(example_inputs, fx_config["cudagraphs"], None) # type: ignore[arg-type]
|
||||
FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type]
|
||||
result._boxed_call = True
|
||||
return result
|
||||
|
||||
|
||||
|
@ -105,6 +105,7 @@ from torch._inductor.cpp_builder import (
|
||||
normalize_path_separator,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import pick_vec_isa
|
||||
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_module_to_triton_kernel,
|
||||
_reload_python_module,
|
||||
@ -113,9 +114,12 @@ from torch._inductor.runtime.compile_tasks import (
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
|
||||
from torch._inductor.utils import (
|
||||
ALIGN_BYTES,
|
||||
align_inputs_from_check_idxs,
|
||||
BoxedBool,
|
||||
clear_on_fresh_inductor_cache,
|
||||
is_linux,
|
||||
is_windows,
|
||||
set_tracing_context_output_strides,
|
||||
)
|
||||
from torch._logging import trace_structured
|
||||
from torch._subclasses.fake_tensor import (
|
||||
@ -909,6 +913,117 @@ def compiled_fx_graph_hash(
|
||||
return key, debug_lines
|
||||
|
||||
|
||||
def cudagraph_post_compile(
|
||||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule],
|
||||
) -> None:
|
||||
"""
|
||||
Checks for any reasons not to run cudagraphs and then
|
||||
runs it on compiled_graph.
|
||||
Mutates the `compiled_graph.current_callable` and `cudagraphs`
|
||||
"""
|
||||
assert compiled_graph.current_callable is not None
|
||||
assert compiled_graph.cudagraph_info is not None
|
||||
cached_info = compiled_graph.cudagraph_info
|
||||
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
|
||||
inputs_to_check = compiled_graph.inputs_to_check
|
||||
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
||||
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
|
||||
if not cudagraph_fail_reasons:
|
||||
fx_kwargs = compiled_graph.fx_kwargs
|
||||
static_input_idxs = fx_kwargs["static_input_idxs"]
|
||||
|
||||
placeholders = cached_info.placeholders
|
||||
stack_traces = cached_info.stack_traces
|
||||
if not config.triton.cudagraph_trees:
|
||||
# Force specialize all inputs so that CUDA graphs will work
|
||||
for t in example_inputs:
|
||||
if isinstance(t, torch.SymInt):
|
||||
int(t) # guard
|
||||
|
||||
if (
|
||||
boxed_forward_device_index is not None
|
||||
and not is_inference
|
||||
and not is_backward
|
||||
):
|
||||
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
||||
|
||||
from .compile_fx import cudagraphify
|
||||
|
||||
current_callable = compiled_graph.current_callable
|
||||
assert current_callable is not None
|
||||
compiled_graph.current_callable = cudagraphify(
|
||||
current_callable,
|
||||
static_input_idxs=static_input_idxs or (),
|
||||
device_index=next(iter(compiled_graph.device_idxs)),
|
||||
stack_traces=stack_traces,
|
||||
is_backward=is_backward,
|
||||
is_inference=is_inference,
|
||||
constants=tuple(compiled_graph.get_constants(gm).values()),
|
||||
placeholders=placeholders,
|
||||
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
||||
)
|
||||
|
||||
else:
|
||||
BoxedBool.disable(cudagraphs)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
||||
# know we are we running the backward even if we will not run it in cudagraphs
|
||||
if is_backward and config.triton.cudagraph_trees:
|
||||
assert boxed_forward_device_index is not None
|
||||
assert boxed_forward_device_index.value is not None
|
||||
compiled_graph_callable = compiled_graph.current_callable
|
||||
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_forward_device_index.value, create_if_none_exists=False
|
||||
)
|
||||
# should already exist from forward
|
||||
assert manager is not None
|
||||
|
||||
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
|
||||
manager.set_to_running_backward() # type: ignore[union-attr]
|
||||
return compiled_graph_callable(new_inputs)
|
||||
|
||||
compiled_graph.current_callable = compiled_artifact
|
||||
|
||||
if "cuda" in compiled_graph.device_types:
|
||||
# prefer better disable_cudagraphs_reason bc stack trace
|
||||
# TODO: migrate all disable reasons to stack trace, refactor
|
||||
if compiled_graph.disabled_cudagraphs_reason:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
compiled_graph.disabled_cudagraphs_reason
|
||||
)
|
||||
else:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
|
||||
)
|
||||
|
||||
|
||||
def maybe_realign_inputs(
|
||||
ran_cudagraphs: BoxedBool,
|
||||
compiled_graph: CompiledFxGraph,
|
||||
inputs_to_check: Sequence[int],
|
||||
) -> None:
|
||||
"""
|
||||
Realigns input strides from inputs_to_check if
|
||||
we didn't end up running cudagraphs. Mutates
|
||||
`compiled_graph.current_callable` if cudagraphs
|
||||
was run. Otherwise, does nothing.
|
||||
"""
|
||||
if not ran_cudagraphs:
|
||||
assert compiled_graph.current_callable is not None
|
||||
new_callable = align_inputs_from_check_idxs(
|
||||
compiled_graph.current_callable, inputs_to_check
|
||||
)
|
||||
if new_callable is not compiled_graph.current_callable:
|
||||
compiled_graph.current_callable = new_callable
|
||||
|
||||
|
||||
def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int:
|
||||
"""
|
||||
Ephemerally increases the NCCL timeout when compiling for a distributed job
|
||||
@ -1156,6 +1271,54 @@ class FxGraphCache:
|
||||
)
|
||||
return graph, cache_info
|
||||
|
||||
@staticmethod
|
||||
def post_compile(
|
||||
compiled_graph: CompiledFxGraph,
|
||||
example_inputs: Sequence[InputType],
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule] = None,
|
||||
) -> CompiledFxGraph:
|
||||
"""
|
||||
Run a set of post processing steps after loading from the cache. These involve:
|
||||
- Setting the tracing context output strides
|
||||
- Running cudagraphs if enabled
|
||||
- Realigning inputs
|
||||
|
||||
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
|
||||
The results of this function are *not* saved in the cache itself.
|
||||
"""
|
||||
set_tracing_context_output_strides(example_inputs, compiled_graph)
|
||||
|
||||
if cudagraphs:
|
||||
# It's possible that cudagraphs is enabled, but was disabled
|
||||
# during a previous compilation we're loading from the cache.
|
||||
# If so, we need to disable it on this new process too.
|
||||
if compiled_graph.disabled_cudagraphs_reason:
|
||||
if "cuda" in compiled_graph.device_types:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
|
||||
)
|
||||
else:
|
||||
counters["inductor"]["cudagraph_skips"] += 1
|
||||
BoxedBool.disable(cudagraphs)
|
||||
else:
|
||||
cudagraph_post_compile(
|
||||
example_inputs,
|
||||
compiled_graph,
|
||||
cudagraphs,
|
||||
gm,
|
||||
)
|
||||
inputs_to_check = compiled_graph.inputs_to_check
|
||||
# cudagraphs could have been disabled from the earlier conditions
|
||||
# so we still need to realign inputs if that happens
|
||||
maybe_realign_inputs(
|
||||
cudagraphs,
|
||||
compiled_graph,
|
||||
inputs_to_check,
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
@staticmethod
|
||||
def _save_graph(
|
||||
key: str,
|
||||
|
@ -46,16 +46,10 @@ from torch._inductor.cudagraph_utils import (
|
||||
get_placeholder_info,
|
||||
log_cudagraph_skip_and_bump_counter,
|
||||
)
|
||||
from torch._inductor.utils import (
|
||||
align_inputs_from_check_idxs,
|
||||
BoxedBool,
|
||||
InputType,
|
||||
output_node,
|
||||
set_tracing_context_output_strides,
|
||||
)
|
||||
|
||||
from . import config
|
||||
from .runtime.autotune_cache import AutotuneCacheBundler
|
||||
from .utils import BoxedBool, InputType, output_node
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -138,117 +132,6 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def cudagraph_post_compile(
|
||||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
cudagraphs: BoxedBool,
|
||||
gm: Optional[torch.fx.GraphModule],
|
||||
) -> None:
|
||||
"""
|
||||
Checks for any reasons not to run cudagraphs and then
|
||||
runs it on compiled_graph.
|
||||
Mutates the `compiled_graph.current_callable` and `cudagraphs`
|
||||
"""
|
||||
assert compiled_graph.current_callable is not None
|
||||
assert compiled_graph.cudagraph_info is not None
|
||||
cached_info = compiled_graph.cudagraph_info
|
||||
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
|
||||
inputs_to_check = compiled_graph.inputs_to_check
|
||||
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
|
||||
is_inference = compiled_graph.fx_kwargs["is_inference"]
|
||||
is_backward = compiled_graph.fx_kwargs["is_backward"]
|
||||
|
||||
if not cudagraph_fail_reasons:
|
||||
fx_kwargs = compiled_graph.fx_kwargs
|
||||
static_input_idxs = fx_kwargs["static_input_idxs"]
|
||||
|
||||
placeholders = cached_info.placeholders
|
||||
stack_traces = cached_info.stack_traces
|
||||
if not config.triton.cudagraph_trees:
|
||||
# Force specialize all inputs so that CUDA graphs will work
|
||||
for t in example_inputs:
|
||||
if isinstance(t, torch.SymInt):
|
||||
int(t) # guard
|
||||
|
||||
if (
|
||||
boxed_forward_device_index is not None
|
||||
and not is_inference
|
||||
and not is_backward
|
||||
):
|
||||
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
|
||||
|
||||
from .compile_fx import cudagraphify
|
||||
|
||||
current_callable = compiled_graph.current_callable
|
||||
assert current_callable is not None
|
||||
compiled_graph.current_callable = cudagraphify(
|
||||
current_callable,
|
||||
static_input_idxs=static_input_idxs or (),
|
||||
device_index=next(iter(compiled_graph.device_idxs)),
|
||||
stack_traces=stack_traces,
|
||||
is_backward=is_backward,
|
||||
is_inference=is_inference,
|
||||
constants=tuple(compiled_graph.get_constants(gm).values()),
|
||||
placeholders=placeholders,
|
||||
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
|
||||
)
|
||||
|
||||
else:
|
||||
BoxedBool.disable(cudagraphs)
|
||||
|
||||
# See [Backward Generation Handling]
|
||||
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
|
||||
# know we are we running the backward even if we will not run it in cudagraphs
|
||||
if is_backward and config.triton.cudagraph_trees:
|
||||
assert boxed_forward_device_index is not None
|
||||
assert boxed_forward_device_index.value is not None
|
||||
compiled_graph_callable = compiled_graph.current_callable
|
||||
|
||||
manager = torch._inductor.cudagraph_trees.get_manager(
|
||||
boxed_forward_device_index.value, create_if_none_exists=False
|
||||
)
|
||||
# should already exist from forward
|
||||
assert manager is not None
|
||||
|
||||
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
|
||||
manager.set_to_running_backward() # type: ignore[union-attr]
|
||||
return compiled_graph_callable(new_inputs)
|
||||
|
||||
compiled_graph.current_callable = compiled_artifact
|
||||
|
||||
if "cuda" in compiled_graph.device_types:
|
||||
# prefer better disable_cudagraphs_reason bc stack trace
|
||||
# TODO: migrate all disable reasons to stack trace, refactor
|
||||
if compiled_graph.disabled_cudagraphs_reason:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
compiled_graph.disabled_cudagraphs_reason
|
||||
)
|
||||
else:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
|
||||
)
|
||||
|
||||
|
||||
def maybe_realign_inputs(
|
||||
ran_cudagraphs: BoxedBool,
|
||||
compiled_graph: CompiledFxGraph,
|
||||
inputs_to_check: Sequence[int],
|
||||
) -> None:
|
||||
"""
|
||||
Realigns input strides from inputs_to_check if
|
||||
we didn't end up running cudagraphs. Mutates
|
||||
`compiled_graph.current_callable` if cudagraphs
|
||||
was run. Otherwise, does nothing.
|
||||
"""
|
||||
if not ran_cudagraphs:
|
||||
assert compiled_graph.current_callable is not None
|
||||
new_callable = align_inputs_from_check_idxs(
|
||||
compiled_graph.current_callable, inputs_to_check
|
||||
)
|
||||
if new_callable is not compiled_graph.current_callable:
|
||||
compiled_graph.current_callable = new_callable
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompiledFxGraph(OutputCode):
|
||||
"""
|
||||
@ -406,9 +289,6 @@ class CompiledFxGraph(OutputCode):
|
||||
# TODO: should this be part of fx_kwargs
|
||||
self.boxed_forward_device_index = boxed_forward_device_index
|
||||
|
||||
# aot autograd needs to know to pass in inputs as a list
|
||||
self._boxed_call = True
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
assert self.current_callable is not None
|
||||
try:
|
||||
@ -422,44 +302,14 @@ class CompiledFxGraph(OutputCode):
|
||||
cudagraphs: BoxedBool,
|
||||
gm: GraphModule,
|
||||
) -> None:
|
||||
"""
|
||||
Run a set of post processing steps after loading from the cache. These involve:
|
||||
- Setting the tracing context output strides
|
||||
- Running cudagraphs if enabled
|
||||
- Realigning inputs
|
||||
# TODO: maybe move this here? Not sure.
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
|
||||
The results of this function are *not* saved in the cache itself.
|
||||
"""
|
||||
set_tracing_context_output_strides(example_inputs, self)
|
||||
FxGraphCache.post_compile(self, example_inputs, cudagraphs, gm)
|
||||
|
||||
if cudagraphs:
|
||||
# It's possible that cudagraphs is enabled, but was disabled
|
||||
# during a previous compilation we're loading from the cache.
|
||||
# If so, we need to disable it on this new process too.
|
||||
if self.disabled_cudagraphs_reason:
|
||||
if "cuda" in self.device_types:
|
||||
log_cudagraph_skip_and_bump_counter(
|
||||
f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}"
|
||||
)
|
||||
else:
|
||||
counters["inductor"]["cudagraph_skips"] += 1
|
||||
BoxedBool.disable(cudagraphs)
|
||||
else:
|
||||
cudagraph_post_compile(
|
||||
example_inputs,
|
||||
self,
|
||||
cudagraphs,
|
||||
gm,
|
||||
)
|
||||
inputs_to_check = self.inputs_to_check
|
||||
# cudagraphs could have been disabled from the earlier conditions
|
||||
# so we still need to realign inputs if that happens
|
||||
maybe_realign_inputs(
|
||||
cudagraphs,
|
||||
self,
|
||||
inputs_to_check,
|
||||
)
|
||||
# aot autograd needs to know to pass in inputs as a list
|
||||
# TODO: Not sure why this isn't just set by default on CompiledFxGraph
|
||||
self._boxed_call = True
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
self._triton_bundle = triton_bundle
|
||||
@ -514,31 +364,3 @@ class CompiledAOTI(OutputCode):
|
||||
|
||||
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
|
||||
return h
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockFXGraphCacheOutput(OutputCode):
|
||||
gm: Any
|
||||
_fx_graph_cache_key: Optional[str]
|
||||
# How long it took to compile this OutputCode, end to end
|
||||
_time_taken_ns: Optional[int]
|
||||
|
||||
def __init__(self, gm: Any, key: Optional[str]) -> None:
|
||||
self.gm = gm
|
||||
self._fx_graph_cache_key = key
|
||||
self._time_taken_ns = 0
|
||||
self._boxed_call = True
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
example_inputs: Sequence[InputType],
|
||||
cudagraphs: BoxedBool,
|
||||
gm: GraphModule,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
return self.gm(inputs)
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user