Code motion CompiledFxGraph to a dedicated file (#141654)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141654
Approved by: https://github.com/aorenste, https://github.com/jansel
ghstack dependencies: #141491, #141492, #141574
This commit is contained in:
Edward Z. Yang
2024-11-27 06:04:14 -08:00
committed by PyTorch MergeBot
parent a7ca6a9113
commit dbbebee9d7
7 changed files with 199 additions and 140 deletions

View File

@ -36,22 +36,25 @@ from typing import (
Any,
Callable,
cast,
Counter,
Dict,
Generator,
List,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import TypeAlias
import torch
# WARNING: Do not directly import has_frozen_params, it is monkeypatched in
# python test/inductor/test_codecache.py
# TestFxGraphCache.test_constant_handling_device_cpu
# TODO: Why are we monkeypatching it......
import torch._inductor.output_code as output_code
import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.utils import (
@ -72,7 +75,7 @@ from torch._utils_internal import log_cache_bypass
from .remote_cache import create_cache
from .runtime import autotune_cache
from .runtime.autotune_cache import AutotuneCacheBundler
from .triton_bundler import TritonBundler, TritonKernelArtifacts
from .triton_bundler import TritonBundler
T = TypeVar("T")
@ -82,6 +85,7 @@ if TYPE_CHECKING:
from collections.abc import KeysView
from .compile_fx import _CompileFxKwargs
from .output_code import CompiledFxGraph
from .remote_cache import JsonDataTy, RemoteCache
from .utils import InputType
@ -101,11 +105,7 @@ from torch._inductor.cpp_builder import (
normalize_path_separator,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.cudagraph_utils import (
BoxedDeviceIndex,
CudagraphCachedInfo,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
@ -885,10 +885,6 @@ class FxGraphHashDetails:
return custom_pass.uuid()
def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
return getattr(gm, "_has_frozen_params", False)
def compiled_fx_graph_hash(
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
@ -901,7 +897,7 @@ def compiled_fx_graph_hash(
# To support caching when the graph has frozen params, we ignore the tensor values
# of non-inlined constants since they won't be included in the cache entry. Without
# freezing, we want to include the values of any constant attribute.
include_non_inlined = not has_frozen_params(gm)
include_non_inlined = not output_code.has_frozen_params(gm)
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0
@ -1399,7 +1395,9 @@ class FxGraphCache:
raise BypassFxGraphCache("Unsupported post grad custom pass")
# Freezing can embed constants that wouldn't be static across runs.
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
if output_code.has_frozen_params(
gm
) and not torch._utils_internal.justknobs_check(
"pytorch/inductor:allow_freezing_with_caching"
):
raise BypassFxGraphCache("Skipping graph with frozen constants")
@ -1683,121 +1681,6 @@ class FxGraphCache:
pass
_StrideExprStr: TypeAlias = str
@dataclasses.dataclass
class CompiledFxGraph:
"""
Class holding a compiled FX graph. This is the object serialized on disk
to support FxGraph caching.
"""
current_callable: Optional[Callable[..., Any]]
cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[List[Tuple[int, str]]]
device_types: Set[str]
device_idxs: Set[int]
mutated_inputs: Set[str]
mutated_input_idxs: Set[int]
# We populate exactly one of the next two fields. In the common case, we store the
# constant attirbutes in the cache entry and re-attach them to the module created in
# PyCodeCache.load_by_key_path. In the case that the graph has frozen parameters,
# however, we save the mapping from attribute names in the GraphLowering to the
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[Tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas
counter_deltas: Counter[str]
# This is a string representation of an expression we serialize
# with the object so the guards can be evaluated in a different
# context in order to verify the validity of serving a cached
# fx graph. The expression must be generated by:
# ShapeEnv.produce_guards_expression()
guards_expr: Optional[str]
cudagraph_info: Optional[CudagraphCachedInfo]
fx_kwargs: _CompileFxKwargs
inputs_to_check: Sequence[int]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
_time_taken_ns: Optional[int] = None
_boxed_call: Optional[bool] = None
_fx_graph_cache_key: Optional[str] = None
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
def __init__(
self,
current_callable: Optional[Callable[..., Any]],
graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[Tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas,
counter_deltas: Counter[str],
) -> None:
self.current_callable = current_callable
self.cache_key = graph.cache_key
if graph.cache_path:
with open(graph.cache_path) as f:
self.source_code = f.read()
self.cache_linemap = graph.cache_linemap
# TODO - ordered set
self.device_types = set(graph.device_types)
self.device_idxs = set(graph.device_idxs)
self.mutated_inputs = set(graph.mutated_inputs)
self.mutated_input_idxs = set(graph.mutated_input_idxs)
if has_frozen_params(gm):
self.allocated_constant_name = graph.allocated_constant_name
self.constants = None
else:
self.allocated_constant_name = None
self.constants = graph.constants
self.torchbind_constants = graph.torchbind_constants
self.output_strides = output_strides
self.disabled_cudagraphs_reason = disabled_cudagraphs_reason
self.metrics_deltas = metrics_deltas
self.counter_deltas = counter_deltas
self.guards_expr = None
self.cudagraph_info = None
self.fx_kwargs = {}
self.inputs_to_check = ()
self.boxed_forward_device_index = None
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
return self.current_callable(inputs)
finally:
AutotuneCacheBundler.end_compile()
def get_constants(
self, gm: Optional[torch.fx.GraphModule]
) -> Dict[str, torch.Tensor]:
"""
Get the constant attributes.
"""
# Normal case: The constants are stored in the entry.
if self.constants is not None:
return self.constants
# Freezing case: Look up the constants from attributes on the GraphModule using
# the allocated_constant_name map.
assert gm is not None
assert self.allocated_constant_name is not None
constants = {
name: getattr(gm, orig_name)
for name, orig_name in self.allocated_constant_name.items()
}
return constants
def run_command_and_check(cmd_: str) -> None:
with dynamo_timed("run_command_and_check", log_pt2_compile_event=True):
cmd = shlex.split(cmd_)