mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[export] Support user input mutation. [1/2] (#114496)
Summary: Serialization not implemented yet. Will do in the next diff. Resolving Github issues: https://github.com/pytorch/pytorch/issues/112429 https://github.com/pytorch/pytorch/issues/114142 Test Plan: buck2 run mode/opt caffe2/test:test_export -- -r test_export_ input_mutation Differential Revision: D51556962 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114496 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
624f202522
commit
b62c0d96bc
@ -277,9 +277,10 @@ class ExportedProgram:
|
||||
)
|
||||
|
||||
if self.call_spec.out_spec is not None:
|
||||
mutation = self.graph_signature.buffers_to_mutate
|
||||
num_mutated = len(mutation)
|
||||
mutated_buffers = res[:num_mutated]
|
||||
buffer_mutation = self.graph_signature.buffers_to_mutate
|
||||
user_input_mutation = self.graph_signature.user_inputs_to_mutate
|
||||
num_mutated = len(buffer_mutation) + len(user_input_mutation)
|
||||
mutated_values = res[:num_mutated]
|
||||
|
||||
# Exclude dependency token from final result.
|
||||
assertion_dep_token = self.graph_signature.assertion_dep_token
|
||||
@ -299,10 +300,27 @@ class ExportedProgram:
|
||||
f"{received_spec}"
|
||||
)
|
||||
finally:
|
||||
ix = 0
|
||||
for buffer in self.graph_signature.buffers_to_mutate.values():
|
||||
self.state_dict[buffer] = mutated_buffers[ix]
|
||||
ix += 1
|
||||
user_inputs = [
|
||||
spec
|
||||
for spec in self.graph_signature.input_specs
|
||||
if spec.kind == InputKind.USER_INPUT
|
||||
]
|
||||
for i, value in enumerate(mutated_values):
|
||||
output_spec = self.graph_signature.output_specs[i]
|
||||
if output_spec.kind == OutputKind.BUFFER_MUTATION:
|
||||
assert output_spec.target is not None
|
||||
self.state_dict[output_spec.target] = value
|
||||
elif output_spec.kind == OutputKind.USER_INPUT_MUTATION:
|
||||
assert output_spec.target is not None
|
||||
index = next(
|
||||
i
|
||||
for i, spec in enumerate(user_inputs)
|
||||
if spec.arg.name == output_spec.target
|
||||
)
|
||||
args[index].copy_(value)
|
||||
else:
|
||||
raise AssertionError(f"Unexpected kind: {output_spec.kind}")
|
||||
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -365,7 +383,6 @@ class ExportedProgram:
|
||||
decomp_table = decomp_table or core_aten_decompositions()
|
||||
|
||||
old_placeholders = _get_placeholders(self.graph_module)
|
||||
old_outputs = list(self.graph.nodes)[-1].args[0]
|
||||
fake_args = [node.meta["val"] for node in old_placeholders]
|
||||
|
||||
buffers_to_remove = [name for name, _ in self.graph_module.named_buffers()]
|
||||
|
Reference in New Issue
Block a user