mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[2/n] Support module.to("cuda:0") in FakeTensorMode on cuda-less machine (#163433)
Summary: To support exporting a cuda model on a CPU-only machine under fake tensor mode. User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call. This diff supports this. I expect the following pattern to work ``` with FakeTensorMode(allow_non_fake_inputs=True): cuda_module = module.to("cuda:0") cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs]) with torch.no_grad(): ep = torch.export.export(cuda_module, cuda_sample_inputs) ``` Before Moving module.to("cuda:0") under fake tensor mode would have parameter on `meta` device. After parameters would be on "cuda:0" . Test Plan: buck2 run fbcode//caffe2/test:fake_tensor -- --r test_move_module Reviewed By: mikaylagawarecki Differential Revision: D80102876 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163433 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
d15048493c
commit
6f9aef5fef
@ -929,8 +929,12 @@ class Module:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
def compute_should_use_set_data(tensor, tensor_applied) -> bool:
|
||||
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
|
||||
if torch._has_compatible_shallow_copy_type(
|
||||
tensor, tensor_applied
|
||||
) and not isinstance(tensor_applied, FakeTensor):
|
||||
# If the new tensor has compatible tensor type as the existing tensor,
|
||||
# the current behavior is to change the tensor in-place using `.data =`,
|
||||
# and the future behavior is to overwrite the existing tensor. However,
|
||||
@ -957,8 +961,6 @@ class Module:
|
||||
param_applied = fn(param)
|
||||
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
|
||||
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
||||
# subclasses may have multiple child tensors so we need to use swap_tensors
|
||||
p_should_use_swap_tensors = (
|
||||
should_use_swap_tensors
|
||||
|
Reference in New Issue
Block a user