mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add experimental torch.export prototype (#95070)
This is WIP PR for adding torch.export API in OSS. Couple of points: - I intentionally named it as experimental_export so that ppl don't get confused thinking this is our official API - We don't plan to use AOTAutograd backend just yet. The reason we have it here is because the functionalization AOTAutograd uses is what we need for export (handling of param/buffer mutation etc). In the near future, I will extract the functionalization part and use it on top of make_fx. What we have right now is merely a placeholder. - The reason we want to do it now is because we want to have some minimal tests running in OSS so that we can catch regressions earlier. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95070 Approved by: https://github.com/gmagogsfm, https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
801b3f8fc7
commit
454c48b987
79
test/export/test_export.py
Normal file
79
test/export/test_export.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from functorch.experimental.control_flow import cond
|
||||
from torch._export import do_not_use_experimental_export
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
class TestExport(TestCase):
|
||||
@unittest.skip("dynamo failure -> RuntimeError: Could not infer dtype of SymBool")
|
||||
def test_export_cond(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
def false_fn(x):
|
||||
return x.cos()
|
||||
|
||||
def foo(x):
|
||||
return cond(torch.tensor(x.shape[0] > 4), true_fn, false_fn, [x])
|
||||
|
||||
exported_program = do_not_use_experimental_export(foo, (torch.ones(6, 4, requires_grad=True),))
|
||||
print(exported_program.graph_module.graph)
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
def test_export_simple_model_with_attr(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, float_val):
|
||||
super().__init__()
|
||||
self.float_val = float_val
|
||||
|
||||
def forward(self, x):
|
||||
y = x + self.float_val
|
||||
return y.cos()
|
||||
|
||||
inp = (torch.ones(6, 4, requires_grad=True),)
|
||||
mod = Foo(0.5)
|
||||
|
||||
exported_program = do_not_use_experimental_export(mod, inp)
|
||||
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
def test_export_simple_model(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, float_val):
|
||||
super().__init__()
|
||||
self.float_val = float_val
|
||||
|
||||
def forward(self, x):
|
||||
return x.cos()
|
||||
|
||||
inp = (torch.ones(6, 4, requires_grad=True),)
|
||||
mod = Foo(0.5)
|
||||
|
||||
exported_program = do_not_use_experimental_export(mod, inp)
|
||||
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
def test_export_simple_model_buffer_mutation(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, float_val):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.ones(6, 1))
|
||||
|
||||
def forward(self, x):
|
||||
self.buffer1.add_(2)
|
||||
return x.cos() + self.buffer1.sin()
|
||||
|
||||
inp = (torch.ones(6, 4, requires_grad=True),)
|
||||
mod = Foo(0.5)
|
||||
|
||||
exported_program = do_not_use_experimental_export(mod, inp)
|
||||
mutated_buffer, output = exported_program.fw_module(*inp)
|
||||
# TODO (tmanlaibaatar) enable this once we figure out
|
||||
# how to do buffer mutation
|
||||
# self.assertEqual(mutated_buffer.sum().item(), 30)
|
||||
self.assertEqual(output, mod(*inp))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -6,6 +6,7 @@ from .eval_frame import (
|
||||
disable,
|
||||
explain,
|
||||
export,
|
||||
is_dynamo_supported,
|
||||
optimize,
|
||||
optimize_assert,
|
||||
OptimizedModule,
|
||||
|
@ -383,6 +383,14 @@ def check_if_dynamo_supported():
|
||||
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
|
||||
|
||||
|
||||
def is_dynamo_supported():
|
||||
try:
|
||||
check_if_dynamo_supported()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def optimize(
|
||||
backend="inductor",
|
||||
*,
|
||||
|
@ -0,0 +1,215 @@
|
||||
import contextlib
|
||||
import copy
|
||||
from typing import Callable, Tuple, Generator, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch._decomp import core_aten_decompositions
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch.nn.utils import stateless
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
from torch._functorch.aot_autograd import (
|
||||
AOTConfig,
|
||||
create_aot_dispatcher_function,
|
||||
default_partition,
|
||||
run_functionalized_fw_and_collect_metadata,
|
||||
)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_slot,
|
||||
get_torch_dispatch_modes,
|
||||
has_proxy_slot,
|
||||
make_fx,
|
||||
ProxyTorchDispatchMode,
|
||||
set_proxy_slot,
|
||||
)
|
||||
|
||||
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional
|
||||
|
||||
from .workflow import ExportedProgram
|
||||
|
||||
CORE_ATEN_DECOMPOSITIONS_TABLE = core_aten_decompositions()
|
||||
|
||||
__all__ = ["experimental_export"]
|
||||
|
||||
|
||||
def _aot_capture(mod, flat_args):
|
||||
"""
|
||||
A wrapper around aot_autograd() to mix AOT Autograd + torch.export.
|
||||
Some assumptions were made about the AOT Autograd internal:
|
||||
1. The functionalization metadata format.
|
||||
2. Calling convention of returned forward graph.
|
||||
3. make_fx() internal proxy storage.
|
||||
|
||||
In the current context we're just experimenting the idea so it's possible things
|
||||
could break. For the next step we should find a way to upstream something reasonable.
|
||||
"""
|
||||
param_list = [
|
||||
*mod.named_parameters(remove_duplicate=False),
|
||||
*mod.named_buffers(remove_duplicate=False),
|
||||
]
|
||||
params = dict(param_list)
|
||||
params_flat, params_spec = pytree.tree_flatten(params)
|
||||
params_len = len(params_flat)
|
||||
|
||||
full_args = []
|
||||
full_args.extend(params_flat)
|
||||
full_args.extend(flat_args)
|
||||
|
||||
def functional_call(*args):
|
||||
|
||||
with stateless._reparametrize_module(
|
||||
mod,
|
||||
pytree.tree_unflatten(args[:params_len], params_spec), # type: ignore[arg-type]
|
||||
):
|
||||
return torch.fx.Interpreter(mod).run(*args[params_len:])
|
||||
|
||||
out_spec = None
|
||||
|
||||
with enable_python_dispatcher():
|
||||
fw_metadata, _ = run_functionalized_fw_and_collect_metadata(
|
||||
lambda *args: pytree.tree_flatten(functional_call(*args))[0],
|
||||
keep_input_mutations=False,
|
||||
)(*copy.deepcopy(full_args)) # type: ignore[operator]
|
||||
|
||||
assert len(fw_metadata.input_info) == len(full_args)
|
||||
mutated_input_indices = [
|
||||
i
|
||||
for i, input_info in enumerate(fw_metadata.input_info)
|
||||
if input_info.mutates_data or input_info.mutates_metadata
|
||||
]
|
||||
|
||||
graph_module = None
|
||||
|
||||
def fw_compiler(gm, inputs):
|
||||
nonlocal graph_module
|
||||
graph_module = gm
|
||||
|
||||
num_fwd_returns = None
|
||||
|
||||
def partition_fn(joint_module, joint_inputs, *, num_fwd_outputs, **kwargs):
|
||||
nonlocal num_fwd_returns
|
||||
num_fwd_returns = num_fwd_outputs
|
||||
return default_partition(
|
||||
joint_module, joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs
|
||||
)
|
||||
|
||||
def set_state_proxies(state_args):
|
||||
modes = get_torch_dispatch_modes()
|
||||
proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
|
||||
if len(proxy_tensor_modes) == 0:
|
||||
return
|
||||
assert len(state_args) == len(params_flat)
|
||||
for i, arg in enumerate(state_args):
|
||||
tracer = next(
|
||||
m.tracer for m in proxy_tensor_modes if has_proxy_slot(arg, m.tracer)
|
||||
)
|
||||
set_proxy_slot(arg, tracer, params_flat[i])
|
||||
|
||||
aot_config = AOTConfig(
|
||||
fw_compiler=fw_compiler,
|
||||
bw_compiler=lambda gm, inputs: None,
|
||||
partition_fn=partition_fn,
|
||||
decompositions=CORE_ATEN_DECOMPOSITIONS_TABLE, # type: ignore[arg-type]
|
||||
num_params_buffers=params_len,
|
||||
aot_id=-1,
|
||||
keep_inference_input_mutations=False,
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def setup_dynamic_shape():
|
||||
prev, torch._functorch.config.use_dynamic_shapes = (
|
||||
torch._functorch.config.use_dynamic_shapes,
|
||||
True,
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._functorch.config.use_dynamic_shapes = prev
|
||||
|
||||
def exported_call(*args):
|
||||
state_args = args[:params_len]
|
||||
unwrapped_state_args = _unwrap_all_tensors_from_functional(
|
||||
state_args, reapply_views=False
|
||||
)
|
||||
set_state_proxies(unwrapped_state_args)
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
outputs = functional_call(*args)
|
||||
nonlocal out_spec
|
||||
outputs, out_spec = pytree.tree_flatten(outputs)
|
||||
return outputs
|
||||
|
||||
with torch.enable_grad(), setup_dynamic_shape():
|
||||
create_aot_dispatcher_function(
|
||||
exported_call,
|
||||
full_args,
|
||||
aot_config,
|
||||
)
|
||||
|
||||
assert graph_module is not None
|
||||
|
||||
for i, node in enumerate(graph_module.graph.nodes):
|
||||
if i == len(params_flat):
|
||||
break
|
||||
assert node.op == "placeholder" and len(node.users) == 0
|
||||
graph_module.graph.erase_node(node)
|
||||
|
||||
output_node = next(iter(reversed(graph_module.graph.nodes)))
|
||||
assert output_node.op == "output" and len(output_node.args) == 1
|
||||
assert num_fwd_returns is not None
|
||||
# Turncate the output so we only output what we need.
|
||||
output_node.args = (
|
||||
output_node.args[0][
|
||||
: len(mutated_input_indices) + len(fw_metadata.output_info)
|
||||
],
|
||||
)
|
||||
|
||||
graph_module.graph.eliminate_dead_code()
|
||||
graph_module.recompile()
|
||||
|
||||
def find_mutation_destinations(gm, w):
|
||||
assert isinstance(w, torch.Tensor)
|
||||
ret = [
|
||||
name for name, x in [*gm.named_parameters(), *gm.named_buffers()] if x is w
|
||||
]
|
||||
assert len(ret) != 0, "Cannot find mutation destination."
|
||||
return ret
|
||||
|
||||
mutation = [
|
||||
(
|
||||
"copy_",
|
||||
output_node.args[0][k].name,
|
||||
find_mutation_destinations(graph_module, param_list[i][1]),
|
||||
)
|
||||
for k, i in enumerate(mutated_input_indices)
|
||||
]
|
||||
assert out_spec is not None
|
||||
return graph_module, mutation, out_spec
|
||||
|
||||
|
||||
@patch.object(torchdynamo.config, "dynamic_shapes", True)
|
||||
@patch.object(torchdynamo.config, "capture_scalar_outputs", True)
|
||||
@patch.object(torchdynamo.config, "guard_nn_modules", True)
|
||||
@patch.object(torchdynamo.config, "specialize_int_float", True)
|
||||
@patch.object(torchdynamo.config, "allow_rnn", True)
|
||||
@patch.object(torchdynamo.config, "verbose", True)
|
||||
def do_not_use_experimental_export(f: Callable, args: Tuple, training=False):
|
||||
"""
|
||||
This prototype is under heavy development. Pls don't use it if you are
|
||||
not part of PyTorch 2.0 Export team.
|
||||
"""
|
||||
if training:
|
||||
NotImplementedError("training mode is not supported yet")
|
||||
|
||||
flattened_args, in_spec = pytree.tree_flatten(args)
|
||||
# Doing it twice so that if graph_module accidentally modifies the input
|
||||
# we still get the same original input.
|
||||
original_flat_args = tuple(flattened_args)
|
||||
flat_args = tuple(flattened_args)
|
||||
|
||||
graph_module, guards = torchdynamo.export(f, *args, aten_graph=False)
|
||||
# TODO (tmanlaibaatar) do sth with guards?
|
||||
graph_module, _, out_spec = _aot_capture(graph_module, flat_args)
|
||||
return ExportedProgram(fw_module=graph_module, example_inputs=original_flat_args, in_spec=in_spec, out_spec=out_spec)
|
||||
|
@ -5,6 +5,14 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"PassManager",
|
||||
"inplace_wrapper",
|
||||
"log_hook",
|
||||
"loop_pass",
|
||||
"this_before_that_pass_constraint",
|
||||
"these_before_those_pass_constraint",
|
||||
]
|
||||
|
||||
# for callables which modify object inplace and return something other than
|
||||
# the object on which they act
|
||||
|
Reference in New Issue
Block a user