Code motion CompiledFxGraph to a dedicated file (#141654)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141654
Approved by: https://github.com/aorenste, https://github.com/jansel
ghstack dependencies: #141491, #141492, #141574
This commit is contained in:
Edward Z. Yang
2024-11-27 06:04:14 -08:00
committed by PyTorch MergeBot
parent a7ca6a9113
commit dbbebee9d7
7 changed files with 199 additions and 140 deletions

View File

@ -430,7 +430,7 @@ class TestFxGraphCache(TestCase):
self.reset()
with mock.patch(
"torch._inductor.codecache.has_frozen_params", return_value=True
"torch._inductor.output_code.has_frozen_params", return_value=True
):
# A call to fn1 should miss in the cache since we do not consider
# the constant values.

View File

@ -50,8 +50,8 @@ from .. import config
if TYPE_CHECKING:
from torch._inductor.codecache import CompiledFxGraph
from torch._inductor.compile_fx import _CompileFxCallableEx, _CompileFxKwargsEx
from torch._inductor.output_code import CompiledFxGraph
from torch._inductor.utils import InputType

View File

@ -23,7 +23,6 @@ from torch._inductor.codecache import (
_ident,
add_ephemeral_timeout_increase_for_distributed,
BypassFxGraphCache,
CompiledFxGraph,
create_cache,
extract_tensor_metadata_for_cache_key,
FxGraphCache,
@ -51,6 +50,7 @@ from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noq
if TYPE_CHECKING:
from torch._inductor.compile_fx import _CompileFxKwargs
from torch._inductor.output_code import CompiledFxGraph
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
from torch._inductor.utils import BoxedBool
from torch.fx.node import Node

View File

@ -36,22 +36,25 @@ from typing import (
Any,
Callable,
cast,
Counter,
Dict,
Generator,
List,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import TypeAlias
import torch
# WARNING: Do not directly import has_frozen_params, it is monkeypatched in
# python test/inductor/test_codecache.py
# TestFxGraphCache.test_constant_handling_device_cpu
# TODO: Why are we monkeypatching it......
import torch._inductor.output_code as output_code
import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.utils import (
@ -72,7 +75,7 @@ from torch._utils_internal import log_cache_bypass
from .remote_cache import create_cache
from .runtime import autotune_cache
from .runtime.autotune_cache import AutotuneCacheBundler
from .triton_bundler import TritonBundler, TritonKernelArtifacts
from .triton_bundler import TritonBundler
T = TypeVar("T")
@ -82,6 +85,7 @@ if TYPE_CHECKING:
from collections.abc import KeysView
from .compile_fx import _CompileFxKwargs
from .output_code import CompiledFxGraph
from .remote_cache import JsonDataTy, RemoteCache
from .utils import InputType
@ -101,11 +105,7 @@ from torch._inductor.cpp_builder import (
normalize_path_separator,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
CudagraphCachedInfo,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
@ -885,10 +885,6 @@ class FxGraphHashDetails:
return custom_pass.uuid()
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
return getattr(gm, "_has_frozen_params", False)
def compiled_fx_graph_hash(
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
@ -901,7 +897,7 @@ def compiled_fx_graph_hash(
# To support caching when the graph has frozen params, we ignore the tensor values
# of non-inlined constants since they won't be included in the cache entry. Without
# freezing, we want to include the values of any constant attribute.
include_non_inlined = not has_frozen_params(gm)
include_non_inlined = not output_code.has_frozen_params(gm)
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0
@ -1399,7 +1395,9 @@ class FxGraphCache:
raise BypassFxGraphCache("Unsupported post grad custom pass")
# Freezing can embed constants that wouldn't be static across runs.
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
if output_code.has_frozen_params(
gm
) and not torch._utils_internal.justknobs_check(
"pytorch/inductor:allow_freezing_with_caching"
):
raise BypassFxGraphCache("Skipping graph with frozen constants")
@ -1683,121 +1681,6 @@ class FxGraphCache:
pass
_StrideExprStr: TypeAlias = str
@dataclasses.dataclass
class CompiledFxGraph:
"""
Class holding a compiled FX graph. This is the object serialized on disk
to support FxGraph caching.
"""
current_callable: Optional[Callable[..., Any]]
cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[List[Tuple[int, str]]]
device_types: Set[str]
device_idxs: Set[int]
mutated_inputs: Set[str]
mutated_input_idxs: Set[int]
# We populate exactly one of the next two fields. In the common case, we store the
# constant attirbutes in the cache entry and re-attach them to the module created in
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
# however, we save the mapping from attribute names in the GraphLowering to the
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas
counter_deltas: Counter[str]
# This is a string representation of an expression we serialize
# with the object so the guards can be evaluated in a different
# context in order to verify the validity of serving a cached
# fx graph. The expression must be generated by:
# ShapeEnv.produce_guards_expression()
guards_expr: Optional[str]
cudagraph_info: Optional[CudagraphCachedInfo]
fx_kwargs: _CompileFxKwargs
inputs_to_check: Sequence[int]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
_time_taken_ns: Optional[int] = None
_boxed_call: Optional[bool] = None
_fx_graph_cache_key: Optional[str] = None
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
def __init__(
self,
current_callable: Optional[Callable[..., Any]],
graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas,
counter_deltas: Counter[str],
) -> None:
self.current_callable = current_callable
self.cache_key = graph.cache_key
if graph.cache_path:
with open(graph.cache_path) as f:
self.source_code = f.read()
self.cache_linemap = graph.cache_linemap
# TODO - ordered set
self.device_types = set(graph.device_types)
self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs)
if has_frozen_params(gm):
self.allocated_constant_name = graph.allocated_constant_name
self.constants = None
else:
self.allocated_constant_name = None
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants
self.output_strides = output_strides
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
self.metrics_deltas = metrics_deltas
self.counter_deltas = counter_deltas
self.guards_expr = None
self.cudagraph_info = None
self.fx_kwargs = {}
self.inputs_to_check = ()
self.boxed_forward_device_index = None
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
return self.current_callable(inputs)
finally:
AutotuneCacheBundler.end_compile()
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
"""
Get the constant attributes.
"""
# Normal case: The constants are stored in the entry.
if self.constants is not None:
return self.constants
# Freezing case: Look up the constants from attributes on the GraphModule using
# the allocated_constant_name map.
assert gm is not None
assert self.allocated_constant_name is not None
constants = {
name: getattr(gm, orig_name)
for name, orig_name in self.allocated_constant_name.items()
}
return constants
def run_command_and_check(cmd_: str) -> None:
with dynamo_timed("run_command_and_check", log_pt2_compile_event=True):
cmd = shlex.split(cmd_)

View File

@ -51,12 +51,7 @@ from torch._dynamo.utils import (
)
from torch._functorch import config as functorch_config
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._inductor.codecache import (
_StrideExprStr,
code_hash,
CompiledFxGraph,
FxGraphCache,
)
from torch._inductor.codecache import code_hash, FxGraphCache
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
CudagraphCachedInfo,
@ -65,6 +60,7 @@ from torch._inductor.cudagraph_utils import (
PlaceholderInfo,
)
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.output_code import CompiledFxGraph
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import (
BoxedBool,
@ -110,6 +106,7 @@ from .virtualized import V
if TYPE_CHECKING:
from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
from .ir import ExternKernelNode

View File

@ -0,0 +1,179 @@
"""
This provides an abstract class which parametrizes over an "output code" concept
for Inductor. Intuitively, this represents the compiled callable which Inductor
produces which you can call to get optimized code. However, this callable
has some other capabilities:
- It is serializable, so you can save/load this product from disk without
having to do compilation again.
- (When using remote cache) it is addressable, so you can save just a key
which you can use to load this product from remote cache later.
This class is abstract because we have several different implementations of
serialized format:
- Python wrapper (the default)
- AOTInductor (this produces ABI stable binaries which work across PyTorch
versions)
"""
from __future__ import annotations
import dataclasses
from typing import (
Any,
Callable,
Counter,
Dict,
List,
Optional,
Protocol,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
)
from typing_extensions import TypeAlias
from .runtime.autotune_cache import AutotuneCacheBundler
if TYPE_CHECKING:
import torch
from torch._inductor import metrics
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, CudagraphCachedInfo
from torch._inductor.graph import GraphLowering
from .compile_fx import _CompileFxKwargs
from .triton_bundler import TritonKernelArtifacts
class OutputCode(Protocol):
def __call__(self, inputs: Sequence[Any]) -> Any:
...
_StrideExprStr: TypeAlias = str
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
return getattr(gm, "_has_frozen_params", False)
@dataclasses.dataclass
class CompiledFxGraph:
"""
Class holding a compiled FX graph. This is the object serialized on disk
to support FxGraph caching.
"""
current_callable: Optional[Callable[..., Any]]
cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[List[Tuple[int, str]]]
device_types: Set[str]
device_idxs: Set[int]
mutated_inputs: Set[str]
mutated_input_idxs: Set[int]
# We populate exactly one of the next two fields. In the common case, we store the
# constant attirbutes in the cache entry and re-attach them to the module created in
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
# however, we save the mapping from attribute names in the GraphLowering to the
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas
counter_deltas: Counter[str]
# This is a string representation of an expression we serialize
# with the object so the guards can be evaluated in a different
# context in order to verify the validity of serving a cached
# fx graph. The expression must be generated by:
# ShapeEnv.produce_guards_expression()
guards_expr: Optional[str]
cudagraph_info: Optional[CudagraphCachedInfo]
fx_kwargs: _CompileFxKwargs
inputs_to_check: Sequence[int]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
_time_taken_ns: Optional[int] = None
_boxed_call: Optional[bool] = None
_fx_graph_cache_key: Optional[str] = None
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
def __init__(
self,
current_callable: Optional[Callable[..., Any]],
graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas,
counter_deltas: Counter[str],
) -> None:
self.current_callable = current_callable
self.cache_key = graph.cache_key
if graph.cache_path:
with open(graph.cache_path) as f:
self.source_code = f.read()
self.cache_linemap = graph.cache_linemap
# TODO - ordered set
self.device_types = set(graph.device_types)
self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs)
if has_frozen_params(gm):
self.allocated_constant_name = graph.allocated_constant_name
self.constants = None
else:
self.allocated_constant_name = None
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants
self.output_strides = output_strides
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
self.metrics_deltas = metrics_deltas
self.counter_deltas = counter_deltas
self.guards_expr = None
self.cudagraph_info = None
self.fx_kwargs = {}
self.inputs_to_check = ()
self.boxed_forward_device_index = None
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
return self.current_callable(inputs)
finally:
AutotuneCacheBundler.end_compile()
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
"""
Get the constant attributes.
"""
# Normal case: The constants are stored in the entry.
if self.constants is not None:
return self.constants
# Freezing case: Look up the constants from attributes on the GraphModule using
# the allocated_constant_name map.
assert gm is not None
assert self.allocated_constant_name is not None
constants = {
name: getattr(gm, orig_name)
for name, orig_name in self.allocated_constant_name.items()
}
return constants
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
return h

View File

@ -1454,8 +1454,8 @@ def run_and_get_triton_code(fn, *args, **kwargs):
def run_and_get_graph_lowering(fn, *args, **kwargs):
from torch._inductor.codecache import CompiledFxGraph
from torch._inductor.graph import GraphLowering
from torch._inductor.output_code import CompiledFxGraph
real_init = CompiledFxGraph.__init__
graph_lowerings = []