Introduce CompiledAOTI (#141695)

Stacked on https://github.com/pytorch/pytorch/pull/141691

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141695
Approved by: https://github.com/aorenste
ghstack dependencies: #141681, #141683, #141685, #141688, #141689, #141691
This commit is contained in:
Edward Z. Yang
2024-11-28 06:18:38 -08:00
committed by PyTorch MergeBot
parent 2f72635a5c
commit 7fafaa9c82
4 changed files with 85 additions and 33 deletions

View File

@ -39,6 +39,7 @@ from torch._dynamo.debug_utils import (
)
from torch._dynamo.trace_rules import is_fbcode
from torch._dynamo.utils import clone_inputs, counters, same
from torch._inductor.output_code import OutputCode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
fx_placeholder_targets,
@ -51,7 +52,6 @@ from .. import config
if TYPE_CHECKING:
from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs
from torch._inductor.output_code import CompiledFxGraph
from torch._inductor.utils import InputType
@ -83,7 +83,7 @@ def wrap_compiler_debug(
gm: torch.fx.GraphModule,
example_inputs: Sequence["InputType"],
**kwargs: Unpack["_CompileFxKwargs"],
) -> Union["CompiledFxGraph", str]:
) -> OutputCode:
from torch._subclasses import FakeTensorMode
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)

View File

