Files
pytorch/torch/_inductor/output_code.py
2024-11-27 20:42:21 +00:00

180 lines
6.2 KiB
Python

"""
This provides an abstract class which parametrizes over an "output code" concept
for Inductor. Intuitively, this represents the compiled callable which Inductor
produces which you can call to get optimized code. However, this callable
has some other capabilities:
- It is serializable, so you can save/load this product from disk without
having to do compilation again.
- (When using remote cache) it is addressable, so you can save just a key
which you can use to load this product from remote cache later.
This class is abstract because we have several different implementations of
serialized format:
- Python wrapper (the default)
- AOTInductor (this produces ABI stable binaries which work across PyTorch
versions)
"""
from __future__ import annotations
import dataclasses
from typing import (
Any,
Callable,
Counter,
Dict,
List,
Optional,
Protocol,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
)
from typing_extensions import TypeAlias
from .runtime.autotune_cache import AutotuneCacheBundler
if TYPE_CHECKING:
import torch
from torch._inductor import metrics
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, CudagraphCachedInfo
from torch._inductor.graph import GraphLowering
from .compile_fx import _CompileFxKwargs
from .triton_bundler import TritonKernelArtifacts
class OutputCode(Protocol):
def __call__(self, inputs: Sequence[Any]) -> Any:
...
_StrideExprStr: TypeAlias = str
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
return getattr(gm, "_has_frozen_params", False)
@dataclasses.dataclass
class CompiledFxGraph:
"""
Class holding a compiled FX graph. This is the object serialized on disk
to support FxGraph caching.
"""
current_callable: Optional[Callable[..., Any]]
cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[List[Tuple[int, str]]]
device_types: Set[str]
device_idxs: Set[int]
mutated_inputs: Set[str]
mutated_input_idxs: Set[int]
# We populate exactly one of the next two fields. In the common case, we store the
# constant attirbutes in the cache entry and re-attach them to the module created in
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
# however, we save the mapping from attribute names in the GraphLowering to the
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas
counter_deltas: Counter[str]
# This is a string representation of an expression we serialize
# with the object so the guards can be evaluated in a different
# context in order to verify the validity of serving a cached
# fx graph. The expression must be generated by:
# ShapeEnv.produce_guards_expression()
guards_expr: Optional[str]
cudagraph_info: Optional[CudagraphCachedInfo]
fx_kwargs: _CompileFxKwargs
inputs_to_check: Sequence[int]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
_time_taken_ns: Optional[int] = None
_boxed_call: Optional[bool] = None
_fx_graph_cache_key: Optional[str] = None
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
def __init__(
self,
current_callable: Optional[Callable[..., Any]],
graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas,
counter_deltas: Counter[str],
) -> None:
self.current_callable = current_callable
self.cache_key = graph.cache_key
if graph.cache_path:
with open(graph.cache_path) as f:
self.source_code = f.read()
self.cache_linemap = graph.cache_linemap
# TODO - ordered set
self.device_types = set(graph.device_types)
self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs)
if has_frozen_params(gm):
self.allocated_constant_name = graph.allocated_constant_name
self.constants = None
else:
self.allocated_constant_name = None
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants
self.output_strides = output_strides
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
self.metrics_deltas = metrics_deltas
self.counter_deltas = counter_deltas
self.guards_expr = None
self.cudagraph_info = None
self.fx_kwargs = {}
self.inputs_to_check = ()
self.boxed_forward_device_index = None
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
return self.current_callable(inputs)
finally:
AutotuneCacheBundler.end_compile()
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
"""
Get the constant attributes.
"""
# Normal case: The constants are stored in the entry.
if self.constants is not None:
return self.constants
# Freezing case: Look up the constants from attributes on the GraphModule using
# the allocated_constant_name map.
assert gm is not None
assert self.allocated_constant_name is not None
constants = {
name: getattr(gm, orig_name)
for name, orig_name in self.allocated_constant_name.items()
}
return constants
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
return h