[2/n] Support module.to("cuda:0") in FakeTensorMode on cuda-less machine (#163433)

Summary:
To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call.
This diff supports this.

I expect the following pattern to work

```
with FakeTensorMode(allow_non_fake_inputs=True):
    cuda_module = module.to("cuda:0")
    cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])

    with torch.no_grad():
        ep = torch.export.export(cuda_module, cuda_sample_inputs)

```

Before
Moving module.to("cuda:0") under fake tensor mode would have parameter on `meta` device.

After
parameters would be on "cuda:0" .

Test Plan: buck2 run  fbcode//caffe2/test:fake_tensor -- --r test_move_module

Reviewed By: mikaylagawarecki

Differential Revision: D80102876

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163433
Approved by: https://github.com/albanD
This commit is contained in:
Sherlock Huang
2025-09-22 20:16:32 +00:00
committed by PyTorch MergeBot
parent d15048493c
commit 6f9aef5fef
3 changed files with 44 additions and 13 deletions

View File

@ -929,8 +929,12 @@ class Module:
for module in self.children():
module._apply(fn)
from torch._subclasses.fake_tensor import FakeTensor
def compute_should_use_set_data(tensor, tensor_applied) -> bool:
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(
tensor, tensor_applied
) and not isinstance(tensor_applied, FakeTensor):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
@ -957,8 +961,6 @@ class Module:
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
from torch._subclasses.fake_tensor import FakeTensor
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = (
should_use_swap_tensors