mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 23:15:01 +08:00
[export] Support user input mutation. [1/2] (#114496)
Summary: Serialization not implemented yet. Will do in the next diff. Resolving Github issues: https://github.com/pytorch/pytorch/issues/112429 https://github.com/pytorch/pytorch/issues/114142 Test Plan: buck2 run mode/opt caffe2/test:test_export -- -r test_export_ input_mutation Differential Revision: D51556962 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114496 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
624f202522
commit
b62c0d96bc
@ -1,14 +1,15 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch.export import export
|
||||
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
|
||||
from torch._export.db.examples import (
|
||||
filter_examples_by_support_level,
|
||||
get_rewrite_cases,
|
||||
)
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -28,18 +29,19 @@ class ExampleTests(TestCase):
|
||||
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
|
||||
model = case.model
|
||||
|
||||
inputs = normalize_inputs(case.example_inputs)
|
||||
inputs_export = normalize_inputs(case.example_inputs)
|
||||
inputs_model = copy.deepcopy(inputs_export)
|
||||
exported_program = export(
|
||||
model,
|
||||
inputs.args,
|
||||
inputs.kwargs,
|
||||
inputs_export.args,
|
||||
inputs_export.kwargs,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
)
|
||||
exported_program.graph_module.print_readable()
|
||||
|
||||
self.assertEqual(
|
||||
exported_program(*inputs.args, **inputs.kwargs),
|
||||
model(*inputs.args, **inputs.kwargs),
|
||||
exported_program(*inputs_export.args, **inputs_export.kwargs),
|
||||
model(*inputs_model.args, **inputs_model.kwargs),
|
||||
)
|
||||
|
||||
if case.extra_inputs is not None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
# flake8: noqa
|
||||
import copy
|
||||
import dataclasses
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
@ -1092,13 +1093,13 @@ class TestExport(TestCase):
|
||||
torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5))
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time"
|
||||
RuntimeError, "is specialized to be 5 at tracing time"
|
||||
):
|
||||
_ = exported(torch.ones(8, 5), 6)
|
||||
|
||||
exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time"
|
||||
RuntimeError, "is specialized to be 5.0 at tracing time"
|
||||
):
|
||||
_ = exported(torch.ones(7, 5), 6.0)
|
||||
|
||||
@ -1109,7 +1110,7 @@ class TestExport(TestCase):
|
||||
|
||||
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
|
||||
exported = torch._export.export(g, inps)
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"):
|
||||
with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"):
|
||||
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
|
||||
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))
|
||||
|
||||
@ -1190,7 +1191,7 @@ class TestExport(TestCase):
|
||||
dim0_x = torch.export.Dim("dim0_x")
|
||||
exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}})
|
||||
reexported = torch.export.export(exported, (inp,))
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg2_1\.shape\[0\] is specialized at 5"):
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"):
|
||||
reexported(torch.ones(7, 5))
|
||||
|
||||
reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},))
|
||||
@ -1199,7 +1200,7 @@ class TestExport(TestCase):
|
||||
# can't retrace with invalid inputs with respect to the original ExportedProgram
|
||||
dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3)
|
||||
exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}})
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg2_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"):
|
||||
torch.export.export(exported_v2, (torch.randn(2, 2),))
|
||||
|
||||
def test_retrace_graph_level_meta_preservation(self):
|
||||
@ -1472,8 +1473,8 @@ class TestExport(TestCase):
|
||||
|
||||
ep = export(f, (torch.tensor([3]),))
|
||||
self.assertExpectedInline(str(ep.graph_module.code).strip(), """\
|
||||
def forward(self, arg0_1):
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
||||
def forward(self, l_x_):
|
||||
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None
|
||||
ge = _local_scalar_dense >= 0
|
||||
scalar_tensor = torch.ops.aten.scalar_tensor.default(ge); ge = None
|
||||
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, inf].'); scalar_tensor = None
|
||||
@ -1492,7 +1493,7 @@ def forward(self, arg0_1):
|
||||
self.assertEqual(ep(*test_inp), foo(*test_inp))
|
||||
|
||||
ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None))
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"):
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
|
||||
ep_v2(*test_inp)
|
||||
|
||||
def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
|
||||
@ -1540,7 +1541,7 @@ def forward(self, arg0_1):
|
||||
test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]"
|
||||
"shape\[0\] is outside of specified dynamic range \[3, inf\]"
|
||||
):
|
||||
ep(*test_inp)
|
||||
|
||||
@ -1721,6 +1722,42 @@ def forward(self, arg0_1):
|
||||
optimized_model = torch.compile(exported_model)
|
||||
optimized_model(tensor_cpu, mask_cpu)
|
||||
|
||||
def test_export_input_mutation_static_shape(self):
|
||||
class MutationModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x.view(3, 2, -1).add_(y)
|
||||
return x
|
||||
inputs = (torch.randn(12), 2.0)
|
||||
model = MutationModel()
|
||||
ep = torch.export.export(model, inputs)
|
||||
inputs_export = copy.deepcopy(inputs)
|
||||
inputs_model = copy.deepcopy(inputs)
|
||||
self.assertEqual(ep(*inputs_export), model(*inputs_model))
|
||||
self.assertEqual(inputs[0] + 2.0, inputs_model[0])
|
||||
self.assertEqual(inputs[0] + 2.0, inputs_export[0])
|
||||
|
||||
def test_export_input_mutation_dynamic_shape(self):
|
||||
class MutationModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x[0].mul_(y)
|
||||
return x
|
||||
inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0)
|
||||
model = MutationModel()
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
dynamic_shapes={'x': ({0: torch.export.Dim("dim")}, None), "y": None}
|
||||
)
|
||||
nodes = list(ep.graph.nodes)
|
||||
self.assertEqual(nodes[0].op, "placeholder")
|
||||
self.assertIsInstance(nodes[0].meta['val'], torch.Tensor)
|
||||
self.assertIsInstance(nodes[0].meta['val'].shape[0], torch.SymInt)
|
||||
|
||||
inputs_export = copy.deepcopy(inputs)
|
||||
inputs_model = copy.deepcopy(inputs)
|
||||
self.assertEqual(ep(*inputs_export), model(*inputs_model))
|
||||
self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0])
|
||||
self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -76,7 +76,7 @@ class TestPasses(TestCase):
|
||||
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
|
||||
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
ep(torch.zeros(2, 7, 3))
|
||||
|
||||
self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
|
||||
@ -99,10 +99,10 @@ class TestPasses(TestCase):
|
||||
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg1_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
def test_runtime_assert_some_dims_not_specified(self) -> None:
|
||||
@ -123,12 +123,12 @@ class TestPasses(TestCase):
|
||||
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
# y is specialized to 5
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"
|
||||
RuntimeError, r"shape\[0\] is specialized at 5"
|
||||
):
|
||||
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
@ -152,12 +152,12 @@ class TestPasses(TestCase):
|
||||
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
|
||||
ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}})
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"):
|
||||
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
# y is specialized to 5
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"
|
||||
RuntimeError, r"shape\[0\] is specialized at 5"
|
||||
):
|
||||
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
@ -302,34 +302,6 @@ class TestPasses(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
|
||||
ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
|
||||
|
||||
def test_runtime_assert_equality_constraint(self):
|
||||
class Adder(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
m = Adder()
|
||||
x = torch.rand(3, 4)
|
||||
y = torch.rand(3, 4)
|
||||
dim1 = torch.export.Dim("dim1")
|
||||
exported = torch.export.export(
|
||||
m, (x, y), dynamic_shapes={"x": {1: dim1}, "y": {1: dim1}}
|
||||
)
|
||||
|
||||
x = torch.rand(3, 5)
|
||||
y = torch.rand(3, 6)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]"
|
||||
):
|
||||
exported(x, y)
|
||||
|
||||
y = torch.rand(3, 5)
|
||||
dynamo_result = exported(x, y)
|
||||
real_result = m(x, y)
|
||||
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
||||
|
||||
def test_functionalize_inline_contraints(self) -> None:
|
||||
def f(x):
|
||||
a = x.item()
|
||||
|
||||
@ -44,6 +44,7 @@ def get_filtered_export_db_tests():
|
||||
"dictionary", # Graph output must be a tuple()
|
||||
"fn_with_kwargs", # export doesn't support kwargs yet
|
||||
"scalar_output", # Tracing through 'f' must produce a single graph
|
||||
"user_input_mutation", # TODO(zhxchen17) Support serializing user inputs mutation.
|
||||
}
|
||||
|
||||
return [
|
||||
|
||||
@ -277,11 +277,11 @@ class TestUnflatten(TestCase):
|
||||
return a
|
||||
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"):
|
||||
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
|
||||
export_module(torch.randn(6, 6))
|
||||
|
||||
unflattened = export_module.module(flat=False)
|
||||
with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"):
|
||||
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
|
||||
unflattened(torch.randn(6, 6))
|
||||
|
||||
|
||||
|
||||
@ -30,29 +30,29 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
from torch._dynamo.source import ConstantSource
|
||||
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
|
||||
from torch._functorch.eager_transforms import functionalize
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._ops import OpOverload
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.export import _create_constraint, _Dim, Constraint
|
||||
from torch.export.graph_signature import (
|
||||
ExportGraphSignature,
|
||||
_sig_to_specs,
|
||||
ArgumentSpec,
|
||||
ConstantArgument,
|
||||
InputKind,
|
||||
OutputKind,
|
||||
OutputSpec,
|
||||
SymIntArgument,
|
||||
TensorArgument,
|
||||
InputSpec
|
||||
)
|
||||
from torch.export.exported_program import (
|
||||
ExportedProgram,
|
||||
ModuleCallEntry,
|
||||
ModuleCallSignature,
|
||||
)
|
||||
from torch.export.graph_signature import (
|
||||
_sig_to_specs,
|
||||
ArgumentSpec,
|
||||
ConstantArgument,
|
||||
ExportGraphSignature,
|
||||
InputKind,
|
||||
InputSpec,
|
||||
OutputKind,
|
||||
OutputSpec,
|
||||
SymIntArgument,
|
||||
TensorArgument,
|
||||
)
|
||||
from torch.fx import traceback as fx_traceback
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
@ -559,6 +559,88 @@ def export(
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_module(
|
||||
gm_torch_level: torch.fx.GraphModule,
|
||||
aot_export_args
|
||||
) -> List[str]:
|
||||
flat_args = pytree.tree_leaves(aot_export_args)
|
||||
user_input_names = []
|
||||
with gm_torch_level.graph.inserting_before():
|
||||
for i, (arg, node) in enumerate(zip(flat_args, gm_torch_level.graph.nodes)):
|
||||
assert node.op == "placeholder"
|
||||
user_input_names.append(node.name)
|
||||
if isinstance(arg, torch.Tensor):
|
||||
assert not hasattr(gm_torch_level, node.name)
|
||||
gm_torch_level.register_buffer(node.name, arg)
|
||||
get_attr = gm_torch_level.graph.get_attr(node.name)
|
||||
node.replace_all_uses_with(get_attr)
|
||||
get_attr.meta = copy.copy(node.meta)
|
||||
|
||||
for node in list(gm_torch_level.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
assert len(node.users) == 0
|
||||
gm_torch_level.graph.erase_node(node)
|
||||
gm_torch_level.recompile()
|
||||
return user_input_names
|
||||
|
||||
|
||||
def _unwrap_user_inputs(
|
||||
gm: torch.fx.GraphModule,
|
||||
graph_signature: GraphSignature,
|
||||
user_input_names: List[str]
|
||||
) -> Dict[str, str]:
|
||||
assert len(graph_signature.user_inputs) == 0
|
||||
assert graph_signature.backward_signature is None
|
||||
names = set(user_input_names)
|
||||
|
||||
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
||||
# user inputs are always added in the end
|
||||
start = len(graph_signature.parameters)
|
||||
end = start + len(graph_signature.buffers)
|
||||
buffer_nodes = placeholders[start:end]
|
||||
last_placeholder_node = placeholders[-1] if len(placeholders) > 0 else None
|
||||
old_nodes: Dict[str, torch.fx.Node] = {}
|
||||
for node in buffer_nodes:
|
||||
buffer_name = graph_signature.inputs_to_buffers[node.name]
|
||||
if buffer_name not in names:
|
||||
continue
|
||||
old_nodes[buffer_name] = node
|
||||
replaces = {}
|
||||
new_node_names: Dict[str, str] = {}
|
||||
with gm.graph.inserting_after(last_placeholder_node):
|
||||
for name in reversed(user_input_names):
|
||||
new_node = gm.graph.placeholder(name)
|
||||
new_node.target = new_node.name
|
||||
new_node_names[name] = new_node.name
|
||||
if name in old_nodes:
|
||||
old_node = old_nodes[name]
|
||||
new_node.meta = copy.copy(old_node.meta)
|
||||
old_node.replace_all_uses_with(new_node)
|
||||
replaces[old_node.name] = new_node.name
|
||||
|
||||
for old_node in old_nodes.values():
|
||||
gm.graph.erase_node(old_node)
|
||||
|
||||
gm.recompile()
|
||||
|
||||
graph_signature.buffers = [b for b in graph_signature.buffers if b not in names]
|
||||
graph_signature.inputs_to_buffers = {
|
||||
i: b for i, b in graph_signature.inputs_to_buffers.items() if b not in names
|
||||
}
|
||||
user_inputs_to_mutate = {
|
||||
o: b for o, b in graph_signature.buffers_to_mutate.items() if b in names
|
||||
}
|
||||
graph_signature.buffers_to_mutate = {
|
||||
o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names
|
||||
}
|
||||
graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type]
|
||||
graph_signature.user_outputs = [
|
||||
replaces[o] if o in replaces else o for o in graph_signature.user_outputs
|
||||
]
|
||||
return user_inputs_to_mutate # type: ignore[return-value]
|
||||
|
||||
|
||||
def _disable_prexisiting_fake_mode(fn):
|
||||
|
||||
@functools.wraps(fn)
|
||||
@ -703,6 +785,10 @@ def _export(
|
||||
if isinstance(f, torch.nn.Module):
|
||||
_normalize_nn_module_stack(gm_torch_level, type(f))
|
||||
|
||||
aot_export_args = (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values())
|
||||
|
||||
user_input_names = _prepare_module(gm_torch_level, aot_export_args)
|
||||
|
||||
# Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
|
||||
# to follow the order in orig_args and correctly call gm_torch_level
|
||||
|
||||
@ -712,9 +798,10 @@ def _export(
|
||||
with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers):
|
||||
gm, graph_signature = aot_export_module(
|
||||
gm_torch_level,
|
||||
(*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
|
||||
(),
|
||||
trace_joint=False
|
||||
)
|
||||
user_inputs_to_mutate = _unwrap_user_inputs(gm, graph_signature, user_input_names)
|
||||
|
||||
def to_str_list(sig_component: List[Any]):
|
||||
return [str(v) for v in sig_component]
|
||||
@ -771,6 +858,7 @@ def _export(
|
||||
is_joint = graph_signature.backward_signature is not None
|
||||
|
||||
def make_argument_spec(node) -> ArgumentSpec:
|
||||
assert "val" in node.meta, f"{node} has no 'val' metadata field"
|
||||
val = node.meta["val"]
|
||||
if isinstance(val, FakeTensor):
|
||||
return TensorArgument(name=node.name)
|
||||
@ -784,6 +872,7 @@ def _export(
|
||||
inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type]
|
||||
user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type]
|
||||
buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type]
|
||||
user_input_mutations=user_inputs_to_mutate, # type: ignore[arg-type]
|
||||
grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr]
|
||||
grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr]
|
||||
loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr]
|
||||
|
||||
@ -6,11 +6,11 @@ from torch._export.db.case import export_case, SupportLevel
|
||||
@export_case(
|
||||
example_inputs=(torch.ones(3, 2),),
|
||||
tags={"torch.mutation"},
|
||||
support_level=SupportLevel.NOT_SUPPORTED_YET,
|
||||
support_level=SupportLevel.SUPPORTED,
|
||||
)
|
||||
class UserInputMutation(torch.nn.Module):
|
||||
"""
|
||||
Can't directly mutate user input in forward
|
||||
Directly mutate user input in forward
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -230,12 +230,6 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
# Check ExportedProgram signature matches
|
||||
gs = exported_program.graph_signature
|
||||
|
||||
bs_grad_to_param = {}
|
||||
bs_grad_to_user_inputs = {}
|
||||
if gs.backward_signature is not None:
|
||||
bs_grad_to_param = gs.backward_signature.gradients_to_parameters
|
||||
bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs
|
||||
|
||||
# Check every node in the signature exists in the graph
|
||||
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
|
||||
|
||||
@ -324,19 +318,28 @@ def _verify_exported_program_signature(exported_program) -> None:
|
||||
f"Number of user outputs: {len(gs.user_outputs)}. \n"
|
||||
)
|
||||
|
||||
buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)]
|
||||
user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)]
|
||||
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate)
|
||||
mutate_nodes: List[str] = output_nodes[:end]
|
||||
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
|
||||
|
||||
for buffer_node in buffer_mutate_nodes:
|
||||
if (
|
||||
buffer_node not in gs.buffers_to_mutate or
|
||||
gs.buffers_to_mutate[buffer_node] not in gs.buffers
|
||||
):
|
||||
for mutation_node in mutate_nodes:
|
||||
if mutation_node in gs.buffers_to_mutate:
|
||||
if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
|
||||
raise SpecViolationError(
|
||||
f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
|
||||
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
|
||||
f"Buffer nodes available: {gs.buffers} \n"
|
||||
)
|
||||
elif mutation_node in gs.user_inputs_to_mutate:
|
||||
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
|
||||
raise SpecViolationError(
|
||||
f"User input output {mutation_node} does not point to a user input that exists. \n"
|
||||
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
|
||||
f"User input nodes available: {gs.user_inputs} \n")
|
||||
else:
|
||||
raise SpecViolationError(
|
||||
f"Buffer output {buffer_node} is not in buffer mutation dictionary "
|
||||
"or, it does not point to a buffer that exists. \n"
|
||||
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
|
||||
f"Buffer nodes available: {gs.buffers} \n"
|
||||
f"Mutation node {mutation_node} is neither a buffer nor a user input. "
|
||||
f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
|
||||
)
|
||||
|
||||
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
|
||||
|
||||
@ -277,9 +277,10 @@ class ExportedProgram:
|
||||
)
|
||||
|
||||
if self.call_spec.out_spec is not None:
|
||||
mutation = self.graph_signature.buffers_to_mutate
|
||||
num_mutated = len(mutation)
|
||||
mutated_buffers = res[:num_mutated]
|
||||
buffer_mutation = self.graph_signature.buffers_to_mutate
|
||||
user_input_mutation = self.graph_signature.user_inputs_to_mutate
|
||||
num_mutated = len(buffer_mutation) + len(user_input_mutation)
|
||||
mutated_values = res[:num_mutated]
|
||||
|
||||
# Exclude dependency token from final result.
|
||||
assertion_dep_token = self.graph_signature.assertion_dep_token
|
||||
@ -299,10 +300,27 @@ class ExportedProgram:
|
||||
f"{received_spec}"
|
||||
)
|
||||
finally:
|
||||
ix = 0
|
||||
for buffer in self.graph_signature.buffers_to_mutate.values():
|
||||
self.state_dict[buffer] = mutated_buffers[ix]
|
||||
ix += 1
|
||||
user_inputs = [
|
||||
spec
|
||||
for spec in self.graph_signature.input_specs
|
||||
if spec.kind == InputKind.USER_INPUT
|
||||
]
|
||||
for i, value in enumerate(mutated_values):
|
||||
output_spec = self.graph_signature.output_specs[i]
|
||||
if output_spec.kind == OutputKind.BUFFER_MUTATION:
|
||||
assert output_spec.target is not None
|
||||
self.state_dict[output_spec.target] = value
|
||||
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
|
||||
assert output_spec.target is not None
|
||||
index = next(
|
||||
i
|
||||
for i, spec in enumerate(user_inputs)
|
||||
if spec.arg.name == output_spec.target
|
||||
)
|
||||
args[index].copy_(value)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
|
||||
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -365,7 +383,6 @@ class ExportedProgram:
|
||||
decomp_table = decomp_table or core_aten_decompositions()
|
||||
|
||||
old_placeholders = _get_placeholders(self.graph_module)
|
||||
old_outputs = list(self.graph.nodes)[-1].args[0]
|
||||
fake_args = [node.meta["val"] for node in old_placeholders]
|
||||
|
||||
buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]
|
||||
|
||||
@ -57,6 +57,7 @@ class OutputKind(Enum):
|
||||
BUFFER_MUTATION = auto()
|
||||
GRADIENT_TO_PARAMETER = auto()
|
||||
GRADIENT_TO_USER_INPUT = auto()
|
||||
USER_INPUT_MUTATION = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -76,6 +77,7 @@ def _sig_to_specs(
|
||||
inputs_to_buffers: Mapping[str, str],
|
||||
user_outputs: Set[str],
|
||||
buffer_mutations: Mapping[str, str],
|
||||
user_input_mutations: Mapping[str, str],
|
||||
grad_params: Mapping[str, str],
|
||||
grad_user_inputs: Mapping[str, str],
|
||||
loss_output: Optional[str],
|
||||
@ -101,37 +103,49 @@ def _sig_to_specs(
|
||||
else:
|
||||
raise AssertionError(f"Unknown tensor input kind: {name}")
|
||||
|
||||
def to_output_spec(o: ArgumentSpec) -> OutputSpec:
|
||||
def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
|
||||
if not isinstance(o, TensorArgument):
|
||||
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
|
||||
name = o.name
|
||||
if name in user_outputs:
|
||||
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
|
||||
elif name in buffer_mutations:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.BUFFER_MUTATION,
|
||||
arg=o,
|
||||
target=buffer_mutations[name],
|
||||
)
|
||||
elif name in grad_params:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.GRADIENT_TO_PARAMETER,
|
||||
arg=o,
|
||||
target=grad_params[name],
|
||||
)
|
||||
elif name in grad_user_inputs:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.GRADIENT_TO_USER_INPUT,
|
||||
arg=o,
|
||||
target=grad_user_inputs[name],
|
||||
)
|
||||
elif name == loss_output:
|
||||
return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
|
||||
if idx < len(buffer_mutations) + len(user_input_mutations):
|
||||
if name in buffer_mutations:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.BUFFER_MUTATION,
|
||||
arg=o,
|
||||
target=buffer_mutations[name],
|
||||
)
|
||||
elif name in user_input_mutations:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.USER_INPUT_MUTATION,
|
||||
arg=o,
|
||||
target=user_input_mutations[name],
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown tensor mutation kind: {name}")
|
||||
else:
|
||||
raise AssertionError(f"Unknown tensor output kind: {name}")
|
||||
if name in user_outputs:
|
||||
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
|
||||
|
||||
elif name in grad_params:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.GRADIENT_TO_PARAMETER,
|
||||
arg=o,
|
||||
target=grad_params[name],
|
||||
)
|
||||
elif name in grad_user_inputs:
|
||||
return OutputSpec(
|
||||
kind=OutputKind.GRADIENT_TO_USER_INPUT,
|
||||
arg=o,
|
||||
target=grad_user_inputs[name],
|
||||
)
|
||||
elif name == loss_output:
|
||||
return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown tensor output kind: {name}")
|
||||
|
||||
input_specs = [to_input_spec(i) for i in inputs]
|
||||
output_specs = [to_output_spec(o) for o in outputs]
|
||||
output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
|
||||
return input_specs, output_specs
|
||||
|
||||
|
||||
@ -304,6 +318,16 @@ class ExportGraphSignature:
|
||||
and isinstance(s.target, str)
|
||||
}
|
||||
|
||||
@property
|
||||
def user_inputs_to_mutate(self) -> Mapping[str, str]:
|
||||
return {
|
||||
s.arg.name: s.target
|
||||
for s in self.output_specs
|
||||
if s.kind == OutputKind.USER_INPUT_MUTATION
|
||||
and isinstance(s.arg, TensorArgument)
|
||||
and isinstance(s.target, str)
|
||||
}
|
||||
|
||||
# A dictionary mapping graph input node names to lifted tensor constants.
|
||||
@property
|
||||
def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]:
|
||||
|
||||
Reference in New Issue
Block a user