mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce CompiledAOTI (#141695)
Stacked on https://github.com/pytorch/pytorch/pull/141691 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/141695 Approved by: https://github.com/aorenste ghstack dependencies: #141681, #141683, #141685, #141688, #141689, #141691
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f72635a5c
commit
7fafaa9c82
@ -39,6 +39,7 @@ from torch._dynamo.debug_utils import (
|
||||
)
|
||||
from torch._dynamo.trace_rules import is_fbcode
|
||||
from torch._dynamo.utils import clone_inputs, counters, same
|
||||
from torch._inductor.output_code import OutputCode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
fx_placeholder_targets,
|
||||
@ -51,7 +52,6 @@ from .. import config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.compile_fx import _CompileFxCallable, _CompileFxKwargs
|
||||
from torch._inductor.output_code import CompiledFxGraph
|
||||
from torch._inductor.utils import InputType
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ def wrap_compiler_debug(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence["InputType"],
|
||||
**kwargs: Unpack["_CompileFxKwargs"],
|
||||
) -> Union["CompiledFxGraph", str]:
|
||||
) -> OutputCode:
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
|
||||
|
@ -84,8 +84,8 @@ T = TypeVar("T")
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import KeysView
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .output_code import CompiledFxGraph
|
||||
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
|
||||
from .output_code import OutputCode
|
||||
from .remote_cache import JsonDataTy, RemoteCache
|
||||
from .utils import InputType
|
||||
|
||||
@ -1322,7 +1322,7 @@ class FxGraphCache:
|
||||
@staticmethod
|
||||
def _save_graph(
|
||||
key: str,
|
||||
compiled_graph: CompiledFxGraph,
|
||||
compiled_graph: OutputCode,
|
||||
example_inputs: Sequence[InputType],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
@ -1330,6 +1330,11 @@ class FxGraphCache:
|
||||
"""
|
||||
Store a serialized CompiledFxGraph on disk.
|
||||
"""
|
||||
from .compile_fx import CompiledFxGraph
|
||||
|
||||
assert isinstance(
|
||||
compiled_graph, CompiledFxGraph
|
||||
), f"serialization for {type(compiled_graph)} NYI"
|
||||
disk_compiled_graph = copy(compiled_graph)
|
||||
# We can't really serialize callables that may be C++/Triton/etc.,
|
||||
# so we serialize their PyCodeCache disk cache location instead.
|
||||
|
@ -56,9 +56,11 @@ from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._inductor.output_code import (
|
||||
CompiledAOTI,
|
||||
CompiledFxGraph,
|
||||
get_expanded_dims,
|
||||
index_expanded_dims,
|
||||
OutputCode,
|
||||
)
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import (
|
||||
@ -509,7 +511,7 @@ class _CompileFxCallable(Protocol):
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
**kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> Union[CompiledFxGraph, str]:
|
||||
) -> OutputCode:
|
||||
...
|
||||
|
||||
|
||||
@ -517,7 +519,7 @@ def compile_fx_inner(
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
**kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> Union[CompiledFxGraph, str]:
|
||||
) -> OutputCode:
|
||||
kwargs.setdefault("cudagraphs", None)
|
||||
kwargs.setdefault("static_input_idxs", ())
|
||||
kwargs.setdefault("is_backward", False)
|
||||
@ -570,7 +572,7 @@ def _compile_fx_inner(
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
**graph_kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> Union[CompiledFxGraph, str]:
|
||||
) -> OutputCode:
|
||||
"""
|
||||
Inductor API that compiles a single graph.
|
||||
|
||||
@ -630,11 +632,7 @@ def _compile_fx_inner(
|
||||
):
|
||||
input._is_inductor_static = True # type: ignore[attr-defined]
|
||||
|
||||
# TODO: Remove this short circuit once types are unified here
|
||||
if aot_mode:
|
||||
return fx_codegen_and_compile(gm, example_inputs, inputs_to_check, **graph_kwargs) # type: ignore[assignment]
|
||||
|
||||
mb_compiled_graph: Optional[CompiledFxGraph] = None
|
||||
mb_compiled_graph: Optional[OutputCode] = None
|
||||
key_info = None
|
||||
cache_info = None
|
||||
remote_cache = None
|
||||
@ -668,11 +666,9 @@ def _compile_fx_inner(
|
||||
# determined the input is uncacheable)
|
||||
if cache_info is None or cache_info["cache_state"] == "bypass":
|
||||
assert mb_compiled_graph is None
|
||||
r = fx_codegen_and_compile(
|
||||
mb_compiled_graph = fx_codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, **graph_kwargs
|
||||
)
|
||||
assert not isinstance(r, str) # due to aot test
|
||||
mb_compiled_graph = r
|
||||
|
||||
# CACHE MISS: Compile the graph and save to cache
|
||||
elif cache_info["cache_state"] == "miss":
|
||||
@ -680,19 +676,18 @@ def _compile_fx_inner(
|
||||
assert key_info is not None
|
||||
TritonBundler.begin_compile()
|
||||
try:
|
||||
r = fx_codegen_and_compile(
|
||||
mb_compiled_graph = fx_codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, **graph_kwargs
|
||||
)
|
||||
assert not isinstance(r, str) # due to aot test
|
||||
mb_compiled_graph = r
|
||||
assert mb_compiled_graph is not None
|
||||
mb_compiled_graph._time_taken_ns = time.time_ns() - start_time
|
||||
cache_key = key_info[0]
|
||||
mb_compiled_graph._fx_graph_cache_key = cache_key
|
||||
(
|
||||
mb_compiled_graph._triton_bundle,
|
||||
triton_bundle,
|
||||
triton_bundler_meta,
|
||||
) = TritonBundler.collect()
|
||||
mb_compiled_graph.set_triton_bundle(triton_bundle)
|
||||
finally:
|
||||
TritonBundler.end_compile()
|
||||
if triton_bundler_meta is not None:
|
||||
@ -782,7 +777,7 @@ def fx_codegen_and_compile(
|
||||
# in explicitly because it's nontrivial to compute
|
||||
inputs_to_check: Sequence[int],
|
||||
**graph_kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> Union[CompiledFxGraph, str]:
|
||||
) -> OutputCode:
|
||||
# Sorry about the mess, we need graph_kwargs to continue to be able
|
||||
# to propagate it further on
|
||||
# TODO: _CompileFxKwargs actually has stronger types than in the
|
||||
@ -979,6 +974,10 @@ def fx_codegen_and_compile(
|
||||
|
||||
_check_triton_bf16_support(graph)
|
||||
|
||||
# TODO: The switching between AOT mode and not here is a bit
|
||||
# messy, but it's localized to the block of code below so I'm
|
||||
# not going to touch it for now
|
||||
|
||||
compiled_fn: Any
|
||||
|
||||
with dynamo_timed(
|
||||
@ -1058,8 +1057,10 @@ def fx_codegen_and_compile(
|
||||
V.graph.disable_cudagraphs_reason = disable
|
||||
|
||||
if V.aot_compilation is True:
|
||||
return compiled_fn
|
||||
assert isinstance(compiled_fn, (str, list))
|
||||
return CompiledAOTI(compiled_fn)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
||||
from torch._inductor.cudagraph_utils import (
|
||||
check_lowering_disable_cudagraph,
|
||||
@ -1069,7 +1070,7 @@ def fx_codegen_and_compile(
|
||||
check_lowering_disable_cudagraph(V.graph.device_node_mapping)
|
||||
)
|
||||
|
||||
compiled_graph = CompiledFxGraph(
|
||||
return CompiledFxGraph(
|
||||
compiled_fn,
|
||||
graph,
|
||||
gm,
|
||||
@ -1085,8 +1086,6 @@ def fx_codegen_and_compile(
|
||||
boxed_forward_device_index,
|
||||
)
|
||||
|
||||
return compiled_graph
|
||||
|
||||
|
||||
def get_input_idxs_to_check(
|
||||
inputs: Sequence[InputType],
|
||||
@ -1326,11 +1325,9 @@ def compile_fx_aot(
|
||||
config_patches=config_patches,
|
||||
)
|
||||
|
||||
assert isinstance(compiled_artifacts, str) or (
|
||||
isinstance(compiled_artifacts, list)
|
||||
and isinstance(compiled_artifacts[0], str)
|
||||
)
|
||||
return compiled_artifacts
|
||||
assert isinstance(compiled_artifacts, CompiledAOTI)
|
||||
|
||||
return compiled_artifacts.filename
|
||||
|
||||
|
||||
_graph_counter = count(0)
|
||||
@ -1487,7 +1484,7 @@ def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]:
|
||||
def compile_fx(
|
||||
model_: GraphModule,
|
||||
example_inputs_: Sequence[InputType],
|
||||
inner_compile: Callable[..., Any] = compile_fx_inner,
|
||||
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||
config_patches: Optional[Dict[str, Any]] = None,
|
||||
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
|
||||
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]:
|
||||
@ -1631,7 +1628,7 @@ def compile_fx(
|
||||
model: GraphModule,
|
||||
example_inputs: List[InputType],
|
||||
is_inference: bool,
|
||||
) -> CompiledFxGraph:
|
||||
) -> OutputCode:
|
||||
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
||||
if is_inference:
|
||||
# partition_fn won't be called
|
||||
@ -1737,7 +1734,7 @@ def compile_fx(
|
||||
@compile_time_strobelight_meta(phase_name="backward")
|
||||
def bw_compiler(
|
||||
model: GraphModule, example_inputs: List[InputType]
|
||||
) -> Union[CompiledFxGraph, str]:
|
||||
) -> OutputCode:
|
||||
from torch._dynamo.convert_frame import compile_lock
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
|
@ -35,6 +35,7 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
@ -73,6 +74,21 @@ class OutputCode(Protocol):
|
||||
) -> None:
|
||||
...
|
||||
|
||||
# TODO: Not sure if I really want these to be properties, this is easy
|
||||
# though
|
||||
#
|
||||
# TODO: Remove leading underscores
|
||||
|
||||
# None if the output is not remote cacheable
|
||||
_fx_graph_cache_key: Optional[str]
|
||||
|
||||
# How long it took to compile this OutputCode, end to end
|
||||
_time_taken_ns: Optional[int]
|
||||
|
||||
# TODO: Get rid of this
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
...
|
||||
|
||||
|
||||
_StrideExprStr: TypeAlias = str
|
||||
|
||||
@ -300,6 +316,9 @@ class CompiledFxGraph:
|
||||
# TODO: Not sure why this isn't just set by default on CompiledFxGraph
|
||||
self._boxed_call = True
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
self._triton_bundle = triton_bundle
|
||||
|
||||
def get_constants(
|
||||
self, gm: Optional[torch.fx.GraphModule]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
@ -323,3 +342,34 @@ class CompiledFxGraph:
|
||||
|
||||
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
|
||||
return h
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompiledAOTI:
|
||||
"""
|
||||
Class holding an AOTInductor compiled so.
|
||||
"""
|
||||
|
||||
filename: Union[str, List[str]]
|
||||
|
||||
# TODO: Figure out if these make sense or not here
|
||||
_fx_graph_cache_key: Optional[str] = None
|
||||
_time_taken_ns: Optional[int] = None
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
example_inputs: Sequence[InputType],
|
||||
cudagraphs: BoxedBool,
|
||||
gm: GraphModule,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def set_triton_bundle(self, triton_bundle: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
|
||||
return h
|
||||
|
Reference in New Issue
Block a user