pickler for GraphModule (#141659)

Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.

Differential Revision: [D68921318](https://our.internmc.facebook.com/intern/diff/D68921318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
This commit is contained in:
Aaron Orenstein
2025-01-30 14:05:26 -08:00
committed by PyTorch MergeBot
parent f9227e7c33
commit 57d8278ab9
13 changed files with 1014 additions and 41 deletions

View File

@ -751,7 +751,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
"inductor",
fwd_fullgraph=fwd_fullgraph,
bwd_resize_count_before_inductor=48 if fwd_fullgraph else None,
)
),
)
if fwd_fullgraph:
self.assertEqual(
@ -834,7 +834,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
*self._create_nested_fully_shard_factory_fns(fwd_fullgraph=False),
"inductor",
fwd_fullgraph=False,
)
),
)
# TODO: when fwd_fullgraph=False and there is graph break in FWD graph,
# there are several recompiles, need to figure out why.
@ -978,7 +978,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
"inductor",
fwd_fullgraph=fwd_fullgraph,
bwd_resize_count_before_inductor=76 if fwd_fullgraph else None,
)
),
)
if fwd_fullgraph:
self.assertEqual(
@ -1071,7 +1071,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
),
"inductor",
fwd_fullgraph=fwd_fullgraph,
)
),
)
# TODO: when fwd_fullgraph=False and there is graph break in FWD graph,
# there are several recompiles, need to figure out why.

View File

@ -0,0 +1,96 @@
# Owner(s): ["module: fx"]
#
# Tests the graph pickler by using pickling on all the inductor tests.
#
import contextlib
import importlib
import os
import sys
from unittest.mock import patch
import torch
import torch.library
from torch._dynamo.testing import make_test_cls_with_patches
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
check_model,
CommonTemplate,
copy_tests,
)
importlib.import_module("filelock")
# xfail by default, set is_skip=True to skip
test_failures = {}
def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"):
return make_test_cls_with_patches(
cls,
"GraphPickler",
"_graph_pickler",
(
torch._inductor.compile_fx,
"fx_compile_mode",
torch._inductor.compile_fx.FxCompileMode.SERIALIZE,
),
xfail_prop=xfail_prop,
)
GraphPicklerCommonTemplate = make_test_cls(CommonTemplate)
if HAS_CPU:
class GraphPicklerCpuTests(TestCase):
common = check_model
device = "cpu"
copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures)
class TestGraphPickler(TestCase):
def setUp(self):
torch._dynamo.reset()
TestCase.setUp(self)
self._stack = contextlib.ExitStack()
self._stack.enter_context(
patch(
"torch._inductor.compile_fx.fx_compile_mode",
torch._inductor.compile_fx.FxCompileMode.SERIALIZE,
)
)
def tearDown(self):
self._stack.close()
TestCase.tearDown(self)
torch._dynamo.reset()
def test_simple(self):
# Make sure that compiling works when we pass the input + output from
# fx_codegen_and_compile() through serde.
def fn(a, b):
return a + b
check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True])))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
# Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068
if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN:
run_tests(needs="filelock")

View File

