mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/n] Support cpu_tensor.to("cuda:0") in FakeTensorMode on cuda-less machine (#160431)
Summary: To support exporting a cuda model on a CPU-only machine under fake tensor mode. User commonly need to move sample inputs to the cuda device with .to("cuda:0") call. This diff supports this. Notice that .to("cuda") doesn't work yet, as it enquery current device idx by calling cuda API. I expect the following pattern to work ``` with FakeTensorMode(allow_non_fake_inputs=True): cuda_module = module.to("cuda:0") cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs]) with torch.no_grad(): ep = torch.export.export(cuda_module, cuda_sample_inputs) ``` Test Plan: buck2 run fbcode//caffe2/test:fake_tensor -- --r test_fake_gpu_no_init Rollback Plan: Differential Revision: D80101283 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160431 Approved by: https://github.com/henryoier, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d70c0babf5
commit
3938175ec1
@ -1517,13 +1517,27 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||
# it clearly will not work on CPU runner
|
||||
if torch._functorch.config.fake_tensor_propagate_real_tensors:
|
||||
return
|
||||
with FakeTensorMode():
|
||||
torch.empty(10, device=GPU_TYPE)
|
||||
torch.ones(10, device=GPU_TYPE)
|
||||
torch.zeros(10, device=GPU_TYPE)
|
||||
torch.rand(10, device=GPU_TYPE)
|
||||
torch.tensor(3.14, device=GPU_TYPE)
|
||||
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE)
|
||||
|
||||
with FakeTensorMode(allow_non_fake_inputs=True):
|
||||
self.assertEqual(torch.empty(10, device=GPU_TYPE).device.type, GPU_TYPE)
|
||||
self.assertEqual(torch.ones(10, device=GPU_TYPE).device.type, GPU_TYPE)
|
||||
self.assertEqual(torch.zeros(10, device=GPU_TYPE).device.type, GPU_TYPE)
|
||||
self.assertEqual(torch.rand(10, device=GPU_TYPE).device.type, GPU_TYPE)
|
||||
self.assertEqual(torch.tensor(3.14, device=GPU_TYPE).device.type, GPU_TYPE)
|
||||
self.assertEqual(
|
||||
torch.tensor([[3.14, 2], [1, 2]], device=GPU_TYPE).device.type, GPU_TYPE
|
||||
)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_move_meta_tensor(self):
|
||||
if torch._functorch.config.fake_tensor_propagate_real_tensors:
|
||||
return
|
||||
|
||||
meta_tensor = torch.ones(2, device="meta")
|
||||
gpu_device = torch.device(GPU_TYPE)
|
||||
with FakeTensorMode(allow_non_fake_inputs=True):
|
||||
self.assertEqual(meta_tensor.to(device="cpu").device.type, "cpu")
|
||||
self.assertEqual(meta_tensor.to(device=GPU_TYPE).device.type, GPU_TYPE)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_conv_c1_backward(self):
|
||||
|
@ -2327,11 +2327,28 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
converter = self.fake_tensor_converter
|
||||
|
||||
is_lift_func = func in self.lift_fns
|
||||
|
||||
# If we are trying to avoid device init, then we need to avoid constant
|
||||
# prop on constant tensors for ops that change devices.
|
||||
avoiding_device_init = False
|
||||
if self.avoid_device_init:
|
||||
if (
|
||||
func == torch.ops.aten._to_copy.default
|
||||
and "device" in kwargs
|
||||
and kwargs["device"].type != "cpu" # type: ignore[attr-defined]
|
||||
):
|
||||
avoiding_device_init = True
|
||||
if func == torch.ops.prims.device_put.default:
|
||||
avoiding_device_init = True
|
||||
|
||||
# skip const prop for aten._to_copy if
|
||||
# 1. input tensor is on "meta" device
|
||||
# 2. destination device is unavailable, captured by `avoiding_device_init`
|
||||
device_conversion_skip_const_prop = (
|
||||
func is torch.ops.aten._to_copy.default
|
||||
and isinstance(args[0], torch.Tensor)
|
||||
and args[0].device.type == "meta"
|
||||
)
|
||||
) or avoiding_device_init
|
||||
|
||||
# To constant propagate through these functions:
|
||||
# 1, If this is a lift due to a torch.tensor call,
|
||||
@ -2377,19 +2394,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if type(args[0]) is Tensor:
|
||||
return converter.from_real_tensor(self, args[0])
|
||||
|
||||
# If we are trying to avoid device init, then we need to avoid constant
|
||||
# prop on constant tensors for ops that change devices.
|
||||
avoiding_device_init = False
|
||||
if self.avoid_device_init:
|
||||
if (
|
||||
func == torch.ops.aten._to_copy.default
|
||||
and "device" in kwargs
|
||||
and kwargs["device"] != "cpu"
|
||||
):
|
||||
avoiding_device_init = True
|
||||
if func == torch.ops.prims.device_put.default:
|
||||
avoiding_device_init = True
|
||||
|
||||
# Recompute flat_arg_fake_tensors here again in case some of the inputs
|
||||
# were real tensors and fakified in validate_and_convert_non_fake_tensors
|
||||
(flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
|
||||
|
Reference in New Issue
Block a user