mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use swap_tensors path in nn.Module.to for FakeTensor (#152539)
Fixes https://github.com/pytorch/pytorch/issues/148977 Differential Revision: [D76458023](https://our.internmc.facebook.com/intern/diff/D76458023) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152539 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
db01f1032f
commit
38bfd462b8
@ -954,9 +954,13 @@ class Module:
|
||||
param_applied = fn(param)
|
||||
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
# subclasses may have multiple child tensors so we need to use swap_tensors
|
||||
p_should_use_swap_tensors = (
|
||||
should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
|
||||
should_use_swap_tensors
|
||||
or is_traceable_wrapper_subclass(param_applied)
|
||||
or isinstance(param, FakeTensor)
|
||||
)
|
||||
|
||||
param_grad = param.grad
|
||||
|
Reference in New Issue
Block a user