mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Reland] Return NoOpDeviceGuardImpl in replace of CudaDeviceGuard when device is not available, or cpu-only build (#163016)
Reland of #160532 Summary: To support exporting a cuda model on a CPU-only machine under fake tensor mode. User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call. This diff supports this. I expect the following pattern to work ``` with FakeTensorMode(allow_non_fake_inputs=True): cuda_module = module.to("cuda:0") cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs]) with torch.no_grad(): ep = torch.export.export(cuda_module, cuda_sample_inputs) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/163016 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
bb635a11f8
commit
f1eb99e2e4
@ -1387,6 +1387,12 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# See NOTE: [torch.tensor, lift_fresh, and device movement]
|
||||
prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors()
|
||||
torch._C._set_only_lift_cpu_tensors(True)
|
||||
|
||||
# In the case of CPU-only build or cuda device unavailable,
|
||||
# we patch the cuda device guard to use NoOpDeviceGuardImpl.
|
||||
# This enables us to trace over cuda kernels under FakeTensorMode.
|
||||
torch._C._ensureCUDADeviceGuardSet()
|
||||
|
||||
maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key)
|
||||
if self is not maybe_prev_fake_mode:
|
||||
self.enter_stack.append(
|
||||
@ -1397,6 +1403,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# no-op (still need to re-set the fake mode though since we unset it)
|
||||
torch._C._set_dispatch_mode(self)
|
||||
self.enter_stack.append((False, None, prev_only_lift_cpu_tensors))
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
|
||||
Reference in New Issue
Block a user