Introduce CompiledAOTI (#141695)

Stacked on https://github.com/pytorch/pytorch/pull/141691

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141695
Approved by: https://github.com/aorenste
ghstack dependencies: #141681, #141683, #141685, #141688, #141689, #141691
This commit is contained in:
Edward Z. Yang
2024-11-28 06:18:38 -08:00
committed by PyTorch MergeBot
parent 2f72635a5c
commit 7fafaa9c82
4 changed files with 85 additions and 33 deletions

View File

@ -84,8 +84,8 @@ T = TypeVar("T")
if TYPE_CHECKING:
from collections.abc import KeysView
from .compile_fx import _CompileFxKwargs
from .output_code import CompiledFxGraph
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
from .output_code import OutputCode
from .remote_cache import JsonDataTy, RemoteCache
from .utils import InputType
@ -1322,7 +1322,7 @@ class FxGraphCache:
@staticmethod
def _save_graph(
key: str,
compiled_graph: CompiledFxGraph,
compiled_graph: OutputCode,
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
@ -1330,6 +1330,11 @@ class FxGraphCache:
"""
Store a serialized CompiledFxGraph on disk.
"""
from .compile_fx import CompiledFxGraph
assert isinstance(
compiled_graph, CompiledFxGraph
), f"serialization for {type(compiled_graph)} NYI"
disk_compiled_graph = copy(compiled_graph)
# We can't really serialize callables that may be C++/Triton/etc.,
# so we serialize their PyCodeCache disk cache location instead.