diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 8e82c4fecd33..f8f24a68ddd1 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch import nn from torch._dynamo.utils import counters from torch._inductor import comms -from torch._inductor.utils import is_fallback_op, run_and_get_code_before_compile +from torch._inductor.utils import is_fallback_op, run_and_get_code from torch.distributed._tensor import init_device_mesh from torch.distributed.fsdp import ( fully_shard, @@ -743,7 +743,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, if fwd_fullgraph else None ): - _, triton_codes = run_and_get_code_before_compile( + _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns( fwd_fullgraph=fwd_fullgraph @@ -751,7 +751,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, "inductor", fwd_fullgraph=fwd_fullgraph, bwd_resize_count_before_inductor=48 if fwd_fullgraph else None, - ), + ) ) if fwd_fullgraph: self.assertEqual( @@ -829,12 +829,12 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_inductor_fullgraph_False(self): self.skipTestForOldSm() - _, triton_codes = run_and_get_code_before_compile( + _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns(fwd_fullgraph=False), "inductor", fwd_fullgraph=False, - ), + ) ) # TODO: when fwd_fullgraph=False and there is graph break in FWD graph, # there are several recompiles, need to figure out why. @@ -969,7 +969,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, if fwd_fullgraph else None ): - _, triton_codes = run_and_get_code_before_compile( + _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_transformer_factory_fns( all_requires_grad=all_requires_grad, @@ -978,7 +978,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, "inductor", fwd_fullgraph=fwd_fullgraph, bwd_resize_count_before_inductor=76 if fwd_fullgraph else None, - ), + ) ) if fwd_fullgraph: self.assertEqual( @@ -1063,7 +1063,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950 ) with self._maybe_add_graph_break_to_sdpa(fwd_fullgraph): - _, triton_codes = run_and_get_code_before_compile( + _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_transformer_factory_fns( all_requires_grad=all_requires_grad, @@ -1071,7 +1071,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, ), "inductor", fwd_fullgraph=fwd_fullgraph, - ), + ) ) # TODO: when fwd_fullgraph=False and there is graph break in FWD graph, # there are several recompiles, need to figure out why. diff --git a/test/fx/test_graph_pickler.py b/test/fx/test_graph_pickler.py deleted file mode 100644 index 1cfd6a2ef576..000000000000 --- a/test/fx/test_graph_pickler.py +++ /dev/null @@ -1,96 +0,0 @@ -# Owner(s): ["module: fx"] - -# -# Tests the graph pickler by using pickling on all the inductor tests. -# - -import contextlib -import importlib -import os -import sys -from unittest.mock import patch - -import torch -import torch.library -from torch._dynamo.testing import make_test_cls_with_patches -from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import TEST_WITH_ASAN -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU - - -# Make the helper files in test/ importable -pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -sys.path.append(pytorch_test_dir) -from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library - check_model, - CommonTemplate, - copy_tests, -) - - -importlib.import_module("filelock") - -# xfail by default, set is_skip=True to skip -test_failures = {} - - -def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"): - return make_test_cls_with_patches( - cls, - "GraphPickler", - "_graph_pickler", - ( - torch._inductor.compile_fx, - "fx_compile_mode", - torch._inductor.compile_fx.FxCompileMode.SERIALIZE, - ), - xfail_prop=xfail_prop, - ) - - -GraphPicklerCommonTemplate = make_test_cls(CommonTemplate) - - -if HAS_CPU: - - class GraphPicklerCpuTests(TestCase): - common = check_model - device = "cpu" - - copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures) - - -class TestGraphPickler(TestCase): - def setUp(self): - torch._dynamo.reset() - TestCase.setUp(self) - - self._stack = contextlib.ExitStack() - self._stack.enter_context( - patch( - "torch._inductor.compile_fx.fx_compile_mode", - torch._inductor.compile_fx.FxCompileMode.SERIALIZE, - ) - ) - - def tearDown(self): - self._stack.close() - TestCase.tearDown(self) - torch._dynamo.reset() - - def test_simple(self): - # Make sure that compiling works when we pass the input + output from - # fx_codegen_and_compile() through serde. - - def fn(a, b): - return a + b - - check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True]))) - - -if __name__ == "__main__": - from torch._inductor.test_case import run_tests - - # Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068 - if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN: - run_tests(needs="filelock") diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 9969765fa8fc..64ff88001931 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1211,10 +1211,6 @@ class TypingVariable(VariableTracker): @functools.lru_cache(maxsize=1) def get_np_to_tnp_map(): - """ - This generates a mapping from numpy modules to their torch._numpy - modules equivalents. - """ from ..utils import NP_TO_TNP_MODULE np_fn_to_tnp_fn = {} @@ -1230,16 +1226,6 @@ def get_np_to_tnp_map(): return np_fn_to_tnp_fn -@functools.lru_cache(maxsize=1) -def get_tnp_to_np_map(): - """ - This is just the reverse mapping of get_np_to_tnp_map() - mapping from - torch._numpy modules to numpy equivalents. - """ - m = get_np_to_tnp_map() - return {v: k for k, v in m.items()} - - class NumpyVariable(VariableTracker): """ Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes. diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index d5d390bef7f7..e3798478e323 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,25 +1,21 @@ from __future__ import annotations import contextlib -import enum import functools import io import itertools import json import logging -import os import sys import time import warnings from abc import ABC, abstractmethod -from dataclasses import dataclass from inspect import currentframe from itertools import count from typing import ( Any, Callable, ContextManager, - Mapping, Optional, TYPE_CHECKING, TypeVar, @@ -60,18 +56,12 @@ from torch._functorch.aot_autograd import ( make_boxed_func, SerializableAOTDispatchCompiler, ) -from torch._inductor.codecache import ( - BypassFxGraphCache, - code_hash, - FxGraphCache, - output_code_log, -) +from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo from torch._inductor.debug import save_args_for_compile_fx_inner from torch._inductor.output_code import ( CompiledAOTI, CompiledFxGraph, - CompiledFxGraphConstants, CompiledFxGraphConstantsWithGm, get_expanded_dims, index_expanded_dims, @@ -156,43 +146,6 @@ if TYPE_CHECKING: GraphSignature, ) - -# For testing - use the serde FxCompile scheme to debug serialization and -# deserialization of GraphMoule and CompiledFxGraph. -class FxCompileMode(enum.Enum): - NORMAL = 0 - # For testing - use the serde FxCompile scheme to debug serialization and - # deserialization of GraphMoule and CompiledFxGraph. - SERIALIZE = 1 - - -def _fx_compile_mode_default() -> FxCompileMode: - name = "TORCHINDUCTOR_FX_COMPILE_MODE" - value = os.environ.get(name) - NORMAL = FxCompileMode.NORMAL - if value is None: - return NORMAL - try: - value = value.upper() - return FxCompileMode[value] - except KeyError: - import logging - - log = logging.getLogger(__name__) - log.error( - "Invalid value of %s for %s. Expected one of %s. Using default.", - value, - name, - ", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())), - ) - # Remove from the environment so subprocesses don't ALSO complain. - os.environ.pop(name) - return FxCompileMode.NORMAL - - -fx_compile_mode = _fx_compile_mode_default() - - log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs") @@ -798,11 +751,9 @@ def _compile_fx_inner( cache_event_time=start_time, key=cache_info.get("key") if cache_info else None, components=cache_info.get("components") if cache_info else None, - cache_bypass_reason=( - cache_info.get("cache_bypass_reason") - if cache_info - else "cache not enabled" - ), + cache_bypass_reason=cache_info.get("cache_bypass_reason") + if cache_info + else "cache not enabled", remote_cache_enabled=remote, local_cache_enabled=local, ) @@ -832,11 +783,6 @@ def _compile_fx_inner( class FxCompile(ABC): - """ - An FxCompile represents a mechanism that can turn a GraphModule into an - OutputCode. - """ - # TODO: We should probably eventually add some kind of async version of this # so we can kick off a compile and then go do other things - but we'll need # to know what kind of API we want for that first. @@ -1187,195 +1133,6 @@ class _InProcessFxCompile(FxCompile): ) -def _current_fake_mode() -> torch._subclasses.FakeTensorMode: - fake_mode = None - if context := torch._guards.TracingContext.try_get(): - fake_mode = context.fake_mode - if fake_mode is not None: - return fake_mode - - shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv() - return torch._subclasses.FakeTensorMode(shape_env=shape_env) - - -@dataclass -class _WireProtocolInput: - """ - For _SerializedFxCompile - encapsulates all the data being transferred - (sent) from the parent to the child. - """ - - gm: torch.fx.GraphModule - example_inputs: Sequence[InputType] - inputs_to_check: Sequence[int] - graph_kwargs: _CompileFxKwargs - # TODO: Add additional state to transfer to the child. - - def serialize(self) -> _WireProtocolPickledInput: - """ - Turns this object into a _WireProtocolPickledInput which can be - directly transferred across a stream. - """ - from torch.fx._graph_pickler import GraphPickler - - return _WireProtocolPickledInput(GraphPickler.dumps(self)) - - -@dataclass -class _WireProtocolPickledInput: - value: bytes - - def deserialize(self) -> _WireProtocolInput: - """ - Turn this streamable object back into a _WireProtocolInput. - """ - from torch.fx._graph_pickler import GraphPickler - - fake_mode = _current_fake_mode() - result = GraphPickler.loads(self.value, fake_mode) - assert isinstance(result, _WireProtocolInput) - return result - - -@dataclass -class _WireProtocolOutput: - """ - For _SerializedFxCompile - encapsulates all the data being transferred - (returned) back from the child to the parent. - """ - - graph: OutputCode - - def serialize(self) -> _WireProtocolPickledOutput: - """ - Turns this object into a _WireProtocolPickledOutput which can be - directly transferred across a stream. - """ - from torch.fx._graph_pickler import GraphPickler - - if isinstance(self.graph, CompiledFxGraph): - self.graph.prepare_for_serialization() - return _WireProtocolPickledOutput(GraphPickler.dumps(self)) - - -@dataclass -class _WireProtocolPickledOutput: - value: bytes - - def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput: - """ - Turn this streamable object back into a _WireProtocolOutput. - """ - from torch.fx._graph_pickler import GraphPickler - - fake_mode = _current_fake_mode() - result = GraphPickler.loads(self.value, fake_mode) - assert isinstance(result, _WireProtocolOutput) - if isinstance(result.graph, CompiledFxGraph): - result.graph.after_deserialization(constants) - return result - - -class _SerializedFxCompile(FxCompile): - """ - This is used to represent an FxCompile which occurs across a serialized - boundary. - """ - - @override - def codegen_and_compile( - self, - gm: GraphModule, - example_inputs: Sequence[InputType], - inputs_to_check: Sequence[int], - graph_kwargs: _CompileFxKwargs, - ) -> OutputCode: - # _context = torch._guards.TracingContext.try_get() - constants = CompiledFxGraphConstantsWithGm(gm) - - try: - input = _WireProtocolInput( - gm, - example_inputs, - inputs_to_check, - graph_kwargs, - ).serialize() - except (AttributeError, BypassFxGraphCache): - # For example: AttributeError: Can't pickle local object - # 'make_opaque_unary_fn..OpaqueUnaryFn' - - # TODO: scuba record about not being able to do this? - log.debug("Unable to pickle input graph or example inputs", exc_info=True) - - # Fallback to in-process - return _InProcessFxCompile().codegen_and_compile( - gm, example_inputs, inputs_to_check, graph_kwargs - ) - - output = self._send_to_child(input).deserialize(constants) - - self._postprocess(output) - - # TODO: Do we need to figure out what changed in TracingContext in the - # child and plumb that back up to the parent? - - return output.graph - - @abstractmethod - def _send_to_child( - self, pickled_input: _WireProtocolPickledInput - ) -> _WireProtocolPickledOutput: - # The implementation of this should transfer `input` to the child, call - # `_run_in_child(input)` and transfer the result back. - ... - - def _postprocess(self, output: _WireProtocolOutput) -> None: - pass - - @classmethod - def _run_in_child( - cls, - pickled_input: _WireProtocolPickledInput, - extra_env: Optional[Mapping[str, str]] = None, - ) -> _WireProtocolPickledOutput: - with contextlib.ExitStack() as stack: - if extra_env is not None: - import unittest - - stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env)) - - # TODO: Should we split the input into multiple sections where each - # section sets up state for the previous section? (i.e. a Config section - # which we decode and apply, followed by a FakeTensorMode section which - # we decode and apply, etc) - input = pickled_input.deserialize() - - stack.enter_context(DebugContext()) - - output_graph = _InProcessFxCompile().codegen_and_compile( - input.gm, - input.example_inputs, - input.inputs_to_check, - input.graph_kwargs, - ) - - return _WireProtocolOutput( - output_graph, - ).serialize() - - -# This is a debugging/testing implementation of FxCompile which serializes the -# input and output but still runs the FxCompile in-process. -class _DebugSerdeFxCompile(_SerializedFxCompile): - @override - def _send_to_child( - self, pickled_input: _WireProtocolPickledInput - ) -> _WireProtocolPickledOutput: - # For debugging just serde the input and output but don't run in a - # subprocess. - return self._run_in_child(pickled_input) - - def fx_codegen_and_compile( gm: GraphModule, example_inputs: Sequence[InputType], @@ -1384,13 +1141,7 @@ def fx_codegen_and_compile( inputs_to_check: Sequence[int], **graph_kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: - scheme: FxCompile - if fx_compile_mode == FxCompileMode.NORMAL: - scheme = _InProcessFxCompile() - elif fx_compile_mode == FxCompileMode.SERIALIZE: - scheme = _DebugSerdeFxCompile() - else: - raise NotImplementedError + scheme: FxCompile = _InProcessFxCompile() return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -1517,13 +1268,11 @@ def cudagraphify_impl( # allocate static tensor inputs static_inputs = [ - ( - x - if not isinstance(x, torch.Tensor) - else static_input(x) - if idx not in static_input_idxs - else x.detach() - ) + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() for idx, x in enumerate(inputs) ] @@ -1755,11 +1504,9 @@ def fw_compiler_freezing( def get_cpp_wrapper_config() -> dict[str, object]: return { # Set autotune_at_compile_time to True as default if the option is not explicitly set - "triton.autotune_at_compile_time": ( - config.triton.autotune_at_compile_time - if config.triton.autotune_at_compile_time is not None - else has_triton() - ), + "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + else has_triton(), "triton.autotune_cublasLt": False, "triton.cudagraphs": False, # TODO: to be removed "triton.store_cubin": True, @@ -2093,11 +1840,9 @@ def compile_fx( model_outputs_node.meta["user_visible_output_idxs"] = [] fixed = count_tangents(gm) - with ( - config.patch(get_cpp_wrapper_config()) - if config.cpp_wrapper - else contextlib.nullcontext() - ): + with config.patch( + get_cpp_wrapper_config() + ) if config.cpp_wrapper else contextlib.nullcontext(): return inner_compile( gm, example_inputs, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 4faf1803c9a6..a79a3e05a434 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1,5 +1,4 @@ import contextlib -import enum import functools import itertools import logging @@ -274,12 +273,6 @@ def mark_nodes_dislike_padding( cur.meta["dislike_padding"] = True -class SaveOutputCodeContext(enum.Enum): - BEFORE_COMPILE = 0 - AFTER_DESERIALIZATION = 1 - AFTER_COMPILE = 2 - - class GraphLowering(torch.fx.Interpreter): graph_outputs: list[ir.IRNode] @@ -1989,7 +1982,7 @@ class GraphLowering(torch.fx.Interpreter): return total_bytes, node_counts, node_runtimes @staticmethod - def save_output_code(code: str, context: SaveOutputCodeContext) -> None: + def save_output_code(code: str) -> None: # No-op to be patched for unit tests pass @@ -2017,7 +2010,7 @@ class GraphLowering(torch.fx.Interpreter): + '"""\n' ) code = tuning_code + code - GraphLowering.save_output_code(code, SaveOutputCodeContext.BEFORE_COMPILE) + GraphLowering.save_output_code(code) output_code_log.debug("Output code: \n%s", code) inductor_meta = autotune_cache.inductor_meta_from_config() diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 41a195be0cf7..08ee0cd51e9e 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -460,15 +460,7 @@ class CompiledFxGraph(OutputCode): def __call__(self, inputs: Sequence[Any]) -> Any: assert self.current_callable is not None try: - result = self.current_callable(inputs) - - from torch._inductor.graph import GraphLowering, SaveOutputCodeContext - - GraphLowering.save_output_code( - self.source_code, SaveOutputCodeContext.AFTER_COMPILE - ) - - return result + return self.current_callable(inputs) finally: get_runtime_metrics_context().finish() AutotuneCacheBundler.end_compile() @@ -557,12 +549,10 @@ class CompiledFxGraph(OutputCode): write_atomic(artifact_path, code, make_dirs=True) - from .graph import GraphLowering, SaveOutputCodeContext + from .graph import GraphLowering # This is used by tests to check the output for specific details. - GraphLowering.save_output_code( - code, SaveOutputCodeContext.AFTER_DESERIALIZATION - ) + GraphLowering.save_output_code(code) try: with dynamo_timed( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 6d8b2e38232c..bca55585450c 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -49,7 +49,6 @@ if TYPE_CHECKING: from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from .codegen.common import WorkspaceArg - from .graph import SaveOutputCodeContext from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map_only @@ -1466,19 +1465,13 @@ class DebugDirManager: torch._dynamo.config.debug_dir_root = self.prev_debug_name -def _run_and_get_code_for_context( - fn: Callable[P, _T], - for_context: SaveOutputCodeContext, - args: P.args, - kwargs: P.kwargs, -) -> tuple[_T, list[str]]: +def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]: from .graph import GraphLowering source_codes: list[str] = [] - def save_output_code(code: str, context: SaveOutputCodeContext): - if context == for_context: - source_codes.append(code) + def save_output_code(code: str): + source_codes.append(code) with mock.patch.object(GraphLowering, "save_output_code", save_output_code): torch._dynamo.reset() @@ -1486,22 +1479,6 @@ def _run_and_get_code_for_context( return result, source_codes -def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]: - from .graph import SaveOutputCodeContext - - return _run_and_get_code_for_context( - fn, SaveOutputCodeContext.AFTER_COMPILE, args, kwargs - ) - - -def run_and_get_code_before_compile(fn, *args, **kwargs) -> tuple[Any, list[str]]: - from .graph import SaveOutputCodeContext - - return _run_and_get_code_for_context( - fn, SaveOutputCodeContext.BEFORE_COMPILE, args, kwargs - ) - - def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, list[str]]: result, source_codes = run_and_get_code(fn, *args, **kwargs) kernels = [] @@ -1521,13 +1498,12 @@ def run_fw_bw_and_get_code(fn): def get_code(fn, *args, **kwargs): """Get the inductor-generated code, but skip any actual compilation or running.""" - from .graph import GraphLowering, SaveOutputCodeContext + from .graph import GraphLowering source_codes: list[str] = [] - def save_output_code(code: str, context: SaveOutputCodeContext): - if context == SaveOutputCodeContext.AFTER_COMPILE: - source_codes.append(code) + def save_output_code(code: str): + source_codes.append(code) def patched_compile_to_module(self: GraphLowering): class DummyModule: @@ -1545,7 +1521,7 @@ def get_code(fn, *args, **kwargs): ) # Skip all the actual compiling. nonlocal save_output_code - save_output_code(code, SaveOutputCodeContext.BEFORE_COMPILE) + save_output_code(code) return DummyModule() diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 7890d8d69dd5..6456c0f35d6a 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -379,7 +379,7 @@ class FakeTensorConverter: out = self.meta_converter( t, shape_env=shape_env, - callback=mk_fake_tensor, + callback=mk_fake_tensor, # type: ignore[arg-type] source=source, symbolic_context=symbolic_context, trace=trace, diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 3ffbaa96ec56..d0d1905ae6ea 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -541,27 +541,11 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]): return self.func(new_base, symint_visitor_fn, tensor_visitor_fn) -# A callback where the device is either optional or required. -# All of these satisfy this protocol: -# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str]) -# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") -# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) -class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): - def __call__( - self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str] - ) -> _TensorT_cov: - ... - - class _MetaTensorCallbackKwargs(TypedDict, total=False): device: Union[torch.device, str] -# A callback where the device may not be provided (is optional). -# All of these satisfy this protocol: -# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") -# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) -class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]): +class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): def __call__( self, arg: Callable[[], torch.Tensor], @@ -848,13 +832,11 @@ class MetaConverter(Generic[_TensorT]): self, t: MetaTensorDesc, shape_env: Optional[ShapeEnv], - callback_: _MetaTensorCallback[_TensorT], + callback: _MetaTensorCallback[_TensorT], source: Optional[Source], symbolic_context: Optional[SymbolicContext], ) -> _TensorT: - callback: _MetaTensorCallbackOptDevice = functools.partial( - callback_, device=t.device - ) + callback = functools.partial(callback, device=t.device) if source is None: from torch._dynamo.source import ConstantSource @@ -999,7 +981,7 @@ class MetaConverter(Generic[_TensorT]): symbolic_context: Optional[ torch.fx.experimental.symbolic_shapes.SymbolicContext ], - callback: _MetaTensorCallbackOptDevice[_TensorT], + callback: _MetaTensorCallback[_TensorT], source: torch._guards.Source, ) -> _TensorT: # We are hitting plain meta_desc tensor so actually @@ -1234,7 +1216,7 @@ class MetaConverter(Generic[_TensorT]): shape_env: Optional[ torch.fx.experimental.symbolic_shapes.ShapeEnv ] = shape_env, - callback: _MetaTensorCallbackOptDevice[_TensorT] = callback, + callback: _MetaTensorCallback[_TensorT] = callback, ) -> torch.Tensor: # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: @@ -1787,9 +1769,7 @@ class MetaConverter(Generic[_TensorT]): # Thanks to storage resizing, it's possible to end up with a tensor # that advertises a real size, but has a storage that actually has zero bytes. # Need to reflect this in the generated FakeTensor. - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - - if t.storage is not None and guard_size_oblivious(t.storage.size == 0): + if t.storage is not None and t.storage.size == 0: r.untyped_storage().resize_(0) if t.is_parameter: diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py deleted file mode 100644 index 5a5c77cd5881..000000000000 --- a/torch/fx/_graph_pickler.py +++ /dev/null @@ -1,582 +0,0 @@ -import dataclasses -import importlib -import io -import pickle -from abc import abstractmethod -from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union -from typing_extensions import override, Self - -import torch -import torch.utils._pytree as pytree -from torch._guards import TracingContext -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor -from torch._subclasses.meta_utils import ( - MetaConverter, - MetaTensorDesc, - MetaTensorDescriber, -) -from torch.fx.experimental.sym_node import SymNode -from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.utils._mode_utils import no_dispatch - - -_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat) - - -class GraphPickler(pickle.Pickler): - """ - GraphPickler is a Pickler which helps pickling fx graph - in particular - GraphModule. - """ - - def __init__(self, file: io.BytesIO) -> None: - super().__init__(file) - - # This abomination is so we can pass external decoding state to the - # unpickler functions. We serialize _unpickle_state as a persistent - # external item and when we deserialize it we return the common state - # object. - self._unpickle_state = _UnpickleStateToken(object()) - - # This is used to describe tensors. It needs to be common across the - # pickle so that duplicates and views are properly handled. - self._meta_tensor_describer = MetaTensorDescriber(copy_data=False) - - @override - def reducer_override( - self, obj: object - ) -> Tuple[Callable[..., Any], Tuple[Any, ...]]: - # This function is supposed to return either NotImplemented (meaning to - # do the default pickle behavior) or a pair of (unpickle callable, data - # to pass to unpickle). - - # We could instead teach individual classes how to pickle themselves but - # that has a few problems: - # - # 1. If we have some special needs (maybe for this use-case we don't - # want to fully serialize every field) then we're adding private - # details to a public interface. - # - # 2. If we need to have some common shared data (such as a - # FakeTensorMode) which is passed to each value it's harder to - # support. - - # These are the types that need special handling. See the individual - # *PickleData classes for details on pickling that particular type. - if isinstance(obj, FakeTensor): - return _TensorPickleData.reduce_helper(self, obj) - elif isinstance(obj, torch.fx.GraphModule): - return _GraphModulePickleData.reduce_helper(self, obj) - elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)): - return _OpPickleData.reduce_helper(self, obj) - elif isinstance(obj, ShapeEnv): - return _ShapeEnvPickleData.reduce_helper(self, obj) - elif isinstance(obj, torch.SymInt): - return _SymNodePickleData.reduce_helper(self, obj) - elif isinstance(obj, torch._guards.TracingContext): - return _TracingContextPickleData.reduce_helper(self, obj) - else: - # We should never get a raw Node! - assert not isinstance(obj, torch.fx.Node) - if reduce := _TorchNumpyPickleData.reduce_helper(self, obj): - return reduce - - # returning `NotImplemented` causes pickle to revert to the default - # behavior for this object. - return NotImplemented - - @override - def persistent_id(self, obj: object) -> Optional[str]: - if obj is self._unpickle_state: - return "unpickle_state" - else: - return None - - @classmethod - def dumps(cls, obj: object) -> bytes: - """ - Pickle an object. - """ - with io.BytesIO() as stream: - pickler = cls(stream) - pickler.dump(obj) - return stream.getvalue() - - @staticmethod - def loads(data: bytes, fake_mode: FakeTensorMode) -> object: - """ - Unpickle an object. - """ - state = _UnpickleState(fake_mode) - with io.BytesIO(data) as stream: - unpickler = _GraphUnpickler(stream, state) - return unpickler.load() - - -class _UnpickleState: - def __init__(self, fake_mode: FakeTensorMode) -> None: - self.fake_mode = fake_mode - self.meta_converter: MetaConverter[FakeTensor] = MetaConverter() - - -# This token is passed when pickling to indicate that we want to use the -# unpickler's _UnpickleState as a parameter in that position. -_UnpickleStateToken = NewType("_UnpickleStateToken", object) - - -class _GraphUnpickler(pickle.Unpickler): - def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None: - super().__init__(stream) - self._unpickle_state = unpickle_state - - @override - def persistent_load(self, pid: object) -> object: - if pid == "unpickle_state": - return self._unpickle_state - else: - raise pickle.UnpicklingError("Invalid persistent ID") - - -class _ShapeEnvPickleData: - data: Dict[str, object] - - @classmethod - def reduce_helper( - cls, pickler: GraphPickler, obj: ShapeEnv - ) -> Tuple[ - Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken] - ]: - return cls.unpickle, (cls(obj), pickler._unpickle_state) - - def __init__(self, env: ShapeEnv) -> None: - # In theory pickle should recognize that a given ShapeEnv was already - # pickled and reuse the resulting _ShapeEnvPickleData (so two objects - # pointing at the same ShapeEnv get the same ShapeEnv out). - assert not env._translation_validation_enabled - self.data = env.__dict__.copy() - del self.data["tracked_fakes"] - del self.data["fake_tensor_cache"] - - def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv: - # Fill in the existing ShapeEnv rather than creating a new one - assert unpickle_state.fake_mode - assert unpickle_state.fake_mode.shape_env - - for k, v in self.data.items(): - setattr(unpickle_state.fake_mode.shape_env, k, v) - - return unpickle_state.fake_mode.shape_env - - -class _SymNodePickleData: - @classmethod - def reduce_helper( - cls, - pickler: GraphPickler, - obj: _SymNodeT, - ) -> Tuple[ - Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken] - ]: - args = (cls(obj.node), pickler._unpickle_state) - if isinstance(obj, torch.SymInt): - return _SymNodePickleData.unpickle_sym_int, args - else: - raise NotImplementedError(f"Unhandled SymNode type {type(obj)}") - - def __init__(self, node: SymNode) -> None: - self.expr = node._expr - self.shape_env = node.shape_env - self.pytype = node.pytype - self.hint = node._hint - - def _to_sym_node(self) -> SymNode: - from torch.fx.experimental.sym_node import SymNode - - assert self.shape_env is not None - return SymNode(self.expr, self.shape_env, self.pytype, self.hint) - - def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt: - return torch.SymInt(self._to_sym_node()) - - -class _TensorPickleData: - metadata: MetaTensorDesc[FakeTensor] - - @classmethod - def reduce_helper( - cls, pickler: GraphPickler, obj: FakeTensor - ) -> Tuple[ - Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken] - ]: - return cls.unpickle, ( - cls(pickler._meta_tensor_describer, obj), - pickler._unpickle_state, - ) - - def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None: - # THINGS TO WORRY ABOUT: - # 1. Need to make sure that two tensors with the same id end up with the - # same id on the other side of the wire. - - metadata = describer.describe_tensor(t) - - # view_func is fine if it's either None or a _FakeTensorViewFunc. A - # custom one (which is basically a lambda) can't be serialized. - assert not metadata.view_func or isinstance( - metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc - ) - self.metadata = dataclasses.replace(metadata, fake_mode=None) - - # Some debugging/verification - for k in MetaTensorDesc._UNSERIALIZABLE: - if k in ("fake_mode", "view_func"): - continue - assert ( - getattr(self.metadata, k) is None - ), f"not None: {k}: {getattr(self.metadata, k)}" - - def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: - # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? - metadata = dataclasses.replace( - self.metadata, - fake_mode=unpickle_state.fake_mode, - ) - - def with_fake( - make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str] - ) -> FakeTensor: - with no_dispatch(): - return FakeTensor( - unpickle_state.fake_mode, - make_meta_t(), - device, - ) - - return unpickle_state.meta_converter.meta_tensor( - metadata, - unpickle_state.fake_mode.shape_env, - with_fake, - None, - None, - ) - - -class _TorchNumpyPickleData: - @classmethod - def reduce_helper( - cls, pickler: GraphPickler, obj: object - ) -> Optional[ - Tuple[ - Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken] - ] - ]: - if data := cls.from_object(obj): - return (cls.unpickle, (data, pickler._unpickle_state)) - else: - return None - - def __init__(self, mod: str, name: str) -> None: - self.mod = mod - self.name = name - - def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]: - np = getattr(importlib.import_module(self.mod), self.name) - return torch._dynamo.variables.misc.get_np_to_tnp_map()[np] - - @classmethod - def from_object(cls, tnp: object) -> Optional[Self]: - if not callable(tnp): - return None - - tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map() - try: - if not (np := tnp_to_np.get(tnp)): - return None - except TypeError: - return None - - if not (mod := getattr(np, "__module__", None)): - mod = "numpy" - - if not (name := getattr(np, "__name__", None)): - return None - - assert np == getattr(importlib.import_module(mod), name) - return cls(mod, name) - - -class _GraphModulePickleData: - @classmethod - def reduce_helper( - cls, pickler: GraphPickler, obj: torch.fx.GraphModule - ) -> Tuple[ - Callable[[Self, _UnpickleState], torch.fx.GraphModule], - Tuple[Self, _UnpickleStateToken], - ]: - return cls.unpickle, ( - cls(obj), - pickler._unpickle_state, - ) - - def __init__(self, gm: torch.fx.GraphModule) -> None: - # Need to do this to ensure the code is created for later pickling. - if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule): - _python_code = gm._real_recompile() - else: - _python_code = gm.recompile() - self.gm_dict = gm.__dict__.copy() - del self.gm_dict["_graph"] - self.graph = _GraphPickleData(gm._graph) - - def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule: - gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule) - gm.__dict__ = self.gm_dict - gm._graph = self.graph.unpickle(gm, unpickle_state) - return gm - - -class _NodePickleData: - def __init__( - self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"] - ) -> None: - self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args) - self.kwargs = pytree.tree_map_only( - torch.fx.Node, lambda n: mapping[n], node.kwargs - ) - # -- self.graph = node.graph - self.name = node.name - self.op = node.op - self.target = _OpPickleData.pickle(node.target) - # self.input_nodes = node._input_nodes - # self.users = node.users - self.type = node.type - # self.sort_key = node._sort_key - # self.repr_fn = node._repr_fn - # self.meta = node.meta - self.meta = node.meta - - def unpickle( - self, - graph: torch.fx.Graph, - mapping: Dict["_NodePickleData", torch.fx.Node], - unpickle_state: _UnpickleState, - ) -> torch.fx.Node: - args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args) - kwargs = pytree.tree_map_only( - _NodePickleData, lambda n: mapping[n], self.kwargs - ) - target = self.target.unpickle(unpickle_state) - assert callable(target) or isinstance(target, str) - node = graph.create_node(self.op, target, args, kwargs, self.name, self.type) - node.meta = self.meta - return node - - -class _OpPickleData: - @classmethod - def reduce_helper( - cls, pickler: GraphPickler, op: object - ) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]: - result = cls.pickle(op) - return (result.unpickle, (pickler._unpickle_state,)) - - @classmethod - def pickle(cls, op: object) -> "_OpPickleData": - if isinstance(op, str): - return _OpStrPickleData(op) - - name = torch.fx.Node._pretty_print_target(op) - if isinstance(op, torch._ops.OpOverload): - return cls._pickle_op(name, _OpOverloadPickleData) - elif isinstance(op, torch._ops.OpOverloadPacket): - return cls._pickle_op(name, _OpOverloadPacketPickleData) - elif name.startswith(("builtins.", "math.", "torch.")): - root, detail = name.split(".", 1) - return _OpBuiltinPickleData(root, detail) - elif name.startswith("operator."): - _, detail = name.split(".", 1) - return _OpOperatorPickleData(detail) - else: - # TODO: raise a BypassFxGraphCache so we will just bypass this one... - raise NotImplementedError(f"TARGET: {type(op)} {op} {name}") - - @staticmethod - def _pickle_op( - name: str, - datacls: Union[ - Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"] - ], - ) -> "_OpPickleData": - if not name.startswith("torch.ops.aten"): # TODO: What's the full list? - from torch._inductor.codecache import BypassFxGraphCache - - raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}") - return datacls(name) - - @abstractmethod - def unpickle(self, unpickle_state: _UnpickleState) -> object: - pass - - @classmethod - def _lookup_global_by_name(cls, name: str) -> object: - """ - Like `globals()[name]` but supports dotted names. - """ - if "." in name: - mod, rest = name.split(".", 1) - root = globals()[mod] - return cls._getattr_by_name(root, rest) - else: - return globals()[name] - - @staticmethod - def _getattr_by_name(root: object, name: str) -> object: - """ - Like `getattr(root, name)` but supports dotted names. - """ - while "." in name: - mod, name = name.split(".", 1) - root = getattr(root, mod) - return getattr(root, name) - - -class _OpStrPickleData(_OpPickleData): - def __init__(self, name: str) -> None: - self.name = name - - def unpickle(self, unpickle_state: _UnpickleState) -> str: - return self.name - - -class _OpOverloadPickleData(_OpPickleData): - def __init__(self, name: str) -> None: - self.name = name - - def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload: - obj = self._lookup_global_by_name(self.name) - assert isinstance(obj, torch._ops.OpOverload) - return obj - - -class _OpOverloadPacketPickleData(_OpPickleData): - def __init__(self, name: str) -> None: - self.name = name - - def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket: - obj = self._lookup_global_by_name(self.name) - assert isinstance(obj, torch._ops.OpOverloadPacket) - return obj - - -class _OpBuiltinPickleData(_OpPickleData): - def __init__(self, root: str, name: str) -> None: - self.root = root - self.name = name - - def unpickle(self, unpickle_state: _UnpickleState) -> object: - if self.root == "builtins": - return __builtins__.get(self.name) # type: ignore[attr-defined] - elif self.root == "math": - import math - - return self._getattr_by_name(math, self.name) - elif self.root == "torch": - return self._getattr_by_name(torch, self.name) - else: - raise NotImplementedError - - -class _OpOperatorPickleData(_OpPickleData): - def __init__(self, name: str) -> None: - self.name = name - - def unpickle(self, unpickle_state: _UnpickleState) -> object: - import operator - - return self._getattr_by_name(operator, self.name) - - -class _GraphPickleData: - def __init__(self, graph: torch.fx.Graph) -> None: - self.tracer_cls = graph._tracer_cls - self.tracer_extras = graph._tracer_extras - - nodes: Dict[torch.fx.Node, _NodePickleData] = {} - for node in graph.nodes: - nodes[node] = _NodePickleData(node, nodes) - self.nodes = tuple(nodes.values()) - - # Unpickled variables: - # self._used_names = graph._used_names - # -- self._insert = self._root.prepend - # self._len = graph._len - # self._graph_namespace = graph._graph_namespace - # self._owning_module = graph._owning_module - # self._codegen = graph._codegen - # self._co_fields: Dict[str, Any] = graph._co_fields - # -- self._find_nodes_lookup_table = _FindNodesLookupTable() - - def unpickle( - self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState - ) -> torch.fx.Graph: - graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras) - - nodes: Dict[_NodePickleData, torch.fx.Node] = {} - for nd in self.nodes: - nodes[nd] = nd.unpickle(graph, nodes, unpickle_state) - - return graph - - -class _TracingContextPickleData: - @classmethod - def reduce_helper( - cls, pickler: GraphPickler, obj: torch._guards.TracingContext - ) -> Tuple[ - Callable[[Self, _UnpickleState], torch._guards.TracingContext], - Tuple[Self, _UnpickleStateToken], - ]: - return ( - cls.unpickle, - ( - cls(obj), - pickler._unpickle_state, - ), - ) - - def __init__(self, context: TracingContext) -> None: - # TODO: Do we really need all of this? - self.module_context = context.module_context - self.frame_summary_stack = context.frame_summary_stack - self.loc_in_frame = context.loc_in_frame - self.aot_graph_name = context.aot_graph_name - self.params_flat = context.params_flat - self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses - self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index - self.output_strides = context.output_strides - self.force_unspec_int_unbacked_size_like = ( - context.force_unspec_int_unbacked_size_like - ) - # Not saved (because it's difficult and maybe not needed?): - # self.fw_metadata = context.fw_metadata - # self.guards_context = None - # self.global_context = None - # self.fake_mode = None - # self.fakify_first_call = None - # self.hop_dispatch_set_cache = None - # self.tensor_to_context = context.tensor_to_context - - def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext: - context = TracingContext(unpickle_state.fake_mode) - context.module_context = self.module_context - context.frame_summary_stack = self.frame_summary_stack - context.loc_in_frame = self.loc_in_frame - context.aot_graph_name = self.aot_graph_name - context.params_flat = self.params_flat - context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses - context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index - context.output_strides = self.output_strides - context.force_unspec_int_unbacked_size_like = ( - self.force_unspec_int_unbacked_size_like - ) - return context diff --git a/torch/fx/node.py b/torch/fx/node.py index d017f3a817b5..184e9dee0123 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -598,8 +598,7 @@ class Node(_NodeBase): return self._repr_fn(self) return self.name - @staticmethod - def _pretty_print_target(target: object) -> str: + def _pretty_print_target(self, target: object) -> str: """ Make target printouts more user-friendly. 1) builtins will be printed as `builtins.xyz`