mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Code motion CompiledFxGraph to a dedicated file (#141654)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/141654 Approved by: https://github.com/aorenste, https://github.com/jansel ghstack dependencies: #141491, #141492, #141574
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7ca6a9113
commit
dbbebee9d7
@ -430,7 +430,7 @@ class TestFxGraphCache(TestCase):
|
||||
self.reset()
|
||||
|
||||
with mock.patch(
|
||||
"torch._inductor.codecache.has_frozen_params", return_value=True
|
||||
"torch._inductor.output_code.has_frozen_params", return_value=True
|
||||
):
|
||||
# A call to fn1 should miss in the cache since we do not consider
|
||||
# the constant values.
|
||||
|
@ -50,8 +50,8 @@ from .. import config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.codecache import CompiledFxGraph
|
||||
from torch._inductor.compile_fx import _CompileFxCallableEx, _CompileFxKwargsEx
|
||||
from torch._inductor.output_code import CompiledFxGraph
|
||||
from torch._inductor.utils import InputType
|
||||
|
||||
|
||||
|
@ -23,7 +23,6 @@ from torch._inductor.codecache import (
|
||||
_ident,
|
||||
add_ephemeral_timeout_increase_for_distributed,
|
||||
BypassFxGraphCache,
|
||||
CompiledFxGraph,
|
||||
create_cache,
|
||||
extract_tensor_metadata_for_cache_key,
|
||||
FxGraphCache,
|
||||
@ -51,6 +50,7 @@ from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noq
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.compile_fx import _CompileFxKwargs
|
||||
from torch._inductor.output_code import CompiledFxGraph
|
||||
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
|
||||
from torch._inductor.utils import BoxedBool
|
||||
from torch.fx.node import Node
|
||||
|
@ -36,22 +36,25 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Counter,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
# WARNING: Do not directly import has_frozen_params, it is monkeypatched in
|
||||
# python test/inductor/test_codecache.py
|
||||
# TestFxGraphCache.test_constant_handling_device_cpu
|
||||
# TODO: Why are we monkeypatching it......
|
||||
import torch._inductor.output_code as output_code
|
||||
import torch.distributed as dist
|
||||
from torch import SymInt, Tensor
|
||||
from torch._dynamo.utils import (
|
||||
@ -72,7 +75,7 @@ from torch._utils_internal import log_cache_bypass
|
||||
from .remote_cache import create_cache
|
||||
from .runtime import autotune_cache
|
||||
from .runtime.autotune_cache import AutotuneCacheBundler
|
||||
from .triton_bundler import TritonBundler, TritonKernelArtifacts
|
||||
from .triton_bundler import TritonBundler
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -82,6 +85,7 @@ if TYPE_CHECKING:
|
||||
from collections.abc import KeysView
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .output_code import CompiledFxGraph
|
||||
from .remote_cache import JsonDataTy, RemoteCache
|
||||
from .utils import InputType
|
||||
|
||||
@ -101,11 +105,7 @@ from torch._inductor.cpp_builder import (
|
||||
normalize_path_separator,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import pick_vec_isa
|
||||
from torch._inductor.cudagraph_utils import (
|
||||
BoxedDeviceIndex,
|
||||
CudagraphCachedInfo,
|
||||
log_cudagraph_skip_and_bump_counter,
|
||||
)
|
||||
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_module_to_triton_kernel,
|
||||
_reload_python_module,
|
||||
@ -885,10 +885,6 @@ class FxGraphHashDetails:
|
||||
return custom_pass.uuid()
|
||||
|
||||
|
||||
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
||||
return getattr(gm, "_has_frozen_params", False)
|
||||
|
||||
|
||||
def compiled_fx_graph_hash(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
@ -901,7 +897,7 @@ def compiled_fx_graph_hash(
|
||||
# To support caching when the graph has frozen params, we ignore the tensor values
|
||||
# of non-inlined constants since they won't be included in the cache entry. Without
|
||||
# freezing, we want to include the values of any constant attribute.
|
||||
include_non_inlined = not has_frozen_params(gm)
|
||||
include_non_inlined = not output_code.has_frozen_params(gm)
|
||||
|
||||
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
|
||||
has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0
|
||||
@ -1399,7 +1395,9 @@ class FxGraphCache:
|
||||
raise BypassFxGraphCache("Unsupported post grad custom pass")
|
||||
|
||||
# Freezing can embed constants that wouldn't be static across runs.
|
||||
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
|
||||
if output_code.has_frozen_params(
|
||||
gm
|
||||
) and not torch._utils_internal.justknobs_check(
|
||||
"pytorch/inductor:allow_freezing_with_caching"
|
||||
):
|
||||
raise BypassFxGraphCache("Skipping graph with frozen constants")
|
||||
@ -1683,121 +1681,6 @@ class FxGraphCache:
|
||||
pass
|
||||
|
||||
|
||||
_StrideExprStr: TypeAlias = str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompiledFxGraph:
|
||||
"""
|
||||
Class holding a compiled FX graph. This is the object serialized on disk
|
||||
to support FxGraph caching.
|
||||
"""
|
||||
|
||||
current_callable: Optional[Callable[..., Any]]
|
||||
cache_key: str
|
||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||
cache_linemap: Optional[List[Tuple[int, str]]]
|
||||
device_types: Set[str]
|
||||
device_idxs: Set[int]
|
||||
mutated_inputs: Set[str]
|
||||
mutated_input_idxs: Set[int]
|
||||
# We populate exactly one of the next two fields. In the common case, we store the
|
||||
# constant attirbutes in the cache entry and re-attach them to the module created in
|
||||
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
|
||||
# however, we save the mapping from attribute names in the GraphLowering to the
|
||||
# original name of the attribute in the GraphModule. When we create the module from
|
||||
# the cache entry, we then look up the constants from the current GraphModule. This
|
||||
# scheme allows us to support caching with freezing.
|
||||
allocated_constant_name: Optional[Dict[str, str]]
|
||||
constants: Optional[Dict[str, torch.Tensor]]
|
||||
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
|
||||
disabled_cudagraphs_reason: Optional[str]
|
||||
metrics_deltas: metrics.CachedMetricsDeltas
|
||||
counter_deltas: Counter[str]
|
||||
# This is a string representation of an expression we serialize
|
||||
# with the object so the guards can be evaluated in a different
|
||||
# context in order to verify the validity of serving a cached
|
||||
# fx graph. The expression must be generated by:
|
||||
# ShapeEnv.produce_guards_expression()
|
||||
guards_expr: Optional[str]
|
||||
|
||||
cudagraph_info: Optional[CudagraphCachedInfo]
|
||||
fx_kwargs: _CompileFxKwargs
|
||||
inputs_to_check: Sequence[int]
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||
|
||||
_time_taken_ns: Optional[int] = None
|
||||
_boxed_call: Optional[bool] = None
|
||||
_fx_graph_cache_key: Optional[str] = None
|
||||
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_callable: Optional[Callable[..., Any]],
|
||||
graph: GraphLowering,
|
||||
gm: torch.fx.GraphModule,
|
||||
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
|
||||
disabled_cudagraphs_reason: Optional[str],
|
||||
metrics_deltas: metrics.CachedMetricsDeltas,
|
||||
counter_deltas: Counter[str],
|
||||
) -> None:
|
||||
self.current_callable = current_callable
|
||||
self.cache_key = graph.cache_key
|
||||
if graph.cache_path:
|
||||
with open(graph.cache_path) as f:
|
||||
self.source_code = f.read()
|
||||
self.cache_linemap = graph.cache_linemap
|
||||
# TODO - ordered set
|
||||
self.device_types = set(graph.device_types)
|
||||
self.device_idxs = set(graph.device_idxs)
|
||||
self.mutated_inputs = set(graph.mutated_inputs)
|
||||
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
||||
if has_frozen_params(gm):
|
||||
self.allocated_constant_name = graph.allocated_constant_name
|
||||
self.constants = None
|
||||
else:
|
||||
self.allocated_constant_name = None
|
||||
self.constants = graph.constants
|
||||
self.torchbind_constants = graph.torchbind_constants
|
||||
self.output_strides = output_strides
|
||||
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
||||
self.metrics_deltas = metrics_deltas
|
||||
self.counter_deltas = counter_deltas
|
||||
self.guards_expr = None
|
||||
self.cudagraph_info = None
|
||||
self.fx_kwargs = {}
|
||||
self.inputs_to_check = ()
|
||||
self.boxed_forward_device_index = None
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
assert self.current_callable is not None
|
||||
try:
|
||||
return self.current_callable(inputs)
|
||||
finally:
|
||||
AutotuneCacheBundler.end_compile()
|
||||
|
||||
def get_constants(
|
||||
self, gm: Optional[torch.fx.GraphModule]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get the constant attributes.
|
||||
"""
|
||||
# Normal case: The constants are stored in the entry.
|
||||
if self.constants is not None:
|
||||
return self.constants
|
||||
|
||||
# Freezing case: Look up the constants from attributes on the GraphModule using
|
||||
# the allocated_constant_name map.
|
||||
assert gm is not None
|
||||
assert self.allocated_constant_name is not None
|
||||
constants = {
|
||||
name: getattr(gm, orig_name)
|
||||
for name, orig_name in self.allocated_constant_name.items()
|
||||
}
|
||||
return constants
|
||||
|
||||
|
||||
def run_command_and_check(cmd_: str) -> None:
|
||||
with dynamo_timed("run_command_and_check", log_pt2_compile_event=True):
|
||||
cmd = shlex.split(cmd_)
|
||||
|
@ -51,12 +51,7 @@ from torch._dynamo.utils import (
|
||||
)
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
||||
from torch._inductor.codecache import (
|
||||
_StrideExprStr,
|
||||
code_hash,
|
||||
CompiledFxGraph,
|
||||
FxGraphCache,
|
||||
)
|
||||
from torch._inductor.codecache import code_hash, FxGraphCache
|
||||
from torch._inductor.cudagraph_utils import (
|
||||
BoxedDeviceIndex,
|
||||
CudagraphCachedInfo,
|
||||
@ -65,6 +60,7 @@ from torch._inductor.cudagraph_utils import (
|
||||
PlaceholderInfo,
|
||||
)
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._inductor.output_code import CompiledFxGraph
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import (
|
||||
BoxedBool,
|
||||
@ -110,6 +106,7 @@ from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.output_code import _StrideExprStr
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from .ir import ExternKernelNode
|
||||
|
179
torch/_inductor/output_code.py
Normal file
179
torch/_inductor/output_code.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""
|
||||
This provides an abstract class which parametrizes over an "output code" concept
|
||||
for Inductor. Intuitively, this represents the compiled callable which Inductor
|
||||
produces which you can call to get optimized code. However, this callable
|
||||
has some other capabilities:
|
||||
|
||||
- It is serializable, so you can save/load this product from disk without
|
||||
having to do compilation again.
|
||||
|
||||
- (When using remote cache) it is addressable, so you can save just a key
|
||||
which you can use to load this product from remote cache later.
|
||||
|
||||
This class is abstract because we have several different implementations of
|
||||
serialized format:
|
||||
|
||||
- Python wrapper (the default)
|
||||
|
||||
- AOTInductor (this produces ABI stable binaries which work across PyTorch
|
||||
versions)
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Counter,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .runtime.autotune_cache import AutotuneCacheBundler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, CudagraphCachedInfo
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
from .triton_bundler import TritonKernelArtifacts
|
||||
|
||||
|
||||
class OutputCode(Protocol):
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
...
|
||||
|
||||
|
||||
_StrideExprStr: TypeAlias = str
|
||||
|
||||
|
||||
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
||||
return getattr(gm, "_has_frozen_params", False)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompiledFxGraph:
|
||||
"""
|
||||
Class holding a compiled FX graph. This is the object serialized on disk
|
||||
to support FxGraph caching.
|
||||
"""
|
||||
|
||||
current_callable: Optional[Callable[..., Any]]
|
||||
cache_key: str
|
||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||
cache_linemap: Optional[List[Tuple[int, str]]]
|
||||
device_types: Set[str]
|
||||
device_idxs: Set[int]
|
||||
mutated_inputs: Set[str]
|
||||
mutated_input_idxs: Set[int]
|
||||
# We populate exactly one of the next two fields. In the common case, we store the
|
||||
# constant attirbutes in the cache entry and re-attach them to the module created in
|
||||
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
|
||||
# however, we save the mapping from attribute names in the GraphLowering to the
|
||||
# original name of the attribute in the GraphModule. When we create the module from
|
||||
# the cache entry, we then look up the constants from the current GraphModule. This
|
||||
# scheme allows us to support caching with freezing.
|
||||
allocated_constant_name: Optional[Dict[str, str]]
|
||||
constants: Optional[Dict[str, torch.Tensor]]
|
||||
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
|
||||
disabled_cudagraphs_reason: Optional[str]
|
||||
metrics_deltas: metrics.CachedMetricsDeltas
|
||||
counter_deltas: Counter[str]
|
||||
# This is a string representation of an expression we serialize
|
||||
# with the object so the guards can be evaluated in a different
|
||||
# context in order to verify the validity of serving a cached
|
||||
# fx graph. The expression must be generated by:
|
||||
# ShapeEnv.produce_guards_expression()
|
||||
guards_expr: Optional[str]
|
||||
|
||||
cudagraph_info: Optional[CudagraphCachedInfo]
|
||||
fx_kwargs: _CompileFxKwargs
|
||||
inputs_to_check: Sequence[int]
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||
|
||||
_time_taken_ns: Optional[int] = None
|
||||
_boxed_call: Optional[bool] = None
|
||||
_fx_graph_cache_key: Optional[str] = None
|
||||
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_callable: Optional[Callable[..., Any]],
|
||||
graph: GraphLowering,
|
||||
gm: torch.fx.GraphModule,
|
||||
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
|
||||
disabled_cudagraphs_reason: Optional[str],
|
||||
metrics_deltas: metrics.CachedMetricsDeltas,
|
||||
counter_deltas: Counter[str],
|
||||
) -> None:
|
||||
self.current_callable = current_callable
|
||||
self.cache_key = graph.cache_key
|
||||
if graph.cache_path:
|
||||
with open(graph.cache_path) as f:
|
||||
self.source_code = f.read()
|
||||
self.cache_linemap = graph.cache_linemap
|
||||
# TODO - ordered set
|
||||
self.device_types = set(graph.device_types)
|
||||
self.device_idxs = set(graph.device_idxs)
|
||||
self.mutated_inputs = set(graph.mutated_inputs)
|
||||
self.mutated_input_idxs = set(graph.mutated_input_idxs)
|
||||
if has_frozen_params(gm):
|
||||
self.allocated_constant_name = graph.allocated_constant_name
|
||||
self.constants = None
|
||||
else:
|
||||
self.allocated_constant_name = None
|
||||
self.constants = graph.constants
|
||||
self.torchbind_constants = graph.torchbind_constants
|
||||
self.output_strides = output_strides
|
||||
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
|
||||
self.metrics_deltas = metrics_deltas
|
||||
self.counter_deltas = counter_deltas
|
||||
self.guards_expr = None
|
||||
self.cudagraph_info = None
|
||||
self.fx_kwargs = {}
|
||||
self.inputs_to_check = ()
|
||||
self.boxed_forward_device_index = None
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
assert self.current_callable is not None
|
||||
try:
|
||||
return self.current_callable(inputs)
|
||||
finally:
|
||||
AutotuneCacheBundler.end_compile()
|
||||
|
||||
def get_constants(
|
||||
self, gm: Optional[torch.fx.GraphModule]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get the constant attributes.
|
||||
"""
|
||||
# Normal case: The constants are stored in the entry.
|
||||
if self.constants is not None:
|
||||
return self.constants
|
||||
|
||||
# Freezing case: Look up the constants from attributes on the GraphModule using
|
||||
# the allocated_constant_name map.
|
||||
assert gm is not None
|
||||
assert self.allocated_constant_name is not None
|
||||
constants = {
|
||||
name: getattr(gm, orig_name)
|
||||
for name, orig_name in self.allocated_constant_name.items()
|
||||
}
|
||||
return constants
|
||||
|
||||
|
||||
def _typecheck_CompiledFxGraph(h: CompiledFxGraph) -> OutputCode:
|
||||
return h
|
@ -1454,8 +1454,8 @@ def run_and_get_triton_code(fn, *args, **kwargs):
|
||||
|
||||
|
||||
def run_and_get_graph_lowering(fn, *args, **kwargs):
|
||||
from torch._inductor.codecache import CompiledFxGraph
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.output_code import CompiledFxGraph
|
||||
|
||||
real_init = CompiledFxGraph.__init__
|
||||
graph_lowerings = []
|
||||
|
Reference in New Issue
Block a user