[1/n] Support cpu_tensor.to("cuda:0") in FakeTensorMode on cuda-less machine (#160431)

Summary:
To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") call.
This diff supports this.

Notice that .to("cuda") doesn't work yet, as it enquery current device idx by calling cuda API.

I expect the following pattern to work

```
with FakeTensorMode(allow_non_fake_inputs=True):
    cuda_module = module.to("cuda:0")
    cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])

    with torch.no_grad():
        ep = torch.export.export(cuda_module, cuda_sample_inputs)

```

Test Plan:
buck2 run  fbcode//caffe2/test:fake_tensor -- --r test_fake_gpu_no_init

Rollback Plan:

Differential Revision: D80101283

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160431
Approved by: https://github.com/henryoier, https://github.com/ezyang
This commit is contained in:
Sherlock Huang
2025-09-20 21:33:53 +00:00
committed by PyTorch MergeBot
parent d70c0babf5
commit 3938175ec1
2 changed files with 39 additions and 21 deletions

View File

@ -2327,11 +2327,28 @@ class FakeTensorMode(TorchDispatchMode):
converter = self.fake_tensor_converter
is_lift_func = func in self.lift_fns
# If we are trying to avoid device init, then we need to avoid constant
# prop on constant tensors for ops that change devices.
avoiding_device_init = False
if self.avoid_device_init:
if (
func == torch.ops.aten._to_copy.default
and "device" in kwargs
and kwargs["device"].type != "cpu" # type: ignore[attr-defined]
):
avoiding_device_init = True
if func == torch.ops.prims.device_put.default:
avoiding_device_init = True
# skip const prop for aten._to_copy if
# 1. input tensor is on "meta" device
# 2. destination device is unavailable, captured by `avoiding_device_init`
device_conversion_skip_const_prop = (
func is torch.ops.aten._to_copy.default
and isinstance(args[0], torch.Tensor)
and args[0].device.type == "meta"
)
) or avoiding_device_init
# To constant propagate through these functions:
# 1, If this is a lift due to a torch.tensor call,
@ -2377,19 +2394,6 @@ class FakeTensorMode(TorchDispatchMode):
if type(args[0]) is Tensor:
return converter.from_real_tensor(self, args[0])
# If we are trying to avoid device init, then we need to avoid constant
# prop on constant tensors for ops that change devices.
avoiding_device_init = False
if self.avoid_device_init:
if (
func == torch.ops.aten._to_copy.default
and "device" in kwargs
and kwargs["device"] != "cpu"
):
avoiding_device_init = True
if func == torch.ops.prims.device_put.default:
avoiding_device_init = True
# Recompute flat_arg_fake_tensors here again in case some of the inputs
# were real tensors and fakified in validate_and_convert_non_fake_tensors
(flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(