[export] Support user input mutation. [1/2] (#114496)

Summary:
Serialization not implemented yet. Will do in the next diff.

Resolving Github issues:
https://github.com/pytorch/pytorch/issues/112429
https://github.com/pytorch/pytorch/issues/114142

Test Plan:
buck2 run mode/opt caffe2/test:test_export -- -r test_export_
input_mutation

Differential Revision: D51556962

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114496
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2023-11-27 04:53:38 +00:00
committed by PyTorch MergeBot
parent 624f202522
commit b62c0d96bc
10 changed files with 263 additions and 118 deletions

View File

@ -1,14 +1,15 @@
# Owner(s): ["module: dynamo"]
import copy
import unittest
import torch._dynamo as torchdynamo
from torch.export import export
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import (
filter_examples_by_support_level,
get_rewrite_cases,
)
from torch.export import export
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
@ -28,18 +29,19 @@ class ExampleTests(TestCase):
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
inputs = normalize_inputs(case.example_inputs)
inputs_export = normalize_inputs(case.example_inputs)
inputs_model = copy.deepcopy(inputs_export)
exported_program = export(
model,
inputs.args,
inputs.kwargs,
inputs_export.args,
inputs_export.kwargs,
dynamic_shapes=case.dynamic_shapes,
)
exported_program.graph_module.print_readable()
self.assertEqual(
exported_program(*inputs.args, **inputs.kwargs),
model(*inputs.args, **inputs.kwargs),
exported_program(*inputs_export.args, **inputs_export.kwargs),
model(*inputs_model.args, **inputs_model.kwargs),
)
if case.extra_inputs is not None:

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa
import copy
import dataclasses
import unittest
from contextlib import contextmanager
@ -1092,13 +1093,13 @@ class TestExport(TestCase):
torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5))
)
with self.assertRaisesRegex(
RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time"
RuntimeError, "is specialized to be 5 at tracing time"
):
_ = exported(torch.ones(8, 5), 6)
exported = torch.export.export(f, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes)
with self.assertRaisesRegex(
RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time"
RuntimeError, "is specialized to be 5.0 at tracing time"
):
_ = exported(torch.ones(7, 5), 6.0)
@ -1109,7 +1110,7 @@ class TestExport(TestCase):
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = torch._export.export(g, inps)
with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"):
with self.assertRaisesRegex(RuntimeError, "is specialized to be trunc at"):
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))
@ -1190,7 +1191,7 @@ class TestExport(TestCase):
dim0_x = torch.export.Dim("dim0_x")
exported = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x}})
reexported = torch.export.export(exported, (inp,))
with self.assertRaisesRegex(RuntimeError, "Input arg2_1\.shape\[0\] is specialized at 5"):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 5"):
reexported(torch.ones(7, 5))
reexported = torch.export.export(exported, (inp,), dynamic_shapes=({0: dim0_x},))
@ -1199,7 +1200,7 @@ class TestExport(TestCase):
# can't retrace with invalid inputs with respect to the original ExportedProgram
dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3)
exported_v2 = torch.export.export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}})
with self.assertRaisesRegex(RuntimeError, "Input arg2_1"):
with self.assertRaisesRegex(RuntimeError, "shape\[1\] is specialized at 5"):
torch.export.export(exported_v2, (torch.randn(2, 2),))
def test_retrace_graph_level_meta_preservation(self):
@ -1472,8 +1473,8 @@ class TestExport(TestCase):
ep = export(f, (torch.tensor([3]),))
self.assertExpectedInline(str(ep.graph_module.code).strip(), """\
def forward(self, arg0_1):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
def forward(self, l_x_):
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(l_x_); l_x_ = None
ge = _local_scalar_dense >= 0
scalar_tensor = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async = torch.ops.aten._assert_async.msg(scalar_tensor, '_local_scalar_dense is outside of inline constraint [0, inf].'); scalar_tensor = None
@ -1492,7 +1493,7 @@ def forward(self, arg0_1):
self.assertEqual(ep(*test_inp), foo(*test_inp))
ep_v2 = torch.export.export(foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, None))
with self.assertRaisesRegex(RuntimeError, "Input arg1_1.shape\[0\] is specialized at 4"):
with self.assertRaisesRegex(RuntimeError, "shape\[0\] is specialized at 4"):
ep_v2(*test_inp)
def test_non_arg_name_dynamic_shapes_api_with_kwarg(self):
@ -1540,7 +1541,7 @@ def forward(self, arg0_1):
test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
with self.assertRaisesRegex(
RuntimeError,
"Input arg1_1.shape\[0\] is outside of specified dynamic range \[3, inf\]"
"shape\[0\] is outside of specified dynamic range \[3, inf\]"
):
ep(*test_inp)
@ -1721,6 +1722,42 @@ def forward(self, arg0_1):
optimized_model = torch.compile(exported_model)
optimized_model(tensor_cpu, mask_cpu)
def test_export_input_mutation_static_shape(self):
class MutationModel(torch.nn.Module):
def forward(self, x, y):
x.view(3, 2, -1).add_(y)
return x
inputs = (torch.randn(12), 2.0)
model = MutationModel()
ep = torch.export.export(model, inputs)
inputs_export = copy.deepcopy(inputs)
inputs_model = copy.deepcopy(inputs)
self.assertEqual(ep(*inputs_export), model(*inputs_model))
self.assertEqual(inputs[0] + 2.0, inputs_model[0])
self.assertEqual(inputs[0] + 2.0, inputs_export[0])
def test_export_input_mutation_dynamic_shape(self):
class MutationModel(torch.nn.Module):
def forward(self, x, y):
x[0].mul_(y)
return x
inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0)
model = MutationModel()
ep = torch.export.export(
model,
inputs,
dynamic_shapes={'x': ({0: torch.export.Dim("dim")}, None), "y": None}
)
nodes = list(ep.graph.nodes)
self.assertEqual(nodes[0].op, "placeholder")
self.assertIsInstance(nodes[0].meta['val'], torch.Tensor)
self.assertIsInstance(nodes[0].meta['val'].shape[0], torch.SymInt)
inputs_export = copy.deepcopy(inputs)
inputs_model = copy.deepcopy(inputs)
self.assertEqual(ep(*inputs_export), model(*inputs_model))
self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0])
self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0])
if __name__ == '__main__':
run_tests()

View File

@ -76,7 +76,7 @@ class TestPasses(TestCase):
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
ep(torch.zeros(2, 7, 3))
self.assertEqual(ep(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)))
@ -99,10 +99,10 @@ class TestPasses(TestCase):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
with self.assertRaisesRegex(RuntimeError, "Input arg1_1"):
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
def test_runtime_assert_some_dims_not_specified(self) -> None:
@ -123,12 +123,12 @@ class TestPasses(TestCase):
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
)
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, "is outside of specified dynamic range"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"
RuntimeError, r"shape\[0\] is specialized at 5"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
@ -152,12 +152,12 @@ class TestPasses(TestCase):
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
ep = torch.export.export(M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}})
with self.assertRaisesRegex(RuntimeError, "Input arg0_1"):
with self.assertRaisesRegex(RuntimeError, r"shape\[1\] is specialized at 2"):
ep(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError, r"Input arg1_1.shape\[0\] is specialized at 5"
RuntimeError, r"shape\[0\] is specialized at 5"
):
ep(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
@ -302,34 +302,6 @@ class TestPasses(TestCase):
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
def test_runtime_assert_equality_constraint(self):
class Adder(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
m = Adder()
x = torch.rand(3, 4)
y = torch.rand(3, 4)
dim1 = torch.export.Dim("dim1")
exported = torch.export.export(
m, (x, y), dynamic_shapes={"x": {1: dim1}, "y": {1: dim1}}
)
x = torch.rand(3, 5)
y = torch.rand(3, 6)
with self.assertRaisesRegex(
RuntimeError, r"Input arg0_1.shape\[1\] is not equal to input arg1_1.shape\[1\]"
):
exported(x, y)
y = torch.rand(3, 5)
dynamo_result = exported(x, y)
real_result = m(x, y)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_functionalize_inline_contraints(self) -> None:
def f(x):
a = x.item()

View File

@ -44,6 +44,7 @@ def get_filtered_export_db_tests():
"dictionary", # Graph output must be a tuple()
"fn_with_kwargs", # export doesn't support kwargs yet
"scalar_output", # Tracing through 'f' must produce a single graph
"user_input_mutation", # TODO(zhxchen17) Support serializing user inputs mutation.
}
return [

View File

@ -277,11 +277,11 @@ class TestUnflatten(TestCase):
return a
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"):
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
export_module(torch.randn(6, 6))
unflattened = export_module.module(flat=False)
with self.assertRaisesRegex(RuntimeError, "Input arg4_1.shape"):
with self.assertRaisesRegex(RuntimeError, ".shape\[1\] is specialized at 3"):
unflattened(torch.randn(6, 6))

View File

@ -30,29 +30,29 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.exc import UserError, UserErrorType
from torch._dynamo.source import ConstantSource
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
from torch._functorch.eager_transforms import functionalize
from torch._guards import detect_fake_mode
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.export import _create_constraint, _Dim, Constraint
from torch.export.graph_signature import (
ExportGraphSignature,
_sig_to_specs,
ArgumentSpec,
ConstantArgument,
InputKind,
OutputKind,
OutputSpec,
SymIntArgument,
TensorArgument,
InputSpec
)
from torch.export.exported_program import (
ExportedProgram,
ModuleCallEntry,
ModuleCallSignature,
)
from torch.export.graph_signature import (
_sig_to_specs,
ArgumentSpec,
ConstantArgument,
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
SymIntArgument,
TensorArgument,
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
@ -559,6 +559,88 @@ def export(
preserve_module_call_signature=preserve_module_call_signature,
)
def _prepare_module(
gm_torch_level: torch.fx.GraphModule,
aot_export_args
) -> List[str]:
flat_args = pytree.tree_leaves(aot_export_args)
user_input_names = []
with gm_torch_level.graph.inserting_before():
for i, (arg, node) in enumerate(zip(flat_args, gm_torch_level.graph.nodes)):
assert node.op == "placeholder"
user_input_names.append(node.name)
if isinstance(arg, torch.Tensor):
assert not hasattr(gm_torch_level, node.name)
gm_torch_level.register_buffer(node.name, arg)
get_attr = gm_torch_level.graph.get_attr(node.name)
node.replace_all_uses_with(get_attr)
get_attr.meta = copy.copy(node.meta)
for node in list(gm_torch_level.graph.nodes):
if node.op == "placeholder":
assert len(node.users) == 0
gm_torch_level.graph.erase_node(node)
gm_torch_level.recompile()
return user_input_names
def _unwrap_user_inputs(
gm: torch.fx.GraphModule,
graph_signature: GraphSignature,
user_input_names: List[str]
) -> Dict[str, str]:
assert len(graph_signature.user_inputs) == 0
assert graph_signature.backward_signature is None
names = set(user_input_names)
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
# user inputs are always added in the end
start = len(graph_signature.parameters)
end = start + len(graph_signature.buffers)
buffer_nodes = placeholders[start:end]
last_placeholder_node = placeholders[-1] if len(placeholders) > 0 else None
old_nodes: Dict[str, torch.fx.Node] = {}
for node in buffer_nodes:
buffer_name = graph_signature.inputs_to_buffers[node.name]
if buffer_name not in names:
continue
old_nodes[buffer_name] = node
replaces = {}
new_node_names: Dict[str, str] = {}
with gm.graph.inserting_after(last_placeholder_node):
for name in reversed(user_input_names):
new_node = gm.graph.placeholder(name)
new_node.target = new_node.name
new_node_names[name] = new_node.name
if name in old_nodes:
old_node = old_nodes[name]
new_node.meta = copy.copy(old_node.meta)
old_node.replace_all_uses_with(new_node)
replaces[old_node.name] = new_node.name
for old_node in old_nodes.values():
gm.graph.erase_node(old_node)
gm.recompile()
graph_signature.buffers = [b for b in graph_signature.buffers if b not in names]
graph_signature.inputs_to_buffers = {
i: b for i, b in graph_signature.inputs_to_buffers.items() if b not in names
}
user_inputs_to_mutate = {
o: b for o, b in graph_signature.buffers_to_mutate.items() if b in names
}
graph_signature.buffers_to_mutate = {
o: b for o, b in graph_signature.buffers_to_mutate.items() if b not in names
}
graph_signature.user_inputs = list(reversed(new_node_names.values())) # type: ignore[arg-type]
graph_signature.user_outputs = [
replaces[o] if o in replaces else o for o in graph_signature.user_outputs
]
return user_inputs_to_mutate # type: ignore[return-value]
def _disable_prexisiting_fake_mode(fn):
@functools.wraps(fn)
@ -703,6 +785,10 @@ def _export(
if isinstance(f, torch.nn.Module):
_normalize_nn_module_stack(gm_torch_level, type(f))
aot_export_args = (*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values())
user_input_names = _prepare_module(gm_torch_level, aot_export_args)
# Note: aot_export_module doesn't accept kwargs, we'd like to reorder the kwargs as an OrderedDict
# to follow the order in orig_args and correctly call gm_torch_level
@ -712,9 +798,10 @@ def _export(
with torch.nn.utils.stateless._reparametrize_module(gm_torch_level, fake_params_buffers):
gm, graph_signature = aot_export_module(
gm_torch_level,
(*fake_args, *_reorder_kwargs_by_names(orig_args, fake_args, fake_kwargs).values()),
(),
trace_joint=False
)
user_inputs_to_mutate = _unwrap_user_inputs(gm, graph_signature, user_input_names)
def to_str_list(sig_component: List[Any]):
return [str(v) for v in sig_component]
@ -771,6 +858,7 @@ def _export(
is_joint = graph_signature.backward_signature is not None
def make_argument_spec(node) -> ArgumentSpec:
assert "val" in node.meta, f"{node} has no 'val' metadata field"
val = node.meta["val"]
if isinstance(val, FakeTensor):
return TensorArgument(name=node.name)
@ -784,6 +872,7 @@ def _export(
inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type]
user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type]
buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type]
user_input_mutations=user_inputs_to_mutate, # type: ignore[arg-type]
grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr]
grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr]
loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr]

View File

@ -6,11 +6,11 @@ from torch._export.db.case import export_case, SupportLevel
@export_case(
example_inputs=(torch.ones(3, 2),),
tags={"torch.mutation"},
support_level=SupportLevel.NOT_SUPPORTED_YET,
support_level=SupportLevel.SUPPORTED,
)
class UserInputMutation(torch.nn.Module):
"""
Can't directly mutate user input in forward
Directly mutate user input in forward
"""
def forward(self, x):

View File

@ -230,12 +230,6 @@ def _verify_exported_program_signature(exported_program) -> None:
# Check ExportedProgram signature matches
gs = exported_program.graph_signature
bs_grad_to_param = {}
bs_grad_to_user_inputs = {}
if gs.backward_signature is not None:
bs_grad_to_param = gs.backward_signature.gradients_to_parameters
bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs
# Check every node in the signature exists in the graph
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
@ -324,19 +318,28 @@ def _verify_exported_program_signature(exported_program) -> None:
f"Number of user outputs: {len(gs.user_outputs)}. \n"
)
buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)]
user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)]
end = len(gs.buffers_to_mutate) + len(gs.user_inputs_to_mutate)
mutate_nodes: List[str] = output_nodes[:end]
user_output_nodes = output_nodes[end:end + len(gs.user_outputs)]
for buffer_node in buffer_mutate_nodes:
if (
buffer_node not in gs.buffers_to_mutate or
gs.buffers_to_mutate[buffer_node] not in gs.buffers
):
for mutation_node in mutate_nodes:
if mutation_node in gs.buffers_to_mutate:
if gs.buffers_to_mutate[mutation_node] not in gs.buffers:
raise SpecViolationError(
f"Buffer output {mutation_node} does not point to a buffer that exists. \n"
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
f"Buffer nodes available: {gs.buffers} \n"
)
elif mutation_node in gs.user_inputs_to_mutate:
if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs:
raise SpecViolationError(
f"User input output {mutation_node} does not point to a user input that exists. \n"
f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n"
f"User input nodes available: {gs.user_inputs} \n")
else:
raise SpecViolationError(
f"Buffer output {buffer_node} is not in buffer mutation dictionary "
"or, it does not point to a buffer that exists. \n"
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
f"Buffer nodes available: {gs.buffers} \n"
f"Mutation node {mutation_node} is neither a buffer nor a user input. "
f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}"
)
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):

View File

@ -277,9 +277,10 @@ class ExportedProgram:
)
if self.call_spec.out_spec is not None:
mutation = self.graph_signature.buffers_to_mutate
num_mutated = len(mutation)
mutated_buffers = res[:num_mutated]
buffer_mutation = self.graph_signature.buffers_to_mutate
user_input_mutation = self.graph_signature.user_inputs_to_mutate
num_mutated = len(buffer_mutation) + len(user_input_mutation)
mutated_values = res[:num_mutated]
# Exclude dependency token from final result.
assertion_dep_token = self.graph_signature.assertion_dep_token
@ -299,10 +300,27 @@ class ExportedProgram:
f"{received_spec}"
)
finally:
ix = 0
for buffer in self.graph_signature.buffers_to_mutate.values():
self.state_dict[buffer] = mutated_buffers[ix]
ix += 1
user_inputs = [
spec
for spec in self.graph_signature.input_specs
if spec.kind == InputKind.USER_INPUT
]
for i, value in enumerate(mutated_values):
output_spec = self.graph_signature.output_specs[i]
if output_spec.kind == OutputKind.BUFFER_MUTATION:
assert output_spec.target is not None
self.state_dict[output_spec.target] = value
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
assert output_spec.target is not None
index = next(
i
for i, spec in enumerate(user_inputs)
if spec.arg.name == output_spec.target
)
args[index].copy_(value)
else:
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
return res
def __str__(self) -> str:
@ -365,7 +383,6 @@ class ExportedProgram:
decomp_table = decomp_table or core_aten_decompositions()
old_placeholders = _get_placeholders(self.graph_module)
old_outputs = list(self.graph.nodes)[-1].args[0]
fake_args = [node.meta["val"] for node in old_placeholders]
buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]

View File

@ -57,6 +57,7 @@ class OutputKind(Enum):
BUFFER_MUTATION = auto()
GRADIENT_TO_PARAMETER = auto()
GRADIENT_TO_USER_INPUT = auto()
USER_INPUT_MUTATION = auto()
@dataclasses.dataclass
@ -76,6 +77,7 @@ def _sig_to_specs(
inputs_to_buffers: Mapping[str, str],
user_outputs: Set[str],
buffer_mutations: Mapping[str, str],
user_input_mutations: Mapping[str, str],
grad_params: Mapping[str, str],
grad_user_inputs: Mapping[str, str],
loss_output: Optional[str],
@ -101,37 +103,49 @@ def _sig_to_specs(
else:
raise AssertionError(f"Unknown tensor input kind: {name}")
def to_output_spec(o: ArgumentSpec) -> OutputSpec:
def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec:
if not isinstance(o, TensorArgument):
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
name = o.name
if name in user_outputs:
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
elif name in buffer_mutations:
return OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=o,
target=buffer_mutations[name],
)
elif name in grad_params:
return OutputSpec(
kind=OutputKind.GRADIENT_TO_PARAMETER,
arg=o,
target=grad_params[name],
)
elif name in grad_user_inputs:
return OutputSpec(
kind=OutputKind.GRADIENT_TO_USER_INPUT,
arg=o,
target=grad_user_inputs[name],
)
elif name == loss_output:
return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
if idx < len(buffer_mutations) + len(user_input_mutations):
if name in buffer_mutations:
return OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=o,
target=buffer_mutations[name],
)
elif name in user_input_mutations:
return OutputSpec(
kind=OutputKind.USER_INPUT_MUTATION,
arg=o,
target=user_input_mutations[name],
)
else:
raise AssertionError(f"Unknown tensor mutation kind: {name}")
else:
raise AssertionError(f"Unknown tensor output kind: {name}")
if name in user_outputs:
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None)
elif name in grad_params:
return OutputSpec(
kind=OutputKind.GRADIENT_TO_PARAMETER,
arg=o,
target=grad_params[name],
)
elif name in grad_user_inputs:
return OutputSpec(
kind=OutputKind.GRADIENT_TO_USER_INPUT,
arg=o,
target=grad_user_inputs[name],
)
elif name == loss_output:
return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None)
else:
raise AssertionError(f"Unknown tensor output kind: {name}")
input_specs = [to_input_spec(i) for i in inputs]
output_specs = [to_output_spec(o) for o in outputs]
output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)]
return input_specs, output_specs
@ -304,6 +318,16 @@ class ExportGraphSignature:
and isinstance(s.target, str)
}
@property
def user_inputs_to_mutate(self) -> Mapping[str, str]:
return {
s.arg.name: s.target
for s in self.output_specs
if s.kind == OutputKind.USER_INPUT_MUTATION
and isinstance(s.arg, TensorArgument)
and isinstance(s.target, str)
}
# A dictionary mapping graph input node names to lifted tensor constants.
@property
def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: