mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "pickler for GraphModule (#141659)"
This reverts commit c6ad08357bf8e766b5220bfb5cbbfdb2a4ec0ca5. Reverted https://github.com/pytorch/pytorch/pull/141659 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, please take a look at D68694181 for more details. ([comment](https://github.com/pytorch/pytorch/pull/141659#issuecomment-2617045120))
This commit is contained in:
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import comms
|
||||
from torch._inductor.utils import is_fallback_op, run_and_get_code_before_compile
|
||||
from torch._inductor.utils import is_fallback_op, run_and_get_code
|
||||
from torch.distributed._tensor import init_device_mesh
|
||||
from torch.distributed.fsdp import (
|
||||
fully_shard,
|
||||
@ -743,7 +743,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
):
|
||||
_, triton_codes = run_and_get_code_before_compile(
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
*self._create_nested_fully_shard_factory_fns(
|
||||
fwd_fullgraph=fwd_fullgraph
|
||||
@ -751,7 +751,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
"inductor",
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
bwd_resize_count_before_inductor=48 if fwd_fullgraph else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
if fwd_fullgraph:
|
||||
self.assertEqual(
|
||||
@ -829,12 +829,12 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_nested_fully_shard_backend_inductor_fullgraph_False(self):
|
||||
self.skipTestForOldSm()
|
||||
_, triton_codes = run_and_get_code_before_compile(
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
*self._create_nested_fully_shard_factory_fns(fwd_fullgraph=False),
|
||||
"inductor",
|
||||
fwd_fullgraph=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
# TODO: when fwd_fullgraph=False and there is graph break in FWD graph,
|
||||
# there are several recompiles, need to figure out why.
|
||||
@ -969,7 +969,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
if fwd_fullgraph
|
||||
else None
|
||||
):
|
||||
_, triton_codes = run_and_get_code_before_compile(
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(
|
||||
all_requires_grad=all_requires_grad,
|
||||
@ -978,7 +978,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
"inductor",
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
bwd_resize_count_before_inductor=76 if fwd_fullgraph else None,
|
||||
),
|
||||
)
|
||||
)
|
||||
if fwd_fullgraph:
|
||||
self.assertEqual(
|
||||
@ -1063,7 +1063,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950
|
||||
)
|
||||
with self._maybe_add_graph_break_to_sdpa(fwd_fullgraph):
|
||||
_, triton_codes = run_and_get_code_before_compile(
|
||||
_, triton_codes = run_and_get_code(
|
||||
lambda: self._test_traceable_fsdp(
|
||||
*self._create_transformer_factory_fns(
|
||||
all_requires_grad=all_requires_grad,
|
||||
@ -1071,7 +1071,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
||||
),
|
||||
"inductor",
|
||||
fwd_fullgraph=fwd_fullgraph,
|
||||
),
|
||||
)
|
||||
)
|
||||
# TODO: when fwd_fullgraph=False and there is graph break in FWD graph,
|
||||
# there are several recompiles, need to figure out why.
|
||||
|
@ -1,96 +0,0 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
#
|
||||
# Tests the graph pickler by using pickling on all the inductor tests.
|
||||
#
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.library
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ASAN
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
|
||||
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
||||
check_model,
|
||||
CommonTemplate,
|
||||
copy_tests,
|
||||
)
|
||||
|
||||
|
||||
importlib.import_module("filelock")
|
||||
|
||||
# xfail by default, set is_skip=True to skip
|
||||
test_failures = {}
|
||||
|
||||
|
||||
def make_test_cls(cls, xfail_prop="_expected_failure_graph_pickler"):
|
||||
return make_test_cls_with_patches(
|
||||
cls,
|
||||
"GraphPickler",
|
||||
"_graph_pickler",
|
||||
(
|
||||
torch._inductor.compile_fx,
|
||||
"fx_compile_mode",
|
||||
torch._inductor.compile_fx.FxCompileMode.SERIALIZE,
|
||||
),
|
||||
xfail_prop=xfail_prop,
|
||||
)
|
||||
|
||||
|
||||
GraphPicklerCommonTemplate = make_test_cls(CommonTemplate)
|
||||
|
||||
|
||||
if HAS_CPU:
|
||||
|
||||
class GraphPicklerCpuTests(TestCase):
|
||||
common = check_model
|
||||
device = "cpu"
|
||||
|
||||
copy_tests(GraphPicklerCommonTemplate, GraphPicklerCpuTests, "cpu", test_failures)
|
||||
|
||||
|
||||
class TestGraphPickler(TestCase):
|
||||
def setUp(self):
|
||||
torch._dynamo.reset()
|
||||
TestCase.setUp(self)
|
||||
|
||||
self._stack = contextlib.ExitStack()
|
||||
self._stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.compile_fx.fx_compile_mode",
|
||||
torch._inductor.compile_fx.FxCompileMode.SERIALIZE,
|
||||
)
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self._stack.close()
|
||||
TestCase.tearDown(self)
|
||||
torch._dynamo.reset()
|
||||
|
||||
def test_simple(self):
|
||||
# Make sure that compiling works when we pass the input + output from
|
||||
# fx_codegen_and_compile() through serde.
|
||||
|
||||
def fn(a, b):
|
||||
return a + b
|
||||
|
||||
check_model(self, fn, (torch.tensor([False, True]), torch.tensor([True, True])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
# Slow on ASAN after https://github.com/pytorch/pytorch/pull/94068
|
||||
if (HAS_CPU or HAS_GPU) and not TEST_WITH_ASAN:
|
||||
run_tests(needs="filelock")
|
@ -1211,10 +1211,6 @@ class TypingVariable(VariableTracker):
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_np_to_tnp_map():
|
||||
"""
|
||||
This generates a mapping from numpy modules to their torch._numpy
|
||||
modules equivalents.
|
||||
"""
|
||||
from ..utils import NP_TO_TNP_MODULE
|
||||
|
||||
np_fn_to_tnp_fn = {}
|
||||
@ -1230,16 +1226,6 @@ def get_np_to_tnp_map():
|
||||
return np_fn_to_tnp_fn
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def get_tnp_to_np_map():
|
||||
"""
|
||||
This is just the reverse mapping of get_np_to_tnp_map() - mapping from
|
||||
torch._numpy modules to numpy equivalents.
|
||||
"""
|
||||
m = get_np_to_tnp_map()
|
||||
return {v: k for k, v in m.items()}
|
||||
|
||||
|
||||
class NumpyVariable(VariableTracker):
|
||||
"""
|
||||
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
|
||||
|
@ -1,25 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import functools
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from inspect import currentframe
|
||||
from itertools import count
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Mapping,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
@ -60,18 +56,12 @@ from torch._functorch.aot_autograd import (
|
||||
make_boxed_func,
|
||||
SerializableAOTDispatchCompiler,
|
||||
)
|
||||
from torch._inductor.codecache import (
|
||||
BypassFxGraphCache,
|
||||
code_hash,
|
||||
FxGraphCache,
|
||||
output_code_log,
|
||||
)
|
||||
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._inductor.output_code import (
|
||||
CompiledAOTI,
|
||||
CompiledFxGraph,
|
||||
CompiledFxGraphConstants,
|
||||
CompiledFxGraphConstantsWithGm,
|
||||
get_expanded_dims,
|
||||
index_expanded_dims,
|
||||
@ -156,43 +146,6 @@ if TYPE_CHECKING:
|
||||
GraphSignature,
|
||||
)
|
||||
|
||||
|
||||
# For testing - use the serde FxCompile scheme to debug serialization and
|
||||
# deserialization of GraphMoule and CompiledFxGraph.
|
||||
class FxCompileMode(enum.Enum):
|
||||
NORMAL = 0
|
||||
# For testing - use the serde FxCompile scheme to debug serialization and
|
||||
# deserialization of GraphMoule and CompiledFxGraph.
|
||||
SERIALIZE = 1
|
||||
|
||||
|
||||
def _fx_compile_mode_default() -> FxCompileMode:
|
||||
name = "TORCHINDUCTOR_FX_COMPILE_MODE"
|
||||
value = os.environ.get(name)
|
||||
NORMAL = FxCompileMode.NORMAL
|
||||
if value is None:
|
||||
return NORMAL
|
||||
try:
|
||||
value = value.upper()
|
||||
return FxCompileMode[value]
|
||||
except KeyError:
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.error(
|
||||
"Invalid value of %s for %s. Expected one of %s. Using default.",
|
||||
value,
|
||||
name,
|
||||
", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())),
|
||||
)
|
||||
# Remove from the environment so subprocesses don't ALSO complain.
|
||||
os.environ.pop(name)
|
||||
return FxCompileMode.NORMAL
|
||||
|
||||
|
||||
fx_compile_mode = _fx_compile_mode_default()
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs")
|
||||
@ -798,11 +751,9 @@ def _compile_fx_inner(
|
||||
cache_event_time=start_time,
|
||||
key=cache_info.get("key") if cache_info else None,
|
||||
components=cache_info.get("components") if cache_info else None,
|
||||
cache_bypass_reason=(
|
||||
cache_info.get("cache_bypass_reason")
|
||||
if cache_info
|
||||
else "cache not enabled"
|
||||
),
|
||||
cache_bypass_reason=cache_info.get("cache_bypass_reason")
|
||||
if cache_info
|
||||
else "cache not enabled",
|
||||
remote_cache_enabled=remote,
|
||||
local_cache_enabled=local,
|
||||
)
|
||||
@ -832,11 +783,6 @@ def _compile_fx_inner(
|
||||
|
||||
|
||||
class FxCompile(ABC):
|
||||
"""
|
||||
An FxCompile represents a mechanism that can turn a GraphModule into an
|
||||
OutputCode.
|
||||
"""
|
||||
|
||||
# TODO: We should probably eventually add some kind of async version of this
|
||||
# so we can kick off a compile and then go do other things - but we'll need
|
||||
# to know what kind of API we want for that first.
|
||||
@ -1187,195 +1133,6 @@ class _InProcessFxCompile(FxCompile):
|
||||
)
|
||||
|
||||
|
||||
def _current_fake_mode() -> torch._subclasses.FakeTensorMode:
|
||||
fake_mode = None
|
||||
if context := torch._guards.TracingContext.try_get():
|
||||
fake_mode = context.fake_mode
|
||||
if fake_mode is not None:
|
||||
return fake_mode
|
||||
|
||||
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
||||
return torch._subclasses.FakeTensorMode(shape_env=shape_env)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolInput:
|
||||
"""
|
||||
For _SerializedFxCompile - encapsulates all the data being transferred
|
||||
(sent) from the parent to the child.
|
||||
"""
|
||||
|
||||
gm: torch.fx.GraphModule
|
||||
example_inputs: Sequence[InputType]
|
||||
inputs_to_check: Sequence[int]
|
||||
graph_kwargs: _CompileFxKwargs
|
||||
# TODO: Add additional state to transfer to the child.
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledInput:
|
||||
"""
|
||||
Turns this object into a _WireProtocolPickledInput which can be
|
||||
directly transferred across a stream.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
return _WireProtocolPickledInput(GraphPickler.dumps(self))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolPickledInput:
|
||||
value: bytes
|
||||
|
||||
def deserialize(self) -> _WireProtocolInput:
|
||||
"""
|
||||
Turn this streamable object back into a _WireProtocolInput.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
fake_mode = _current_fake_mode()
|
||||
result = GraphPickler.loads(self.value, fake_mode)
|
||||
assert isinstance(result, _WireProtocolInput)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolOutput:
|
||||
"""
|
||||
For _SerializedFxCompile - encapsulates all the data being transferred
|
||||
(returned) back from the child to the parent.
|
||||
"""
|
||||
|
||||
graph: OutputCode
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledOutput:
|
||||
"""
|
||||
Turns this object into a _WireProtocolPickledOutput which can be
|
||||
directly transferred across a stream.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
if isinstance(self.graph, CompiledFxGraph):
|
||||
self.graph.prepare_for_serialization()
|
||||
return _WireProtocolPickledOutput(GraphPickler.dumps(self))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _WireProtocolPickledOutput:
|
||||
value: bytes
|
||||
|
||||
def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput:
|
||||
"""
|
||||
Turn this streamable object back into a _WireProtocolOutput.
|
||||
"""
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
|
||||
fake_mode = _current_fake_mode()
|
||||
result = GraphPickler.loads(self.value, fake_mode)
|
||||
assert isinstance(result, _WireProtocolOutput)
|
||||
if isinstance(result.graph, CompiledFxGraph):
|
||||
result.graph.after_deserialization(constants)
|
||||
return result
|
||||
|
||||
|
||||
class _SerializedFxCompile(FxCompile):
|
||||
"""
|
||||
This is used to represent an FxCompile which occurs across a serialized
|
||||
boundary.
|
||||
"""
|
||||
|
||||
@override
|
||||
def codegen_and_compile(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
inputs_to_check: Sequence[int],
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> OutputCode:
|
||||
# _context = torch._guards.TracingContext.try_get()
|
||||
constants = CompiledFxGraphConstantsWithGm(gm)
|
||||
|
||||
try:
|
||||
input = _WireProtocolInput(
|
||||
gm,
|
||||
example_inputs,
|
||||
inputs_to_check,
|
||||
graph_kwargs,
|
||||
).serialize()
|
||||
except (AttributeError, BypassFxGraphCache):
|
||||
# For example: AttributeError: Can't pickle local object
|
||||
# 'make_opaque_unary_fn.<locals>.OpaqueUnaryFn'
|
||||
|
||||
# TODO: scuba record about not being able to do this?
|
||||
log.debug("Unable to pickle input graph or example inputs", exc_info=True)
|
||||
|
||||
# Fallback to in-process
|
||||
return _InProcessFxCompile().codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, graph_kwargs
|
||||
)
|
||||
|
||||
output = self._send_to_child(input).deserialize(constants)
|
||||
|
||||
self._postprocess(output)
|
||||
|
||||
# TODO: Do we need to figure out what changed in TracingContext in the
|
||||
# child and plumb that back up to the parent?
|
||||
|
||||
return output.graph
|
||||
|
||||
@abstractmethod
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# The implementation of this should transfer `input` to the child, call
|
||||
# `_run_in_child(input)` and transfer the result back.
|
||||
...
|
||||
|
||||
def _postprocess(self, output: _WireProtocolOutput) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _run_in_child(
|
||||
cls,
|
||||
pickled_input: _WireProtocolPickledInput,
|
||||
extra_env: Optional[Mapping[str, str]] = None,
|
||||
) -> _WireProtocolPickledOutput:
|
||||
with contextlib.ExitStack() as stack:
|
||||
if extra_env is not None:
|
||||
import unittest
|
||||
|
||||
stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env))
|
||||
|
||||
# TODO: Should we split the input into multiple sections where each
|
||||
# section sets up state for the previous section? (i.e. a Config section
|
||||
# which we decode and apply, followed by a FakeTensorMode section which
|
||||
# we decode and apply, etc)
|
||||
input = pickled_input.deserialize()
|
||||
|
||||
stack.enter_context(DebugContext())
|
||||
|
||||
output_graph = _InProcessFxCompile().codegen_and_compile(
|
||||
input.gm,
|
||||
input.example_inputs,
|
||||
input.inputs_to_check,
|
||||
input.graph_kwargs,
|
||||
)
|
||||
|
||||
return _WireProtocolOutput(
|
||||
output_graph,
|
||||
).serialize()
|
||||
|
||||
|
||||
# This is a debugging/testing implementation of FxCompile which serializes the
|
||||
# input and output but still runs the FxCompile in-process.
|
||||
class _DebugSerdeFxCompile(_SerializedFxCompile):
|
||||
@override
|
||||
def _send_to_child(
|
||||
self, pickled_input: _WireProtocolPickledInput
|
||||
) -> _WireProtocolPickledOutput:
|
||||
# For debugging just serde the input and output but don't run in a
|
||||
# subprocess.
|
||||
return self._run_in_child(pickled_input)
|
||||
|
||||
|
||||
def fx_codegen_and_compile(
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
@ -1384,13 +1141,7 @@ def fx_codegen_and_compile(
|
||||
inputs_to_check: Sequence[int],
|
||||
**graph_kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> OutputCode:
|
||||
scheme: FxCompile
|
||||
if fx_compile_mode == FxCompileMode.NORMAL:
|
||||
scheme = _InProcessFxCompile()
|
||||
elif fx_compile_mode == FxCompileMode.SERIALIZE:
|
||||
scheme = _DebugSerdeFxCompile()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
scheme: FxCompile = _InProcessFxCompile()
|
||||
|
||||
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
|
||||
|
||||
@ -1517,13 +1268,11 @@ def cudagraphify_impl(
|
||||
|
||||
# allocate static tensor inputs
|
||||
static_inputs = [
|
||||
(
|
||||
x
|
||||
if not isinstance(x, torch.Tensor)
|
||||
else static_input(x)
|
||||
if idx not in static_input_idxs
|
||||
else x.detach()
|
||||
)
|
||||
x
|
||||
if not isinstance(x, torch.Tensor)
|
||||
else static_input(x)
|
||||
if idx not in static_input_idxs
|
||||
else x.detach()
|
||||
for idx, x in enumerate(inputs)
|
||||
]
|
||||
|
||||
@ -1755,11 +1504,9 @@ def fw_compiler_freezing(
|
||||
def get_cpp_wrapper_config() -> dict[str, object]:
|
||||
return {
|
||||
# Set autotune_at_compile_time to True as default if the option is not explicitly set
|
||||
"triton.autotune_at_compile_time": (
|
||||
config.triton.autotune_at_compile_time
|
||||
if config.triton.autotune_at_compile_time is not None
|
||||
else has_triton()
|
||||
),
|
||||
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
|
||||
if config.triton.autotune_at_compile_time is not None
|
||||
else has_triton(),
|
||||
"triton.autotune_cublasLt": False,
|
||||
"triton.cudagraphs": False, # TODO: to be removed
|
||||
"triton.store_cubin": True,
|
||||
@ -2093,11 +1840,9 @@ def compile_fx(
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
fixed = count_tangents(gm)
|
||||
with (
|
||||
config.patch(get_cpp_wrapper_config())
|
||||
if config.cpp_wrapper
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
with config.patch(
|
||||
get_cpp_wrapper_config()
|
||||
) if config.cpp_wrapper else contextlib.nullcontext():
|
||||
return inner_compile(
|
||||
gm,
|
||||
example_inputs,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import contextlib
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
@ -274,12 +273,6 @@ def mark_nodes_dislike_padding(
|
||||
cur.meta["dislike_padding"] = True
|
||||
|
||||
|
||||
class SaveOutputCodeContext(enum.Enum):
|
||||
BEFORE_COMPILE = 0
|
||||
AFTER_DESERIALIZATION = 1
|
||||
AFTER_COMPILE = 2
|
||||
|
||||
|
||||
class GraphLowering(torch.fx.Interpreter):
|
||||
graph_outputs: list[ir.IRNode]
|
||||
|
||||
@ -1989,7 +1982,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return total_bytes, node_counts, node_runtimes
|
||||
|
||||
@staticmethod
|
||||
def save_output_code(code: str, context: SaveOutputCodeContext) -> None:
|
||||
def save_output_code(code: str) -> None:
|
||||
# No-op to be patched for unit tests
|
||||
pass
|
||||
|
||||
@ -2017,7 +2010,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
+ '"""\n'
|
||||
)
|
||||
code = tuning_code + code
|
||||
GraphLowering.save_output_code(code, SaveOutputCodeContext.BEFORE_COMPILE)
|
||||
GraphLowering.save_output_code(code)
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
|
@ -460,15 +460,7 @@ class CompiledFxGraph(OutputCode):
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
assert self.current_callable is not None
|
||||
try:
|
||||
result = self.current_callable(inputs)
|
||||
|
||||
from torch._inductor.graph import GraphLowering, SaveOutputCodeContext
|
||||
|
||||
GraphLowering.save_output_code(
|
||||
self.source_code, SaveOutputCodeContext.AFTER_COMPILE
|
||||
)
|
||||
|
||||
return result
|
||||
return self.current_callable(inputs)
|
||||
finally:
|
||||
get_runtime_metrics_context().finish()
|
||||
AutotuneCacheBundler.end_compile()
|
||||
@ -557,12 +549,10 @@ class CompiledFxGraph(OutputCode):
|
||||
|
||||
write_atomic(artifact_path, code, make_dirs=True)
|
||||
|
||||
from .graph import GraphLowering, SaveOutputCodeContext
|
||||
from .graph import GraphLowering
|
||||
|
||||
# This is used by tests to check the output for specific details.
|
||||
GraphLowering.save_output_code(
|
||||
code, SaveOutputCodeContext.AFTER_DESERIALIZATION
|
||||
)
|
||||
GraphLowering.save_output_code(code)
|
||||
|
||||
try:
|
||||
with dynamo_timed(
|
||||
|
@ -49,7 +49,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from .codegen.common import WorkspaceArg
|
||||
from .graph import SaveOutputCodeContext
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_map_only
|
||||
@ -1466,19 +1465,13 @@ class DebugDirManager:
|
||||
torch._dynamo.config.debug_dir_root = self.prev_debug_name
|
||||
|
||||
|
||||
def _run_and_get_code_for_context(
|
||||
fn: Callable[P, _T],
|
||||
for_context: SaveOutputCodeContext,
|
||||
args: P.args,
|
||||
kwargs: P.kwargs,
|
||||
) -> tuple[_T, list[str]]:
|
||||
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]:
|
||||
from .graph import GraphLowering
|
||||
|
||||
source_codes: list[str] = []
|
||||
|
||||
def save_output_code(code: str, context: SaveOutputCodeContext):
|
||||
if context == for_context:
|
||||
source_codes.append(code)
|
||||
def save_output_code(code: str):
|
||||
source_codes.append(code)
|
||||
|
||||
with mock.patch.object(GraphLowering, "save_output_code", save_output_code):
|
||||
torch._dynamo.reset()
|
||||
@ -1486,22 +1479,6 @@ def _run_and_get_code_for_context(
|
||||
return result, source_codes
|
||||
|
||||
|
||||
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]:
|
||||
from .graph import SaveOutputCodeContext
|
||||
|
||||
return _run_and_get_code_for_context(
|
||||
fn, SaveOutputCodeContext.AFTER_COMPILE, args, kwargs
|
||||
)
|
||||
|
||||
|
||||
def run_and_get_code_before_compile(fn, *args, **kwargs) -> tuple[Any, list[str]]:
|
||||
from .graph import SaveOutputCodeContext
|
||||
|
||||
return _run_and_get_code_for_context(
|
||||
fn, SaveOutputCodeContext.BEFORE_COMPILE, args, kwargs
|
||||
)
|
||||
|
||||
|
||||
def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, list[str]]:
|
||||
result, source_codes = run_and_get_code(fn, *args, **kwargs)
|
||||
kernels = []
|
||||
@ -1521,13 +1498,12 @@ def run_fw_bw_and_get_code(fn):
|
||||
|
||||
def get_code(fn, *args, **kwargs):
|
||||
"""Get the inductor-generated code, but skip any actual compilation or running."""
|
||||
from .graph import GraphLowering, SaveOutputCodeContext
|
||||
from .graph import GraphLowering
|
||||
|
||||
source_codes: list[str] = []
|
||||
|
||||
def save_output_code(code: str, context: SaveOutputCodeContext):
|
||||
if context == SaveOutputCodeContext.AFTER_COMPILE:
|
||||
source_codes.append(code)
|
||||
def save_output_code(code: str):
|
||||
source_codes.append(code)
|
||||
|
||||
def patched_compile_to_module(self: GraphLowering):
|
||||
class DummyModule:
|
||||
@ -1545,7 +1521,7 @@ def get_code(fn, *args, **kwargs):
|
||||
)
|
||||
# Skip all the actual compiling.
|
||||
nonlocal save_output_code
|
||||
save_output_code(code, SaveOutputCodeContext.BEFORE_COMPILE)
|
||||
save_output_code(code)
|
||||
|
||||
return DummyModule()
|
||||
|
||||
|
@ -379,7 +379,7 @@ class FakeTensorConverter:
|
||||
out = self.meta_converter(
|
||||
t,
|
||||
shape_env=shape_env,
|
||||
callback=mk_fake_tensor,
|
||||
callback=mk_fake_tensor, # type: ignore[arg-type]
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
trace=trace,
|
||||
|
@ -541,27 +541,11 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
|
||||
return self.func(new_base, symint_visitor_fn, tensor_visitor_fn)
|
||||
|
||||
|
||||
# A callback where the device is either optional or required.
|
||||
# All of these satisfy this protocol:
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str])
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
|
||||
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
|
||||
def __call__(
|
||||
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
|
||||
) -> _TensorT_cov:
|
||||
...
|
||||
|
||||
|
||||
class _MetaTensorCallbackKwargs(TypedDict, total=False):
|
||||
device: Union[torch.device, str]
|
||||
|
||||
|
||||
# A callback where the device may not be provided (is optional).
|
||||
# All of these satisfy this protocol:
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
|
||||
class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
|
||||
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
|
||||
def __call__(
|
||||
self,
|
||||
arg: Callable[[], torch.Tensor],
|
||||
@ -848,13 +832,11 @@ class MetaConverter(Generic[_TensorT]):
|
||||
self,
|
||||
t: MetaTensorDesc,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
callback_: _MetaTensorCallback[_TensorT],
|
||||
callback: _MetaTensorCallback[_TensorT],
|
||||
source: Optional[Source],
|
||||
symbolic_context: Optional[SymbolicContext],
|
||||
) -> _TensorT:
|
||||
callback: _MetaTensorCallbackOptDevice = functools.partial(
|
||||
callback_, device=t.device
|
||||
)
|
||||
callback = functools.partial(callback, device=t.device)
|
||||
if source is None:
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
@ -999,7 +981,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
],
|
||||
callback: _MetaTensorCallbackOptDevice[_TensorT],
|
||||
callback: _MetaTensorCallback[_TensorT],
|
||||
source: torch._guards.Source,
|
||||
) -> _TensorT:
|
||||
# We are hitting plain meta_desc tensor so actually
|
||||
@ -1234,7 +1216,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
shape_env: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
||||
] = shape_env,
|
||||
callback: _MetaTensorCallbackOptDevice[_TensorT] = callback,
|
||||
callback: _MetaTensorCallback[_TensorT] = callback,
|
||||
) -> torch.Tensor:
|
||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||
if visited_t is None:
|
||||
@ -1787,9 +1769,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
# Thanks to storage resizing, it's possible to end up with a tensor
|
||||
# that advertises a real size, but has a storage that actually has zero bytes.
|
||||
# Need to reflect this in the generated FakeTensor.
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if t.storage is not None and guard_size_oblivious(t.storage.size == 0):
|
||||
if t.storage is not None and t.storage.size == 0:
|
||||
r.untyped_storage().resize_(0)
|
||||
|
||||
if t.is_parameter:
|
||||
|
@ -1,582 +0,0 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
import io
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, NewType, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing_extensions import override, Self
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._guards import TracingContext
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
|
||||
from torch._subclasses.meta_utils import (
|
||||
MetaConverter,
|
||||
MetaTensorDesc,
|
||||
MetaTensorDescriber,
|
||||
)
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
|
||||
_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat)
|
||||
|
||||
|
||||
class GraphPickler(pickle.Pickler):
|
||||
"""
|
||||
GraphPickler is a Pickler which helps pickling fx graph - in particular
|
||||
GraphModule.
|
||||
"""
|
||||
|
||||
def __init__(self, file: io.BytesIO) -> None:
|
||||
super().__init__(file)
|
||||
|
||||
# This abomination is so we can pass external decoding state to the
|
||||
# unpickler functions. We serialize _unpickle_state as a persistent
|
||||
# external item and when we deserialize it we return the common state
|
||||
# object.
|
||||
self._unpickle_state = _UnpickleStateToken(object())
|
||||
|
||||
# This is used to describe tensors. It needs to be common across the
|
||||
# pickle so that duplicates and views are properly handled.
|
||||
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
|
||||
|
||||
@override
|
||||
def reducer_override(
|
||||
self, obj: object
|
||||
) -> Tuple[Callable[..., Any], Tuple[Any, ...]]:
|
||||
# This function is supposed to return either NotImplemented (meaning to
|
||||
# do the default pickle behavior) or a pair of (unpickle callable, data
|
||||
# to pass to unpickle).
|
||||
|
||||
# We could instead teach individual classes how to pickle themselves but
|
||||
# that has a few problems:
|
||||
#
|
||||
# 1. If we have some special needs (maybe for this use-case we don't
|
||||
# want to fully serialize every field) then we're adding private
|
||||
# details to a public interface.
|
||||
#
|
||||
# 2. If we need to have some common shared data (such as a
|
||||
# FakeTensorMode) which is passed to each value it's harder to
|
||||
# support.
|
||||
|
||||
# These are the types that need special handling. See the individual
|
||||
# *PickleData classes for details on pickling that particular type.
|
||||
if isinstance(obj, FakeTensor):
|
||||
return _TensorPickleData.reduce_helper(self, obj)
|
||||
elif isinstance(obj, torch.fx.GraphModule):
|
||||
return _GraphModulePickleData.reduce_helper(self, obj)
|
||||
elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)):
|
||||
return _OpPickleData.reduce_helper(self, obj)
|
||||
elif isinstance(obj, ShapeEnv):
|
||||
return _ShapeEnvPickleData.reduce_helper(self, obj)
|
||||
elif isinstance(obj, torch.SymInt):
|
||||
return _SymNodePickleData.reduce_helper(self, obj)
|
||||
elif isinstance(obj, torch._guards.TracingContext):
|
||||
return _TracingContextPickleData.reduce_helper(self, obj)
|
||||
else:
|
||||
# We should never get a raw Node!
|
||||
assert not isinstance(obj, torch.fx.Node)
|
||||
if reduce := _TorchNumpyPickleData.reduce_helper(self, obj):
|
||||
return reduce
|
||||
|
||||
# returning `NotImplemented` causes pickle to revert to the default
|
||||
# behavior for this object.
|
||||
return NotImplemented
|
||||
|
||||
@override
|
||||
def persistent_id(self, obj: object) -> Optional[str]:
|
||||
if obj is self._unpickle_state:
|
||||
return "unpickle_state"
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def dumps(cls, obj: object) -> bytes:
|
||||
"""
|
||||
Pickle an object.
|
||||
"""
|
||||
with io.BytesIO() as stream:
|
||||
pickler = cls(stream)
|
||||
pickler.dump(obj)
|
||||
return stream.getvalue()
|
||||
|
||||
@staticmethod
|
||||
def loads(data: bytes, fake_mode: FakeTensorMode) -> object:
|
||||
"""
|
||||
Unpickle an object.
|
||||
"""
|
||||
state = _UnpickleState(fake_mode)
|
||||
with io.BytesIO(data) as stream:
|
||||
unpickler = _GraphUnpickler(stream, state)
|
||||
return unpickler.load()
|
||||
|
||||
|
||||
class _UnpickleState:
|
||||
def __init__(self, fake_mode: FakeTensorMode) -> None:
|
||||
self.fake_mode = fake_mode
|
||||
self.meta_converter: MetaConverter[FakeTensor] = MetaConverter()
|
||||
|
||||
|
||||
# This token is passed when pickling to indicate that we want to use the
|
||||
# unpickler's _UnpickleState as a parameter in that position.
|
||||
_UnpickleStateToken = NewType("_UnpickleStateToken", object)
|
||||
|
||||
|
||||
class _GraphUnpickler(pickle.Unpickler):
|
||||
def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None:
|
||||
super().__init__(stream)
|
||||
self._unpickle_state = unpickle_state
|
||||
|
||||
@override
|
||||
def persistent_load(self, pid: object) -> object:
|
||||
if pid == "unpickle_state":
|
||||
return self._unpickle_state
|
||||
else:
|
||||
raise pickle.UnpicklingError("Invalid persistent ID")
|
||||
|
||||
|
||||
class _ShapeEnvPickleData:
|
||||
data: Dict[str, object]
|
||||
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: ShapeEnv
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], ShapeEnv], Tuple[Self, _UnpickleStateToken]
|
||||
]:
|
||||
return cls.unpickle, (cls(obj), pickler._unpickle_state)
|
||||
|
||||
def __init__(self, env: ShapeEnv) -> None:
|
||||
# In theory pickle should recognize that a given ShapeEnv was already
|
||||
# pickled and reuse the resulting _ShapeEnvPickleData (so two objects
|
||||
# pointing at the same ShapeEnv get the same ShapeEnv out).
|
||||
assert not env._translation_validation_enabled
|
||||
self.data = env.__dict__.copy()
|
||||
del self.data["tracked_fakes"]
|
||||
del self.data["fake_tensor_cache"]
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv:
|
||||
# Fill in the existing ShapeEnv rather than creating a new one
|
||||
assert unpickle_state.fake_mode
|
||||
assert unpickle_state.fake_mode.shape_env
|
||||
|
||||
for k, v in self.data.items():
|
||||
setattr(unpickle_state.fake_mode.shape_env, k, v)
|
||||
|
||||
return unpickle_state.fake_mode.shape_env
|
||||
|
||||
|
||||
class _SymNodePickleData:
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls,
|
||||
pickler: GraphPickler,
|
||||
obj: _SymNodeT,
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], _SymNodeT], Tuple[Self, _UnpickleStateToken]
|
||||
]:
|
||||
args = (cls(obj.node), pickler._unpickle_state)
|
||||
if isinstance(obj, torch.SymInt):
|
||||
return _SymNodePickleData.unpickle_sym_int, args
|
||||
else:
|
||||
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
|
||||
|
||||
def __init__(self, node: SymNode) -> None:
|
||||
self.expr = node._expr
|
||||
self.shape_env = node.shape_env
|
||||
self.pytype = node.pytype
|
||||
self.hint = node._hint
|
||||
|
||||
def _to_sym_node(self) -> SymNode:
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
|
||||
assert self.shape_env is not None
|
||||
return SymNode(self.expr, self.shape_env, self.pytype, self.hint)
|
||||
|
||||
def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt:
|
||||
return torch.SymInt(self._to_sym_node())
|
||||
|
||||
|
||||
class _TensorPickleData:
|
||||
metadata: MetaTensorDesc[FakeTensor]
|
||||
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: FakeTensor
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], FakeTensor], Tuple[Self, _UnpickleStateToken]
|
||||
]:
|
||||
return cls.unpickle, (
|
||||
cls(pickler._meta_tensor_describer, obj),
|
||||
pickler._unpickle_state,
|
||||
)
|
||||
|
||||
def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None:
|
||||
# THINGS TO WORRY ABOUT:
|
||||
# 1. Need to make sure that two tensors with the same id end up with the
|
||||
# same id on the other side of the wire.
|
||||
|
||||
metadata = describer.describe_tensor(t)
|
||||
|
||||
# view_func is fine if it's either None or a _FakeTensorViewFunc. A
|
||||
# custom one (which is basically a lambda) can't be serialized.
|
||||
assert not metadata.view_func or isinstance(
|
||||
metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc
|
||||
)
|
||||
self.metadata = dataclasses.replace(metadata, fake_mode=None)
|
||||
|
||||
# Some debugging/verification
|
||||
for k in MetaTensorDesc._UNSERIALIZABLE:
|
||||
if k in ("fake_mode", "view_func"):
|
||||
continue
|
||||
assert (
|
||||
getattr(self.metadata, k) is None
|
||||
), f"not None: {k}: {getattr(self.metadata, k)}"
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor:
|
||||
# TODO: make common w/ _output_from_cache_entry() in fake_tensor.py?
|
||||
metadata = dataclasses.replace(
|
||||
self.metadata,
|
||||
fake_mode=unpickle_state.fake_mode,
|
||||
)
|
||||
|
||||
def with_fake(
|
||||
make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str]
|
||||
) -> FakeTensor:
|
||||
with no_dispatch():
|
||||
return FakeTensor(
|
||||
unpickle_state.fake_mode,
|
||||
make_meta_t(),
|
||||
device,
|
||||
)
|
||||
|
||||
return unpickle_state.meta_converter.meta_tensor(
|
||||
metadata,
|
||||
unpickle_state.fake_mode.shape_env,
|
||||
with_fake,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class _TorchNumpyPickleData:
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: object
|
||||
) -> Optional[
|
||||
Tuple[
|
||||
Callable[[Self, _UnpickleState], object], Tuple[Self, _UnpickleStateToken]
|
||||
]
|
||||
]:
|
||||
if data := cls.from_object(obj):
|
||||
return (cls.unpickle, (data, pickler._unpickle_state))
|
||||
else:
|
||||
return None
|
||||
|
||||
def __init__(self, mod: str, name: str) -> None:
|
||||
self.mod = mod
|
||||
self.name = name
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]:
|
||||
np = getattr(importlib.import_module(self.mod), self.name)
|
||||
return torch._dynamo.variables.misc.get_np_to_tnp_map()[np]
|
||||
|
||||
@classmethod
|
||||
def from_object(cls, tnp: object) -> Optional[Self]:
|
||||
if not callable(tnp):
|
||||
return None
|
||||
|
||||
tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map()
|
||||
try:
|
||||
if not (np := tnp_to_np.get(tnp)):
|
||||
return None
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
if not (mod := getattr(np, "__module__", None)):
|
||||
mod = "numpy"
|
||||
|
||||
if not (name := getattr(np, "__name__", None)):
|
||||
return None
|
||||
|
||||
assert np == getattr(importlib.import_module(mod), name)
|
||||
return cls(mod, name)
|
||||
|
||||
|
||||
class _GraphModulePickleData:
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: torch.fx.GraphModule
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], torch.fx.GraphModule],
|
||||
Tuple[Self, _UnpickleStateToken],
|
||||
]:
|
||||
return cls.unpickle, (
|
||||
cls(obj),
|
||||
pickler._unpickle_state,
|
||||
)
|
||||
|
||||
def __init__(self, gm: torch.fx.GraphModule) -> None:
|
||||
# Need to do this to ensure the code is created for later pickling.
|
||||
if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule):
|
||||
_python_code = gm._real_recompile()
|
||||
else:
|
||||
_python_code = gm.recompile()
|
||||
self.gm_dict = gm.__dict__.copy()
|
||||
del self.gm_dict["_graph"]
|
||||
self.graph = _GraphPickleData(gm._graph)
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule:
|
||||
gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule)
|
||||
gm.__dict__ = self.gm_dict
|
||||
gm._graph = self.graph.unpickle(gm, unpickle_state)
|
||||
return gm
|
||||
|
||||
|
||||
class _NodePickleData:
|
||||
def __init__(
|
||||
self, node: torch.fx.Node, mapping: Dict[torch.fx.Node, "_NodePickleData"]
|
||||
) -> None:
|
||||
self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args)
|
||||
self.kwargs = pytree.tree_map_only(
|
||||
torch.fx.Node, lambda n: mapping[n], node.kwargs
|
||||
)
|
||||
# -- self.graph = node.graph
|
||||
self.name = node.name
|
||||
self.op = node.op
|
||||
self.target = _OpPickleData.pickle(node.target)
|
||||
# self.input_nodes = node._input_nodes
|
||||
# self.users = node.users
|
||||
self.type = node.type
|
||||
# self.sort_key = node._sort_key
|
||||
# self.repr_fn = node._repr_fn
|
||||
# self.meta = node.meta
|
||||
self.meta = node.meta
|
||||
|
||||
def unpickle(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
mapping: Dict["_NodePickleData", torch.fx.Node],
|
||||
unpickle_state: _UnpickleState,
|
||||
) -> torch.fx.Node:
|
||||
args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args)
|
||||
kwargs = pytree.tree_map_only(
|
||||
_NodePickleData, lambda n: mapping[n], self.kwargs
|
||||
)
|
||||
target = self.target.unpickle(unpickle_state)
|
||||
assert callable(target) or isinstance(target, str)
|
||||
node = graph.create_node(self.op, target, args, kwargs, self.name, self.type)
|
||||
node.meta = self.meta
|
||||
return node
|
||||
|
||||
|
||||
class _OpPickleData:
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, op: object
|
||||
) -> Tuple[Callable[[_UnpickleState], object], Tuple[_UnpickleStateToken]]:
|
||||
result = cls.pickle(op)
|
||||
return (result.unpickle, (pickler._unpickle_state,))
|
||||
|
||||
@classmethod
|
||||
def pickle(cls, op: object) -> "_OpPickleData":
|
||||
if isinstance(op, str):
|
||||
return _OpStrPickleData(op)
|
||||
|
||||
name = torch.fx.Node._pretty_print_target(op)
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
return cls._pickle_op(name, _OpOverloadPickleData)
|
||||
elif isinstance(op, torch._ops.OpOverloadPacket):
|
||||
return cls._pickle_op(name, _OpOverloadPacketPickleData)
|
||||
elif name.startswith(("builtins.", "math.", "torch.")):
|
||||
root, detail = name.split(".", 1)
|
||||
return _OpBuiltinPickleData(root, detail)
|
||||
elif name.startswith("operator."):
|
||||
_, detail = name.split(".", 1)
|
||||
return _OpOperatorPickleData(detail)
|
||||
else:
|
||||
# TODO: raise a BypassFxGraphCache so we will just bypass this one...
|
||||
raise NotImplementedError(f"TARGET: {type(op)} {op} {name}")
|
||||
|
||||
@staticmethod
|
||||
def _pickle_op(
|
||||
name: str,
|
||||
datacls: Union[
|
||||
Type["_OpOverloadPickleData"], Type["_OpOverloadPacketPickleData"]
|
||||
],
|
||||
) -> "_OpPickleData":
|
||||
if not name.startswith("torch.ops.aten"): # TODO: What's the full list?
|
||||
from torch._inductor.codecache import BypassFxGraphCache
|
||||
|
||||
raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}")
|
||||
return datacls(name)
|
||||
|
||||
@abstractmethod
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _lookup_global_by_name(cls, name: str) -> object:
|
||||
"""
|
||||
Like `globals()[name]` but supports dotted names.
|
||||
"""
|
||||
if "." in name:
|
||||
mod, rest = name.split(".", 1)
|
||||
root = globals()[mod]
|
||||
return cls._getattr_by_name(root, rest)
|
||||
else:
|
||||
return globals()[name]
|
||||
|
||||
@staticmethod
|
||||
def _getattr_by_name(root: object, name: str) -> object:
|
||||
"""
|
||||
Like `getattr(root, name)` but supports dotted names.
|
||||
"""
|
||||
while "." in name:
|
||||
mod, name = name.split(".", 1)
|
||||
root = getattr(root, mod)
|
||||
return getattr(root, name)
|
||||
|
||||
|
||||
class _OpStrPickleData(_OpPickleData):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> str:
|
||||
return self.name
|
||||
|
||||
|
||||
class _OpOverloadPickleData(_OpPickleData):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload:
|
||||
obj = self._lookup_global_by_name(self.name)
|
||||
assert isinstance(obj, torch._ops.OpOverload)
|
||||
return obj
|
||||
|
||||
|
||||
class _OpOverloadPacketPickleData(_OpPickleData):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket:
|
||||
obj = self._lookup_global_by_name(self.name)
|
||||
assert isinstance(obj, torch._ops.OpOverloadPacket)
|
||||
return obj
|
||||
|
||||
|
||||
class _OpBuiltinPickleData(_OpPickleData):
|
||||
def __init__(self, root: str, name: str) -> None:
|
||||
self.root = root
|
||||
self.name = name
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
||||
if self.root == "builtins":
|
||||
return __builtins__.get(self.name) # type: ignore[attr-defined]
|
||||
elif self.root == "math":
|
||||
import math
|
||||
|
||||
return self._getattr_by_name(math, self.name)
|
||||
elif self.root == "torch":
|
||||
return self._getattr_by_name(torch, self.name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _OpOperatorPickleData(_OpPickleData):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> object:
|
||||
import operator
|
||||
|
||||
return self._getattr_by_name(operator, self.name)
|
||||
|
||||
|
||||
class _GraphPickleData:
|
||||
def __init__(self, graph: torch.fx.Graph) -> None:
|
||||
self.tracer_cls = graph._tracer_cls
|
||||
self.tracer_extras = graph._tracer_extras
|
||||
|
||||
nodes: Dict[torch.fx.Node, _NodePickleData] = {}
|
||||
for node in graph.nodes:
|
||||
nodes[node] = _NodePickleData(node, nodes)
|
||||
self.nodes = tuple(nodes.values())
|
||||
|
||||
# Unpickled variables:
|
||||
# self._used_names = graph._used_names
|
||||
# -- self._insert = self._root.prepend
|
||||
# self._len = graph._len
|
||||
# self._graph_namespace = graph._graph_namespace
|
||||
# self._owning_module = graph._owning_module
|
||||
# self._codegen = graph._codegen
|
||||
# self._co_fields: Dict[str, Any] = graph._co_fields
|
||||
# -- self._find_nodes_lookup_table = _FindNodesLookupTable()
|
||||
|
||||
def unpickle(
|
||||
self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState
|
||||
) -> torch.fx.Graph:
|
||||
graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras)
|
||||
|
||||
nodes: Dict[_NodePickleData, torch.fx.Node] = {}
|
||||
for nd in self.nodes:
|
||||
nodes[nd] = nd.unpickle(graph, nodes, unpickle_state)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
class _TracingContextPickleData:
|
||||
@classmethod
|
||||
def reduce_helper(
|
||||
cls, pickler: GraphPickler, obj: torch._guards.TracingContext
|
||||
) -> Tuple[
|
||||
Callable[[Self, _UnpickleState], torch._guards.TracingContext],
|
||||
Tuple[Self, _UnpickleStateToken],
|
||||
]:
|
||||
return (
|
||||
cls.unpickle,
|
||||
(
|
||||
cls(obj),
|
||||
pickler._unpickle_state,
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, context: TracingContext) -> None:
|
||||
# TODO: Do we really need all of this?
|
||||
self.module_context = context.module_context
|
||||
self.frame_summary_stack = context.frame_summary_stack
|
||||
self.loc_in_frame = context.loc_in_frame
|
||||
self.aot_graph_name = context.aot_graph_name
|
||||
self.params_flat = context.params_flat
|
||||
self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses
|
||||
self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index
|
||||
self.output_strides = context.output_strides
|
||||
self.force_unspec_int_unbacked_size_like = (
|
||||
context.force_unspec_int_unbacked_size_like
|
||||
)
|
||||
# Not saved (because it's difficult and maybe not needed?):
|
||||
# self.fw_metadata = context.fw_metadata
|
||||
# self.guards_context = None
|
||||
# self.global_context = None
|
||||
# self.fake_mode = None
|
||||
# self.fakify_first_call = None
|
||||
# self.hop_dispatch_set_cache = None
|
||||
# self.tensor_to_context = context.tensor_to_context
|
||||
|
||||
def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext:
|
||||
context = TracingContext(unpickle_state.fake_mode)
|
||||
context.module_context = self.module_context
|
||||
context.frame_summary_stack = self.frame_summary_stack
|
||||
context.loc_in_frame = self.loc_in_frame
|
||||
context.aot_graph_name = self.aot_graph_name
|
||||
context.params_flat = self.params_flat
|
||||
context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses
|
||||
context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index
|
||||
context.output_strides = self.output_strides
|
||||
context.force_unspec_int_unbacked_size_like = (
|
||||
self.force_unspec_int_unbacked_size_like
|
||||
)
|
||||
return context
|
@ -598,8 +598,7 @@ class Node(_NodeBase):
|
||||
return self._repr_fn(self)
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
def _pretty_print_target(target: object) -> str:
|
||||
def _pretty_print_target(self, target: object) -> str:
|
||||
"""
|
||||
Make target printouts more user-friendly.
|
||||
1) builtins will be printed as `builtins.xyz`
|
||||
|
Reference in New Issue
Block a user