@ -84,8 +84,8 @@ T = TypeVar("T")
if TYPE_CHECKING:
from collections.abc import KeysView
from .compile_fx import _CompileFxKwargs
from .output_code import CompiledFxGraph
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
from .output_code import OutputCode
from .remote_cache import JsonDataTy, RemoteCache
from .utils import InputType
@ -1322,7 +1322,7 @@ class FxGraphCache:
@staticmethod
def _save_graph(
key: str,
compiled_graph: CompiledFxGraph,
compiled_graph: OutputCode,
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
@ -1330,6 +1330,11 @@ class FxGraphCache:
"""
Store a serialized CompiledFxGraph on disk.
"""
from .compile_fx import CompiledFxGraph
assert isinstance(
compiled_graph, CompiledFxGraph
), f"serialization for {type(compiled_graph)} NYI"
disk_compiled_graph = copy(compiled_graph)
# We can't really serialize callables that may be C++/Triton/etc.,
# so we serialize their PyCodeCache disk cache location instead.

View File

@ -56,9 +56,11 @@ from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.output_code import (
CompiledAOTI,
CompiledFxGraph,
get_expanded_dims,
index_expanded_dims,
OutputCode,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import (
@ -509,7 +511,7 @@ class _CompileFxCallable(Protocol):
gm: GraphModule,
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> Union[CompiledFxGraph, str]:
) -> OutputCode:
...
@ -517,7 +519,7 @@ def compile_fx_inner(
gm: GraphModule,
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> Union[CompiledFxGraph, str]:
) -> OutputCode:
kwargs.setdefault("cudagraphs", None)
kwargs.setdefault("static_input_idxs", ())
kwargs.setdefault("is_backward", False)
@ -570,7 +572,7 @@ def _compile_fx_inner(
gm: GraphModule,
example_inputs: Sequence[InputType],
**graph_kwargs: Unpack[_CompileFxKwargs],
) -> Union[CompiledFxGraph, str]:
) -> OutputCode:
"""
Inductor API that compiles a single graph.
@ -630,11 +632,7 @@ def _compile_fx_inner(
):
input._is_inductor_static = True # type: ignore[attr-defined]
# TODO: Remove this short circuit once types are unified here
if aot_mode:
return fx_codegen_and_compile(gm, example_inputs, inputs_to_check, **graph_kwargs) # type: ignore[assignment]
mb_compiled_graph: Optional[CompiledFxGraph] = None
mb_compiled_graph: Optional[OutputCode] = None
key_info = None
cache_info = None
remote_cache = None
@ -668,11 +666,9 @@ def _compile_fx_inner(
# determined the input is uncacheable)
if cache_info is None or cache_info["cache_state"] == "bypass":
assert mb_compiled_graph is None
r = fx_codegen_and_compile(
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
assert not isinstance(r, str) # due to aot test
mb_compiled_graph = r
# CACHE MISS: Compile the graph and save to cache
elif cache_info["cache_state"] == "miss":
@ -680,19 +676,18 @@ def _compile_fx_inner(
assert key_info is not None
TritonBundler.begin_compile()
try:
r = fx_codegen_and_compile(
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
assert not isinstance(r, str) # due to aot test
mb_compiled_graph = r
assert mb_compiled_graph is not None
mb_compiled_graph._time_taken_ns = time.time_ns() - start_time
cache_key = key_info[0]
mb_compiled_graph._fx_graph_cache_key = cache_key
(
mb_compiled_graph._triton_bundle,
triton_bundle,
triton_bundler_meta,
) = TritonBundler.collect()
mb_compiled_graph.set_triton_bundle(triton_bundle)
finally:
TritonBundler.end_compile()
if triton_bundler_meta is not None:
@ -782,7 +777,7 @@ def fx_codegen_and_compile(
# in explicitly because it's nontrivial to compute
inputs_to_check: Sequence[int],
**graph_kwargs: Unpack[_CompileFxKwargs],
) -> Union[CompiledFxGraph, str]:
) -> OutputCode:
# Sorry about the mess, we need graph_kwargs to continue to be able
# to propagate it further on
# TODO: _CompileFxKwargs actually has stronger types than in the
@ -979,6 +974,10 @@ def fx_codegen_and_compile(
_check_triton_bf16_support(graph)
# TODO: The switching between AOT mode and not here is a bit
# messy, but it's localized to the block of code below so I'm
# not going to touch it for now
compiled_fn: Any
with dynamo_timed(
@ -1058,8 +1057,10 @@ def fx_codegen_and_compile(
V.graph.disable_cudagraphs_reason = disable
if V.aot_compilation is True:
return compiled_fn
assert isinstance(compiled_fn, (str, list))
return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation
if cudagraphs and not V.graph.disable_cudagraphs_reason:
from torch._inductor.cudagraph_utils import (
check_lowering_disable_cudagraph,
@ -1069,7 +1070,7 @@ def fx_codegen_and_compile(
check_lowering_disable_cudagraph(V.graph.device_node_mapping)
)
compiled_graph = CompiledFxGraph(
return CompiledFxGraph(
compiled_fn,
graph,
gm,
@ -1085,8 +1086,6 @@ def fx_codegen_and_compile(
boxed_forward_device_index,
)
return compiled_graph
def get_input_idxs_to_check(
inputs: Sequence[InputType],
@ -1326,11 +1325,9 @@ def compile_fx_aot(
config_patches=config_patches,
)
assert isinstance(compiled_artifacts, str) or (
isinstance(compiled_artifacts, list)
and isinstance(compiled_artifacts[0], str)
)
return compiled_artifacts
assert isinstance(compiled_artifacts, CompiledAOTI)
return compiled_artifacts.filename
_graph_counter = count(0)
@ -1487,7 +1484,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
def compile_fx(
model_: GraphModule,
example_inputs_: Sequence[InputType],
inner_compile: Callable[..., Any] = compile_fx_inner,
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]:
@ -1631,7 +1628,7 @@ def compile_fx(
model: GraphModule,
example_inputs: List[InputType],
is_inference: bool,
) -> CompiledFxGraph:
) -> OutputCode:
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
if is_inference:
# partition_fn won't be called
@ -1737,7 +1734,7 @@ def compile_fx(
@compile_time_strobelight_meta(phase_name="backward")
def bw_compiler(
model: GraphModule, example_inputs: List[InputType]
) -> Union[CompiledFxGraph, str]:
) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock
with dynamo_utils.dynamo_timed(

View File

@ -35,6 +35,7 @@ from typing import (
Set,
Tuple,
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeAlias
@ -73,6 +74,21 @@ class OutputCode(Protocol):
) -> None:
...
# TODO: Not sure if I really want these to be properties, this is easy
# though
#
# TODO: Remove leading underscores
# None if the output is not remote cacheable
_fx_graph_cache_key: Optional[str]
# How long it took to compile this OutputCode, end to end
_time_taken_ns: Optional[int]
# TODO: Get rid of this
def set_triton_bundle(self, triton_bundle: Any) -> None:
...
_StrideExprStr: TypeAlias = str
@ -300,6 +316,9 @@ class CompiledFxGraph:
# TODO: Not sure why this isn't just set by default on CompiledFxGraph
self._boxed_call = True
def set_triton_bundle(self, triton_bundle: Any) -> None:
self._triton_bundle = triton_bundle
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
@ -323,3 +342,34 @@ class CompiledFxGraph:
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
return h
@dataclasses.dataclass
class CompiledAOTI:
"""
Class holding an AOTInductor compiled so.
"""
filename: Union[str, List[str]]
# TODO: Figure out if these make sense or not here
_fx_graph_cache_key: Optional[str] = None
_time_taken_ns: Optional[int] = None
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")
def post_compile(
self,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: GraphModule,
) -> None:
pass
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
return h