Handle recursive tuple in clone_inputs (#102979)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102979
Approved by: https://github.com/wconstab
This commit is contained in:
Edward Z. Yang
2023-06-05 09:22:04 -07:00
committed by PyTorch MergeBot
parent 4479e2fa19
commit 12cd1dbba0

View File

@ -579,8 +579,11 @@ def clone_inputs(example_inputs):
if type(example_inputs) is dict:
res = dict(example_inputs)
for key, value in res.items():
assert isinstance(value, torch.Tensor)
res[key] = clone_input(value)
if isinstance(value, tuple):
res[key] = clone_inputs(value)
else:
assert isinstance(value, torch.Tensor), type(value)
res[key] = clone_input(value)
return res
res = list(example_inputs)