mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Handle recursive tuple in clone_inputs (#102979)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/102979 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
4479e2fa19
commit
12cd1dbba0
@ -579,8 +579,11 @@ def clone_inputs(example_inputs):
|
||||
if type(example_inputs) is dict:
|
||||
res = dict(example_inputs)
|
||||
for key, value in res.items():
|
||||
assert isinstance(value, torch.Tensor)
|
||||
res[key] = clone_input(value)
|
||||
if isinstance(value, tuple):
|
||||
res[key] = clone_inputs(value)
|
||||
else:
|
||||
assert isinstance(value, torch.Tensor), type(value)
|
||||
res[key] = clone_input(value)
|
||||
return res
|
||||
|
||||
res = list(example_inputs)
|
||||
|
Reference in New Issue
Block a user