mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/141654 Approved by: https://github.com/aorenste, https://github.com/jansel ghstack dependencies: #141491, #141492, #141574
180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
"""
|
|
This provides an abstract class which parametrizes over an "output code" concept
|
|
for Inductor. Intuitively, this represents the compiled callable which Inductor
|
|
produces which you can call to get optimized code. However, this callable
|
|
has some other capabilities:
|
|
|
|
- It is serializable, so you can save/load this product from disk without
|
|
having to do compilation again.
|
|
|
|
- (When using remote cache) it is addressable, so you can save just a key
|
|
which you can use to load this product from remote cache later.
|
|
|
|
This class is abstract because we have several different implementations of
|
|
serialized format:
|
|
|
|
- Python wrapper (the default)
|
|
|
|
- AOTInductor (this produces ABI stable binaries which work across PyTorch
|
|
versions)
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Counter,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Protocol,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
TYPE_CHECKING,
|
|
)
|
|
from typing_extensions import TypeAlias
|
|
|
|
from .runtime.autotune_cache import AutotuneCacheBundler
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import torch
|
|
from torch._inductor import metrics
|
|
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, CudagraphCachedInfo
|
|
from torch._inductor.graph import GraphLowering
|
|
|
|
from .compile_fx import _CompileFxKwargs
|
|
from .triton_bundler import TritonKernelArtifacts
|
|
|
|
|
|
class OutputCode(Protocol):
|
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
|
...
|
|
|
|
|
|
_StrideExprStr: TypeAlias = str
|
|
|
|
|
|
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
|
return getattr(gm, "_has_frozen_params", False)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CompiledFxGraph:
|
|
"""
|
|
Class holding a compiled FX graph. This is the object serialized on disk
|
|
to support FxGraph caching.
|
|
"""
|
|
|
|
current_callable: Optional[Callable[..., Any]]
|
|
cache_key: str
|
|
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
|
cache_linemap: Optional[List[Tuple[int, str]]]
|
|
device_types: Set[str]
|
|
device_idxs: Set[int]
|
|
mutated_inputs: Set[str]
|
|
mutated_input_idxs: Set[int]
|
|
# We populate exactly one of the next two fields. In the common case, we store the
|
|
# constant attirbutes in the cache entry and re-attach them to the module created in
|
|
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
|
|
# however, we save the mapping from attribute names in the GraphLowering to the
|
|
# original name of the attribute in the GraphModule. When we create the module from
|
|
# the cache entry, we then look up the constants from the current GraphModule. This
|
|
# scheme allows us to support caching with freezing.
|
|
allocated_constant_name: Optional[Dict[str, str]]
|
|
constants: Optional[Dict[str, torch.Tensor]]
|
|
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
|
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
|
|
disabled_cudagraphs_reason: Optional[str]
|
|
metrics_deltas: metrics.CachedMetricsDeltas
|
|
counter_deltas: Counter[str]
|
|
# This is a string representation of an expression we serialize
|
|
# with the object so the guards can be evaluated in a different
|
|
# context in order to verify the validity of serving a cached
|
|
# fx graph. The expression must be generated by:
|
|
# ShapeEnv.produce_guards_expression()
|
|
guards_expr: Optional[str]
|
|
|
|
cudagraph_info: Optional[CudagraphCachedInfo]
|
|
fx_kwargs: _CompileFxKwargs
|
|
inputs_to_check: Sequence[int]
|
|
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
|
|
|
_time_taken_ns: Optional[int] = None
|
|
_boxed_call: Optional[bool] = None
|
|
_fx_graph_cache_key: Optional[str] = None
|
|
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
current_callable: Optional[Callable[..., Any]],
|
|
graph: GraphLowering,
|
|
gm: torch.fx.GraphModule,
|
|
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
|
|
disabled_cudagraphs_reason: Optional[str],
|
|
metrics_deltas: metrics.CachedMetricsDeltas,
|
|
counter_deltas: Counter[str],
|
|
) -> None:
|
|
self.current_callable = current_callable
|
|
self.cache_key = graph.cache_key
|
|
if graph.cache_path:
|
|
with open(graph.cache_path) as f:
|
|
self.source_code = f.read()
|
|
self.cache_linemap = graph.cache_linemap
|
|
# TODO - ordered set
|
|
self.device_types = set(graph.device_types)
|
|
self.device_idxs = set(graph.device_idxs)
|
|
self.mutated_inputs = set(graph.mutated_inputs)
|
|
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
|
if has_frozen_params(gm):
|
|
self.allocated_constant_name = graph.allocated_constant_name
|
|
self.constants = None
|
|
else:
|
|
self.allocated_constant_name = None
|
|
self.constants = graph.constants
|
|
self.torchbind_constants = graph.torchbind_constants
|
|
self.output_strides = output_strides
|
|
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
|
self.metrics_deltas = metrics_deltas
|
|
self.counter_deltas = counter_deltas
|
|
self.guards_expr = None
|
|
self.cudagraph_info = None
|
|
self.fx_kwargs = {}
|
|
self.inputs_to_check = ()
|
|
self.boxed_forward_device_index = None
|
|
|
|
def __call__(self, inputs: Sequence[Any]) -> Any:
|
|
assert self.current_callable is not None
|
|
try:
|
|
return self.current_callable(inputs)
|
|
finally:
|
|
AutotuneCacheBundler.end_compile()
|
|
|
|
def get_constants(
|
|
self, gm: Optional[torch.fx.GraphModule]
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Get the constant attributes.
|
|
"""
|
|
# Normal case: The constants are stored in the entry.
|
|
if self.constants is not None:
|
|
return self.constants
|
|
|
|
# Freezing case: Look up the constants from attributes on the GraphModule using
|
|
# the allocated_constant_name map.
|
|
assert gm is not None
|
|
assert self.allocated_constant_name is not None
|
|
constants = {
|
|
name: getattr(gm, orig_name)
|
|
for name, orig_name in self.allocated_constant_name.items()
|
|
}
|
|
return constants
|
|
|
|
|
|
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
|
|
return h
|