@ -12370,7 +12370,7 @@ def copy_tests(
new_test = unittest.expectedFailure(new_test)
tf = test_failures and test_failures.get(name)
if tf is not None and suffix in tf.suffixes:
if tf and suffix in tf.suffixes:
skip_func = (
unittest.skip("Skipped!")
if tf.is_skip

View File

@ -1237,6 +1237,10 @@ class TypingVariable(VariableTracker):
@functools.lru_cache(maxsize=1)
def get_np_to_tnp_map():
"""
This generates a mapping from numpy modules to their torch._numpy
modules equivalents.
"""
from ..utils import NP_TO_TNP_MODULE
np_fn_to_tnp_fn = {}
@ -1252,6 +1256,16 @@ def get_np_to_tnp_map():
return np_fn_to_tnp_fn
@functools.lru_cache(maxsize=1)
def get_tnp_to_np_map():
"""
This is just the reverse mapping of get_np_to_tnp_map() - mapping from
torch._numpy modules to numpy equivalents.
"""
m = get_np_to_tnp_map()
return {v: k for k, v in m.items()}
class NumpyVariable(VariableTracker):
"""
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.

View File

@ -1053,6 +1053,13 @@ class FxGraphCache:
try:
artifact_path = graph.after_deserialization(constants)
from .graph import GraphLowering
# This is used by tests to check the output for specific details.
if GraphLowering.save_output_code is not None:
GraphLowering.save_output_code(graph.source_code)
except OSError:
# Not expected, but in case the PyCodeCache entry is removed from
# underneath us, treat it as a cache miss and recompile.

View File

@ -1,21 +1,25 @@
from __future__ import annotations
import contextlib
import enum
import functools
import io
import itertools
import json
import logging
import os
import sys
import time
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from inspect import currentframe
from itertools import count
from typing import (
Any,
Callable,
ContextManager,
Mapping,
Optional,
TYPE_CHECKING,
TypeVar,
@ -56,12 +60,18 @@ from torch._functorch.aot_autograd import (
make_boxed_func,
SerializableAOTDispatchCompiler,
)
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
from torch._inductor.codecache import (
BypassFxGraphCache,
code_hash,
FxGraphCache,
output_code_log,
)
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.output_code import (
CompiledAOTI,
CompiledFxGraph,
CompiledFxGraphConstants,
CompiledFxGraphConstantsWithGm,
get_expanded_dims,
index_expanded_dims,
@ -146,6 +156,43 @@ if TYPE_CHECKING:
GraphSignature,
)
# For testing - use the serde FxCompile scheme to debug serialization and
# deserialization of GraphMoule and CompiledFxGraph.
class FxCompileMode(enum.Enum):
NORMAL = 0
# For testing - use the serde FxCompile scheme to debug serialization and
# deserialization of GraphMoule and CompiledFxGraph.
SERIALIZE = 1
def _fx_compile_mode_default() -> FxCompileMode:
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
value = os.environ.get(name)
NORMAL = FxCompileMode.NORMAL
if value is None:
return NORMAL
try:
value = value.upper()
return FxCompileMode[value]
except KeyError:
import logging
log = logging.getLogger(__name__)
log.error(
"Invalid value of %s for %s. Expected one of %s. Using default.",
value,
name,
", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())),
)
# Remove from the environment so subprocesses don't ALSO complain.
os.environ.pop(name)
return FxCompileMode.NORMAL
fx_compile_mode = _fx_compile_mode_default()
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs")
@ -755,9 +802,11 @@ def _compile_fx_inner(
cache_event_time=start_time,
key=cache_info.get("key") if cache_info else None,
components=cache_info.get("components") if cache_info else None,
cache_bypass_reason=cache_info.get("cache_bypass_reason")
if cache_info
else "cache not enabled",
cache_bypass_reason=(
cache_info.get("cache_bypass_reason")
if cache_info
else "cache not enabled"
),
remote_cache_enabled=remote,
local_cache_enabled=local,
)
@ -787,6 +836,11 @@ def _compile_fx_inner(
class FxCompile(ABC):
"""
An FxCompile represents a mechanism that can turn a GraphModule into an
OutputCode.
"""
# TODO: We should probably eventually add some kind of async version of this
# so we can kick off a compile and then go do other things - but we'll need
# to know what kind of API we want for that first.
@ -1137,6 +1191,195 @@ class _InProcessFxCompile(FxCompile):
)
def _current_fake_mode() -> torch._subclasses.FakeTensorMode:
fake_mode = None
if context := torch._guards.TracingContext.try_get():
fake_mode = context.fake_mode
if fake_mode is not None:
return fake_mode
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
return torch._subclasses.FakeTensorMode(shape_env=shape_env)
@dataclass
class _WireProtocolInput:
"""
For _SerializedFxCompile - encapsulates all the data being transferred
(sent) from the parent to the child.
"""
gm: torch.fx.GraphModule
example_inputs: Sequence[InputType]
inputs_to_check: Sequence[int]
graph_kwargs: _CompileFxKwargs
# TODO: Add additional state to transfer to the child.
def serialize(self) -> _WireProtocolPickledInput:
"""
Turns this object into a _WireProtocolPickledInput which can be
directly transferred across a stream.
"""
from torch.fx._graph_pickler import GraphPickler
return _WireProtocolPickledInput(GraphPickler.dumps(self))
@dataclass
class _WireProtocolPickledInput:
value: bytes
def deserialize(self) -> _WireProtocolInput:
"""
Turn this streamable object back into a _WireProtocolInput.
"""
from torch.fx._graph_pickler import GraphPickler
fake_mode = _current_fake_mode()
result = GraphPickler.loads(self.value, fake_mode)
assert isinstance(result, _WireProtocolInput)
return result
@dataclass
class _WireProtocolOutput:
"""
For _SerializedFxCompile - encapsulates all the data being transferred
(returned) back from the child to the parent.
"""
graph: OutputCode
def serialize(self) -> _WireProtocolPickledOutput:
"""
Turns this object into a _WireProtocolPickledOutput which can be
directly transferred across a stream.
"""
from torch.fx._graph_pickler import GraphPickler
if isinstance(self.graph, CompiledFxGraph):
self.graph.prepare_for_serialization()
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
@dataclass
class _WireProtocolPickledOutput:
value: bytes
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
"""
Turn this streamable object back into a _WireProtocolOutput.
"""
from torch.fx._graph_pickler import GraphPickler
fake_mode = _current_fake_mode()
result = GraphPickler.loads(self.value, fake_mode)
assert isinstance(result, _WireProtocolOutput)
if isinstance(result.graph, CompiledFxGraph):
result.graph.after_deserialization(constants)
return result
class _SerializedFxCompile(FxCompile):
"""
This is used to represent an FxCompile which occurs across a serialized
boundary.
"""
@override
def codegen_and_compile(
self,
gm: GraphModule,
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
# _context = torch._guards.TracingContext.try_get()
constants = CompiledFxGraphConstantsWithGm(gm)
try:
input = _WireProtocolInput(
gm,
example_inputs,
inputs_to_check,
graph_kwargs,
).serialize()
except (AttributeError, BypassFxGraphCache):
# For example: AttributeError: Can't pickle local object
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
# TODO: scuba record about not being able to do this?
log.debug("Unable to pickle input graph or example inputs", exc_info=True)
# Fallback to in-process
return _InProcessFxCompile().codegen_and_compile(
gm, example_inputs, inputs_to_check, graph_kwargs
)
output = self._send_to_child(input).deserialize(constants)
self._postprocess(output)
# TODO: Do we need to figure out what changed in TracingContext in the
# child and plumb that back up to the parent?
return output.graph
@abstractmethod
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# The implementation of this should transfer `input` to the child, call
# `_run_in_child(input)` and transfer the result back.
...
def _postprocess(self, output: _WireProtocolOutput) -> None:
pass
@classmethod
def _run_in_child(
cls,
pickled_input: _WireProtocolPickledInput,
extra_env: Optional[Mapping[str, str]] = None,
) -> _WireProtocolPickledOutput:
with contextlib.ExitStack() as stack:
if extra_env is not None:
import unittest
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
# TODO: Should we split the input into multiple sections where each
# section sets up state for the previous section? (i.e. a Config section
# which we decode and apply, followed by a FakeTensorMode section which
# we decode and apply, etc)
input = pickled_input.deserialize()
stack.enter_context(DebugContext())
output_graph = _InProcessFxCompile().codegen_and_compile(
input.gm,
input.example_inputs,
input.inputs_to_check,
input.graph_kwargs,
)
return _WireProtocolOutput(
output_graph,
).serialize()
# This is a debugging/testing implementation of FxCompile which serializes the
# input and output but still runs the FxCompile in-process.
class _DebugSerdeFxCompile(_SerializedFxCompile):
@override
def _send_to_child(
self, pickled_input: _WireProtocolPickledInput
) -> _WireProtocolPickledOutput:
# For debugging just serde the input and output but don't run in a
# subprocess.
return self._run_in_child(pickled_input)
def fx_codegen_and_compile(
gm: GraphModule,
example_inputs: Sequence[InputType],
@ -1145,7 +1388,13 @@ def fx_codegen_and_compile(
inputs_to_check: Sequence[int],
**graph_kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
scheme: FxCompile = _InProcessFxCompile()
scheme: FxCompile
if fx_compile_mode == FxCompileMode.NORMAL:
scheme = _InProcessFxCompile()
elif fx_compile_mode == FxCompileMode.SERIALIZE:
scheme = _DebugSerdeFxCompile()
else:
raise NotImplementedError
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
@ -1272,11 +1521,13 @@ def cudagraphify_impl(
# allocate static tensor inputs
static_inputs = [
x
if not isinstance(x, torch.Tensor)
else static_input(x)
if idx not in static_input_idxs
else x.detach()
(
x
if not isinstance(x, torch.Tensor)
else static_input(x)
if idx not in static_input_idxs
else x.detach()
)
for idx, x in enumerate(inputs)
]
@ -1506,9 +1757,11 @@ def fw_compiler_freezing(
def get_cpp_wrapper_config() -> dict[str, object]:
return {
# Set autotune_at_compile_time to True as default if the option is not explicitly set
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
if config.triton.autotune_at_compile_time is not None
else has_triton(),
"triton.autotune_at_compile_time": (
config.triton.autotune_at_compile_time
if config.triton.autotune_at_compile_time is not None
else has_triton()
),
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False, # TODO: to be removed
"triton.store_cubin": True,
@ -1842,9 +2095,11 @@ def compile_fx(
model_outputs_node.meta["user_visible_output_idxs"] = []
fixed = count_tangents(gm)
with config.patch(
get_cpp_wrapper_config()
) if config.cpp_wrapper else contextlib.nullcontext():
with (
config.patch(get_cpp_wrapper_config())
if config.cpp_wrapper
else contextlib.nullcontext()
):
return inner_compile(
gm,
example_inputs,

View File

@ -1989,10 +1989,8 @@ class GraphLowering(torch.fx.Interpreter):
return total_bytes, node_counts, node_runtimes
@staticmethod
def save_output_code(code: str) -> None:
# No-op to be patched for unit tests
pass
# No-op to be patched for unit tests
save_output_code: Optional[Callable[[str], None]] = None
def compile_to_module(self) -> ModuleType:
with dynamo_timed(
@ -2018,7 +2016,8 @@ class GraphLowering(torch.fx.Interpreter):
+ '"""\n'
)
code = tuning_code + code
GraphLowering.save_output_code(code)
if GraphLowering.save_output_code is not None:
GraphLowering.save_output_code(code)
output_code_log.debug("Output code: \n%s", code)
inductor_meta = autotune_cache.inductor_meta_from_config()

View File

@ -546,11 +546,6 @@ class CompiledFxGraph(OutputCode):
write_atomic(artifact_path, code, make_dirs=True)
from .graph import GraphLowering
# This is used by tests to check the output for specific details.
GraphLowering.save_output_code(code)
try:
with dynamo_timed(
"PyCodeCache.load_by_key_path",

View File

@ -1465,12 +1465,16 @@ class DebugDirManager:
torch._dynamo.config.debug_dir_root = self.prev_debug_name
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]:
def run_and_get_code(
fn: Callable[P, _T],
*args: P.args,
**kwargs: P.kwargs,
) -> tuple[_T, list[str]]:
from .graph import GraphLowering
source_codes: list[str] = []
def save_output_code(code: str):
def save_output_code(code: str) -> None:
source_codes.append(code)
with mock.patch.object(GraphLowering, "save_output_code", save_output_code):

View File

@ -379,7 +379,7 @@ class FakeTensorConverter:
out = self.meta_converter(
t,
shape_env=shape_env,
callback=mk_fake_tensor, # type: ignore[arg-type]
callback=mk_fake_tensor,
source=source,
symbolic_context=symbolic_context,
trace=trace,

View File

@ -541,11 +541,27 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
return self.func(new_base, symint_visitor_fn, tensor_visitor_fn)
# A callback where the device is either optional or required.
# All of these satisfy this protocol:
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str])
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
def __call__(
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
) -> _TensorT_cov:
...
class _MetaTensorCallbackKwargs(TypedDict, total=False):
device: Union[torch.device, str]
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
# A callback where the device may not be provided (is optional).
# All of these satisfy this protocol:
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
def __call__(
self,
arg: Callable[[], torch.Tensor],
@ -832,11 +848,13 @@ class MetaConverter(Generic[_TensorT]):
self,
t: MetaTensorDesc,
shape_env: Optional[ShapeEnv],
callback: _MetaTensorCallback[_TensorT],
callback_: _MetaTensorCallback[_TensorT],
source: Optional[Source],
symbolic_context: Optional[SymbolicContext],
) -> _TensorT:
callback = functools.partial(callback, device=t.device)
callback: _MetaTensorCallbackOptDevice = functools.partial(
callback_, device=t.device
)
if source is None:
from torch._dynamo.source import ConstantSource
@ -981,7 +999,7 @@ class MetaConverter(Generic[_TensorT]):
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
],
callback: _MetaTensorCallback[_TensorT],
callback: _MetaTensorCallbackOptDevice[_TensorT],
source: torch._guards.Source,
) -> _TensorT:
# We are hitting plain meta_desc tensor so actually
@ -1216,7 +1234,7 @@ class MetaConverter(Generic[_TensorT]):
shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env,
callback: _MetaTensorCallback[_TensorT] = callback,
callback: _MetaTensorCallbackOptDevice[_TensorT] = callback,
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
@ -1769,7 +1787,9 @@ class MetaConverter(Generic[_TensorT]):
# Thanks to storage resizing, it's possible to end up with a tensor
# that advertises a real size, but has a storage that actually has zero bytes.
# Need to reflect this in the generated FakeTensor.
if t.storage is not None and t.storage.size == 0:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if t.storage is not None and guard_size_oblivious(t.storage.size == 0):
r.untyped_storage().resize_(0)
if t.is_parameter:

582
torch/fx/_graph_pickler.py Normal file
View File

@ -0,0 +1,582 @@
import dataclasses
import importlib
import io
import pickle
from abc import abstractmethod
from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union
from typing_extensions import override, Self
import torch
import torch.utils._pytree as pytree
from torch._guards import TracingContext
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
from torch._subclasses.meta_utils import (
MetaConverter,
MetaTensorDesc,
MetaTensorDescriber,
)
from torch.fx.experimental.sym_node import SymNode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._mode_utils import no_dispatch
_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat)
class GraphPickler(pickle.Pickler):
"""
GraphPickler is a Pickler which helps pickling fx graph - in particular
GraphModule.
"""
def __init__(self, file: io.BytesIO) -> None:
super().__init__(file)
# This abomination is so we can pass external decoding state to the
# unpickler functions. We serialize _unpickle_state as a persistent
# external item and when we deserialize it we return the common state
# object.
self._unpickle_state = _UnpickleStateToken(object())
# This is used to describe tensors. It needs to be common across the
# pickle so that duplicates and views are properly handled.
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
@override
def reducer_override(
self, obj: object
) -> Tuple[Callable[..., Any], Tuple[Any, ...]]:
# This function is supposed to return either NotImplemented (meaning to
# do the default pickle behavior) or a pair of (unpickle callable, data
# to pass to unpickle).
# We could instead teach individual classes how to pickle themselves but
# that has a few problems:
#
# 1. If we have some special needs (maybe for this use-case we don't
# want to fully serialize every field) then we're adding private
# details to a public interface.
#
# 2. If we need to have some common shared data (such as a
# FakeTensorMode) which is passed to each value it's harder to
# support.
# These are the types that need special handling. See the individual
# *PickleData classes for details on pickling that particular type.
if isinstance(obj, FakeTensor):
return _TensorPickleData.reduce_helper(self, obj)
elif isinstance(obj, torch.fx.GraphModule):
return _GraphModulePickleData.reduce_helper(self, obj)
elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)):
return _OpPickleData.reduce_helper(self, obj)
elif isinstance(obj, ShapeEnv):
return _ShapeEnvPickleData.reduce_helper(self, obj)
elif isinstance(obj, torch.SymInt):
return _SymNodePickleData.reduce_helper(self, obj)
elif isinstance(obj, torch._guards.TracingContext):
return _TracingContextPickleData.reduce_helper(self, obj)
else:
# We should never get a raw Node!
assert not isinstance(obj, torch.fx.Node)
if reduce := _TorchNumpyPickleData.reduce_helper(self, obj):
return reduce
# returning `NotImplemented` causes pickle to revert to the default
# behavior for this object.
return NotImplemented
@override
def persistent_id(self, obj: object) -> Optional[str]:
if obj is self._unpickle_state:
return "unpickle_state"
else:
return None
@classmethod
def dumps(cls, obj: object) -> bytes:
"""
Pickle an object.
"""
with io.BytesIO() as stream:
pickler = cls(stream)
pickler.dump(obj)
return stream.getvalue()
@staticmethod
def loads(data: bytes, fake_mode: FakeTensorMode) -> object:
"""
Unpickle an object.
"""
state = _UnpickleState(fake_mode)
with io.BytesIO(data) as stream:
unpickler = _GraphUnpickler(stream, state)
return unpickler.load()
class _UnpickleState:
def __init__(self, fake_mode: FakeTensorMode) -> None:
self.fake_mode = fake_mode
self.meta_converter: MetaConverter[FakeTensor] = MetaConverter()
# This token is passed when pickling to indicate that we want to use the
# unpickler's _UnpickleState as a parameter in that position.
_UnpickleStateToken = NewType("_UnpickleStateToken", object)
class _GraphUnpickler(pickle.Unpickler):
def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None:
super().__init__(stream)
self._unpickle_state = unpickle_state
@override
def persistent_load(self, pid: object) -> object:
if pid == "unpickle_state":
return self._unpickle_state
else:
raise pickle.UnpicklingError("Invalid persistent ID")
class _ShapeEnvPickleData:
data: Dict[str, object]
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: ShapeEnv
) -> Tuple[
Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken]
]:
return cls.unpickle, (cls(obj), pickler._unpickle_state)
def __init__(self, env: ShapeEnv) -> None:
# In theory pickle should recognize that a given ShapeEnv was already
# pickled and reuse the resulting _ShapeEnvPickleData (so two objects
# pointing at the same ShapeEnv get the same ShapeEnv out).
assert not env._translation_validation_enabled
self.data = env.__dict__.copy()
del self.data["tracked_fakes"]
del self.data["fake_tensor_cache"]
def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv:
# Fill in the existing ShapeEnv rather than creating a new one
assert unpickle_state.fake_mode
assert unpickle_state.fake_mode.shape_env
for k, v in self.data.items():
setattr(unpickle_state.fake_mode.shape_env, k, v)
return unpickle_state.fake_mode.shape_env
class _SymNodePickleData:
@classmethod
def reduce_helper(
cls,
pickler: GraphPickler,
obj: _SymNodeT,
) -> Tuple[
Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken]
]:
args = (cls(obj.node), pickler._unpickle_state)
if isinstance(obj, torch.SymInt):
return _SymNodePickleData.unpickle_sym_int, args
else:
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
def __init__(self, node: SymNode) -> None:
self.expr = node._expr
self.shape_env = node.shape_env
self.pytype = node.pytype
self.hint = node._hint
def _to_sym_node(self) -> SymNode:
from torch.fx.experimental.sym_node import SymNode
assert self.shape_env is not None
return SymNode(self.expr, self.shape_env, self.pytype, self.hint)
def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt:
return torch.SymInt(self._to_sym_node())
class _TensorPickleData:
metadata: MetaTensorDesc[FakeTensor]
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: FakeTensor
) -> Tuple[
Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken]
]:
return cls.unpickle, (
cls(pickler._meta_tensor_describer, obj),
pickler._unpickle_state,
)
def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None:
# THINGS TO WORRY ABOUT:
# 1. Need to make sure that two tensors with the same id end up with the
# same id on the other side of the wire.
metadata = describer.describe_tensor(t)
# view_func is fine if it's either None or a _FakeTensorViewFunc. A
# custom one (which is basically a lambda) can't be serialized.
assert not metadata.view_func or isinstance(
metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc
)
self.metadata = dataclasses.replace(metadata, fake_mode=None)
# Some debugging/verification
for k in MetaTensorDesc._UNSERIALIZABLE:
if k in ("fake_mode", "view_func"):
continue
assert (
getattr(self.metadata, k) is None
), f"not None: {k}: {getattr(self.metadata, k)}"
def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
# TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?
metadata = dataclasses.replace(
self.metadata,
fake_mode=unpickle_state.fake_mode,
)
def with_fake(
make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str]
) -> FakeTensor:
with no_dispatch():
return FakeTensor(
unpickle_state.fake_mode,
make_meta_t(),
device,
)
return unpickle_state.meta_converter.meta_tensor(
metadata,
unpickle_state.fake_mode.shape_env,
with_fake,
None,
None,
)
class _TorchNumpyPickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: object
) -> Optional[
Tuple[
Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken]
]
]:
if data := cls.from_object(obj):
return (cls.unpickle, (data, pickler._unpickle_state))
else:
return None
def __init__(self, mod: str, name: str) -> None:
self.mod = mod
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]:
np = getattr(importlib.import_module(self.mod), self.name)
return torch._dynamo.variables.misc.get_np_to_tnp_map()[np]
@classmethod
def from_object(cls, tnp: object) -> Optional[Self]:
if not callable(tnp):
return None
tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map()
try:
if not (np := tnp_to_np.get(tnp)):
return None
except TypeError:
return None
if not (mod := getattr(np, "__module__", None)):
mod = "numpy"
if not (name := getattr(np, "__name__", None)):
return None
assert np == getattr(importlib.import_module(mod), name)
return cls(mod, name)
class _GraphModulePickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: torch.fx.GraphModule
) -> Tuple[
Callable[[Self, _UnpickleState], torch.fx.GraphModule],
Tuple[Self, _UnpickleStateToken],
]:
return cls.unpickle, (
cls(obj),
pickler._unpickle_state,
)
def __init__(self, gm: torch.fx.GraphModule) -> None:
# Need to do this to ensure the code is created for later pickling.
if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
_python_code = gm._real_recompile()
else:
_python_code = gm.recompile()
self.gm_dict = gm.__dict__.copy()
del self.gm_dict["_graph"]
self.graph = _GraphPickleData(gm._graph)
def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule:
gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule)
gm.__dict__ = self.gm_dict
gm._graph = self.graph.unpickle(gm, unpickle_state)
return gm
class _NodePickleData:
def __init__(
self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"]
) -> None:
self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
self.kwargs = pytree.tree_map_only(
torch.fx.Node, lambda n: mapping[n], node.kwargs
)
# -- self.graph = node.graph
self.name = node.name
self.op = node.op
self.target = _OpPickleData.pickle(node.target)
# self.input_nodes = node._input_nodes
# self.users = node.users
self.type = node.type
# self.sort_key = node._sort_key
# self.repr_fn = node._repr_fn
# self.meta = node.meta
self.meta = node.meta
def unpickle(
self,
graph: torch.fx.Graph,
mapping: Dict["_NodePickleData", torch.fx.Node],
unpickle_state: _UnpickleState,
) -> torch.fx.Node:
args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
kwargs = pytree.tree_map_only(
_NodePickleData, lambda n: mapping[n], self.kwargs
)
target = self.target.unpickle(unpickle_state)
assert callable(target) or isinstance(target, str)
node = graph.create_node(self.op, target, args, kwargs, self.name, self.type)
node.meta = self.meta
return node
class _OpPickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, op: object
) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]:
result = cls.pickle(op)
return (result.unpickle, (pickler._unpickle_state,))
@classmethod
def pickle(cls, op: object) -> "_OpPickleData":
if isinstance(op, str):
return _OpStrPickleData(op)
name = torch.fx.Node._pretty_print_target(op)
if isinstance(op, torch._ops.OpOverload):
return cls._pickle_op(name, _OpOverloadPickleData)
elif isinstance(op, torch._ops.OpOverloadPacket):
return cls._pickle_op(name, _OpOverloadPacketPickleData)
elif name.startswith(("builtins.", "math.", "torch.")):
root, detail = name.split(".", 1)
return _OpBuiltinPickleData(root, detail)
elif name.startswith("operator."):
_, detail = name.split(".", 1)
return _OpOperatorPickleData(detail)
else:
# TODO: raise a BypassFxGraphCache so we will just bypass this one...
raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
@staticmethod
def _pickle_op(
name: str,
datacls: Union[
Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"]
],
) -> "_OpPickleData":
if not name.startswith("torch.ops.aten"): # TODO: What's the full list?
from torch._inductor.codecache import BypassFxGraphCache
raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}")
return datacls(name)
@abstractmethod
def unpickle(self, unpickle_state: _UnpickleState) -> object:
pass
@classmethod
def _lookup_global_by_name(cls, name: str) -> object:
"""
Like `globals()[name]` but supports dotted names.
"""
if "." in name:
mod, rest = name.split(".", 1)
root = globals()[mod]
return cls._getattr_by_name(root, rest)
else:
return globals()[name]
@staticmethod
def _getattr_by_name(root: object, name: str) -> object:
"""
Like `getattr(root, name)` but supports dotted names.
"""
while "." in name:
mod, name = name.split(".", 1)
root = getattr(root, mod)
return getattr(root, name)
class _OpStrPickleData(_OpPickleData):
def __init__(self, name: str) -> None:
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> str:
return self.name
class _OpOverloadPickleData(_OpPickleData):
def __init__(self, name: str) -> None:
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload:
obj = self._lookup_global_by_name(self.name)
assert isinstance(obj, torch._ops.OpOverload)
return obj
class _OpOverloadPacketPickleData(_OpPickleData):
def __init__(self, name: str) -> None:
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket:
obj = self._lookup_global_by_name(self.name)
assert isinstance(obj, torch._ops.OpOverloadPacket)
return obj
class _OpBuiltinPickleData(_OpPickleData):
def __init__(self, root: str, name: str) -> None:
self.root = root
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> object:
if self.root == "builtins":
return __builtins__.get(self.name) # type: ignore[attr-defined]
elif self.root == "math":
import math
return self._getattr_by_name(math, self.name)
elif self.root == "torch":
return self._getattr_by_name(torch, self.name)
else:
raise NotImplementedError
class _OpOperatorPickleData(_OpPickleData):
def __init__(self, name: str) -> None:
self.name = name
def unpickle(self, unpickle_state: _UnpickleState) -> object:
import operator
return self._getattr_by_name(operator, self.name)
class _GraphPickleData:
def __init__(self, graph: torch.fx.Graph) -> None:
self.tracer_cls = graph._tracer_cls
self.tracer_extras = graph._tracer_extras
nodes: Dict[torch.fx.Node, _NodePickleData] = {}
for node in graph.nodes:
nodes[node] = _NodePickleData(node, nodes)
self.nodes = tuple(nodes.values())
# Unpickled variables:
# self._used_names = graph._used_names
# -- self._insert = self._root.prepend
# self._len = graph._len
# self._graph_namespace = graph._graph_namespace
# self._owning_module = graph._owning_module
# self._codegen = graph._codegen
# self._co_fields: Dict[str, Any] = graph._co_fields
# -- self._find_nodes_lookup_table = _FindNodesLookupTable()
def unpickle(
self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState
) -> torch.fx.Graph:
graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
nodes: Dict[_NodePickleData, torch.fx.Node] = {}
for nd in self.nodes:
nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
return graph
class _TracingContextPickleData:
@classmethod
def reduce_helper(
cls, pickler: GraphPickler, obj: torch._guards.TracingContext
) -> Tuple[
Callable[[Self, _UnpickleState], torch._guards.TracingContext],
Tuple[Self, _UnpickleStateToken],
]:
return (
cls.unpickle,
(
cls(obj),
pickler._unpickle_state,
),
)
def __init__(self, context: TracingContext) -> None:
# TODO: Do we really need all of this?
self.module_context = context.module_context
self.frame_summary_stack = context.frame_summary_stack
self.loc_in_frame = context.loc_in_frame
self.aot_graph_name = context.aot_graph_name
self.params_flat = context.params_flat
self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses
self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index
self.output_strides = context.output_strides
self.force_unspec_int_unbacked_size_like = (
context.force_unspec_int_unbacked_size_like
)
# Not saved (because it's difficult and maybe not needed?):
# self.fw_metadata = context.fw_metadata
# self.guards_context = None
# self.global_context = None
# self.fake_mode = None
# self.fakify_first_call = None
# self.hop_dispatch_set_cache = None
# self.tensor_to_context = context.tensor_to_context
def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext:
context = TracingContext(unpickle_state.fake_mode)
context.module_context = self.module_context
context.frame_summary_stack = self.frame_summary_stack
context.loc_in_frame = self.loc_in_frame
context.aot_graph_name = self.aot_graph_name
context.params_flat = self.params_flat
context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses
context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index
context.output_strides = self.output_strides
context.force_unspec_int_unbacked_size_like = (
self.force_unspec_int_unbacked_size_like
)
return context

View File

@ -602,7 +602,8 @@ class Node(_NodeBase):
return self._repr_fn(self)
return self.name
def _pretty_print_target(self, target: object) -> str:
@staticmethod
def _pretty_print_target(target: object) -> str:
"""
Make target printouts more user-friendly.
1) builtins will be printed as `builtins.xyz`