[export] Support user input mutation. [1/2] (#114496)

Summary:
Serialization not implemented yet. Will do in the next diff.

Resolving Github issues:
https://github.com/pytorch/pytorch/issues/112429
https://github.com/pytorch/pytorch/issues/114142

Test Plan:
buck2 run mode/opt caffe2/test:test_export -- -r test_export_
input_mutation

Differential Revision: D51556962

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114496
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2023-11-27 04:53:38 +00:00
committed by PyTorch MergeBot
parent 624f202522
commit b62c0d96bc
10 changed files with 263 additions and 118 deletions

View File

@ -277,9 +277,10 @@ class ExportedProgram:
)
if self.call_spec.out_spec is not None:
mutation = self.graph_signature.buffers_to_mutate
num_mutated = len(mutation)
mutated_buffers = res[:num_mutated]
buffer_mutation = self.graph_signature.buffers_to_mutate
user_input_mutation = self.graph_signature.user_inputs_to_mutate
num_mutated = len(buffer_mutation) + len(user_input_mutation)
mutated_values = res[:num_mutated]
# Exclude dependency token from final result.
assertion_dep_token = self.graph_signature.assertion_dep_token
@ -299,10 +300,27 @@ class ExportedProgram:
f"{received_spec}"
)
finally:
ix = 0
for buffer in self.graph_signature.buffers_to_mutate.values():
self.state_dict[buffer] = mutated_buffers[ix]
ix += 1
user_inputs = [
spec
for spec in self.graph_signature.input_specs
if spec.kind == InputKind.USER_INPUT
]
for i, value in enumerate(mutated_values):
output_spec = self.graph_signature.output_specs[i]
if output_spec.kind == OutputKind.BUFFER_MUTATION:
assert output_spec.target is not None
self.state_dict[output_spec.target] = value
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
assert output_spec.target is not None
index = next(
i
for i, spec in enumerate(user_inputs)
if spec.arg.name == output_spec.target
)
args[index].copy_(value)
else:
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
return res
def __str__(self) -> str:
@ -365,7 +383,6 @@ class ExportedProgram:
decomp_table = decomp_table or core_aten_decompositions()
old_placeholders = _get_placeholders(self.graph_module)
old_outputs = list(self.graph.nodes)[-1].args[0]
fake_args = [node.meta["val"] for node in old_placeholders]
buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]