Add experimental torch.export prototype (#95070)

This is WIP PR for adding torch.export API in OSS. Couple of points:
- I intentionally named it as experimental_export so that ppl don't get confused thinking this is our official API
- We don't plan to use AOTAutograd backend just yet. The reason we have it here is because the functionalization AOTAutograd uses is what we need for export (handling of param/buffer mutation etc). In the near future, I will extract the functionalization part and use it on top of make_fx. What we have right now is merely a placeholder.
- The reason we want to do it now is because we want to have some minimal tests running in OSS so that we can catch regressions earlier.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95070
Approved by: https://github.com/gmagogsfm, https://github.com/zhxchen17
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2023-02-27 11:40:32 -08:00
committed by PyTorch MergeBot
parent 801b3f8fc7
commit 454c48b987
5 changed files with 311 additions and 0 deletions

View File

@ -0,0 +1,79 @@
# Owner(s): ["module: dynamo"]
from torch.testing._internal.common_utils import run_tests, TestCase
from functorch.experimental.control_flow import cond
from torch._export import do_not_use_experimental_export
import torch._dynamo as torchdynamo
import torch
import unittest
class TestExport(TestCase):
@unittest.skip("dynamo failure -> RuntimeError: Could not infer dtype of SymBool")
def test_export_cond(self):
def true_fn(x):
return x.sin()
def false_fn(x):
return x.cos()
def foo(x):
return cond(torch.tensor(x.shape[0] > 4), true_fn, false_fn, [x])
exported_program = do_not_use_experimental_export(foo, (torch.ones(6, 4, requires_grad=True),))
print(exported_program.graph_module.graph)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
def test_export_simple_model_with_attr(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.float_val = float_val
def forward(self, x):
y = x + self.float_val
return y.cos()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
def test_export_simple_model(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.float_val = float_val
def forward(self, x):
return x.cos()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
def test_export_simple_model_buffer_mutation(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 1))
def forward(self, x):
self.buffer1.add_(2)
return x.cos() + self.buffer1.sin()
inp = (torch.ones(6, 4, requires_grad=True),)
mod = Foo(0.5)
exported_program = do_not_use_experimental_export(mod, inp)
mutated_buffer, output = exported_program.fw_module(*inp)
# TODO (tmanlaibaatar) enable this once we figure out
# how to do buffer mutation
# self.assertEqual(mutated_buffer.sum().item(), 30)
self.assertEqual(output, mod(*inp))
if __name__ == '__main__':
run_tests()

View File

@ -6,6 +6,7 @@ from .eval_frame import (
disable,
explain,
export,
is_dynamo_supported,
optimize,
optimize_assert,
OptimizedModule,

View File

@ -383,6 +383,14 @@ def check_if_dynamo_supported():
raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
def is_dynamo_supported():
try:
check_if_dynamo_supported()
return True
except Exception:
return False
def optimize(
backend="inductor",
*,

View File

@ -0,0 +1,215 @@
import contextlib
import copy
from typing import Callable, Tuple, Generator, Dict
from unittest.mock import patch
import torch
import torch._dynamo as torchdynamo
from torch._decomp import core_aten_decompositions
from torch._dispatch.python import enable_python_dispatcher
from torch.nn.utils import stateless
from torch.utils import _pytree as pytree
from torch._functorch.aot_autograd import (
AOTConfig,
create_aot_dispatcher_function,
default_partition,
run_functionalized_fw_and_collect_metadata,
)
from torch.fx.experimental.proxy_tensor import (
get_proxy_slot,
get_torch_dispatch_modes,
has_proxy_slot,
make_fx,
ProxyTorchDispatchMode,
set_proxy_slot,
)
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional
from .workflow import ExportedProgram
CORE_ATEN_DECOMPOSITIONS_TABLE = core_aten_decompositions()
__all__ = ["experimental_export"]
def _aot_capture(mod, flat_args):
"""
A wrapper around aot_autograd() to mix AOT Autograd + torch.export.
Some assumptions were made about the AOT Autograd internal:
1. The functionalization metadata format.
2. Calling convention of returned forward graph.
3. make_fx() internal proxy storage.
In the current context we're just experimenting the idea so it's possible things
could break. For the next step we should find a way to upstream something reasonable.
"""
param_list = [
*mod.named_parameters(remove_duplicate=False),
*mod.named_buffers(remove_duplicate=False),
]
params = dict(param_list)
params_flat, params_spec = pytree.tree_flatten(params)
params_len = len(params_flat)
full_args = []
full_args.extend(params_flat)
full_args.extend(flat_args)
def functional_call(*args):
with stateless._reparametrize_module(
mod,
pytree.tree_unflatten(args[:params_len], params_spec), # type: ignore[arg-type]
):
return torch.fx.Interpreter(mod).run(*args[params_len:])
out_spec = None
with enable_python_dispatcher():
fw_metadata, _ = run_functionalized_fw_and_collect_metadata(
lambda *args: pytree.tree_flatten(functional_call(*args))[0],
keep_input_mutations=False,
)(*copy.deepcopy(full_args)) # type: ignore[operator]
assert len(fw_metadata.input_info) == len(full_args)
mutated_input_indices = [
i
for i, input_info in enumerate(fw_metadata.input_info)
if input_info.mutates_data or input_info.mutates_metadata
]
graph_module = None
def fw_compiler(gm, inputs):
nonlocal graph_module
graph_module = gm
num_fwd_returns = None
def partition_fn(joint_module, joint_inputs, *, num_fwd_outputs, **kwargs):
nonlocal num_fwd_returns
num_fwd_returns = num_fwd_outputs
return default_partition(
joint_module, joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs
)
def set_state_proxies(state_args):
modes = get_torch_dispatch_modes()
proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
if len(proxy_tensor_modes) == 0:
return
assert len(state_args) == len(params_flat)
for i, arg in enumerate(state_args):
tracer = next(
m.tracer for m in proxy_tensor_modes if has_proxy_slot(arg, m.tracer)
)
set_proxy_slot(arg, tracer, params_flat[i])
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=lambda gm, inputs: None,
partition_fn=partition_fn,
decompositions=CORE_ATEN_DECOMPOSITIONS_TABLE, # type: ignore[arg-type]
num_params_buffers=params_len,
aot_id=-1,
keep_inference_input_mutations=False,
)
@contextlib.contextmanager
def setup_dynamic_shape():
prev, torch._functorch.config.use_dynamic_shapes = (
torch._functorch.config.use_dynamic_shapes,
True,
)
try:
yield
finally:
torch._functorch.config.use_dynamic_shapes = prev
def exported_call(*args):
state_args = args[:params_len]
unwrapped_state_args = _unwrap_all_tensors_from_functional(
state_args, reapply_views=False
)
set_state_proxies(unwrapped_state_args)
with torch.fx.traceback.preserve_node_meta():
outputs = functional_call(*args)
nonlocal out_spec
outputs, out_spec = pytree.tree_flatten(outputs)
return outputs
with torch.enable_grad(), setup_dynamic_shape():
create_aot_dispatcher_function(
exported_call,
full_args,
aot_config,
)
assert graph_module is not None
for i, node in enumerate(graph_module.graph.nodes):
if i == len(params_flat):
break
assert node.op == "placeholder" and len(node.users) == 0
graph_module.graph.erase_node(node)
output_node = next(iter(reversed(graph_module.graph.nodes)))
assert output_node.op == "output" and len(output_node.args) == 1
assert num_fwd_returns is not None
# Turncate the output so we only output what we need.
output_node.args = (
output_node.args[0][
: len(mutated_input_indices) + len(fw_metadata.output_info)
],
)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
def find_mutation_destinations(gm, w):
assert isinstance(w, torch.Tensor)
ret = [
name for name, x in [*gm.named_parameters(), *gm.named_buffers()] if x is w
]
assert len(ret) != 0, "Cannot find mutation destination."
return ret
mutation = [
(
"copy_",
output_node.args[0][k].name,
find_mutation_destinations(graph_module, param_list[i][1]),
)
for k, i in enumerate(mutated_input_indices)
]
assert out_spec is not None
return graph_module, mutation, out_spec
@patch.object(torchdynamo.config, "dynamic_shapes", True)
@patch.object(torchdynamo.config, "capture_scalar_outputs", True)
@patch.object(torchdynamo.config, "guard_nn_modules", True)
@patch.object(torchdynamo.config, "specialize_int_float", True)
@patch.object(torchdynamo.config, "allow_rnn", True)
@patch.object(torchdynamo.config, "verbose", True)
def do_not_use_experimental_export(f: Callable, args: Tuple, training=False):
"""
This prototype is under heavy development. Pls don't use it if you are
not part of PyTorch 2.0 Export team.
"""
if training:
NotImplementedError("training mode is not supported yet")
flattened_args, in_spec = pytree.tree_flatten(args)
# Doing it twice so that if graph_module accidentally modifies the input
# we still get the same original input.
original_flat_args = tuple(flattened_args)
flat_args = tuple(flattened_args)
graph_module, guards = torchdynamo.export(f, *args, aten_graph=False)
# TODO (tmanlaibaatar) do sth with guards?
graph_module, _, out_spec = _aot_capture(graph_module, flat_args)
return ExportedProgram(fw_module=graph_module, example_inputs=original_flat_args, in_spec=in_spec, out_spec=out_spec)

View File

@ -5,6 +5,14 @@ import logging
logger = logging.getLogger(__name__)
__all__ = [
"PassManager",
"inplace_wrapper",
"log_hook",
"loop_pass",
"this_before_that_pass_constraint",
"these_before_those_pass_constraint",
]
# for callables which modify object inplace and return something other than
# the object on which they act