[Reland] Return NoOpDeviceGuardImpl in replace of CudaDeviceGuard when device is not available, or cpu-only build (#163016)

Reland of #160532

Summary:

To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call.
This diff supports this.
I expect the following pattern to work
```
with FakeTensorMode(allow_non_fake_inputs=True):
    cuda_module = module.to("cuda:0")
    cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])
    with torch.no_grad():
        ep = torch.export.export(cuda_module, cuda_sample_inputs)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163016
Approved by: https://github.com/huydhn
This commit is contained in:
Sherlock Huang
2025-09-17 05:01:33 +00:00
committed by PyTorch MergeBot
parent bb635a11f8
commit f1eb99e2e4
6 changed files with 207 additions and 4 deletions

View File

@ -26,6 +26,7 @@
#include <ATen/native/Normalization.h>
#include <c10/core/Device.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/AbortHandler.h>
#include <c10/util/Backtrace.h>
#include <c10/util/Logging.h>
@ -1550,6 +1551,15 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
END_HANDLE_TH_ERRORS
}
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
PyObject* self,
PyObject* noargs) {
HANDLE_TH_ERRORS
c10::impl::ensureCUDADeviceGuardSet();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static std::initializer_list<PyMethodDef> TorchMethods = {
{"_initExtension", THPModule_initExtension, METH_O, nullptr},
{"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
@ -1845,7 +1855,13 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
METH_FASTCALL,
nullptr},
{nullptr, nullptr, 0, nullptr}};
{"_ensureCUDADeviceGuardSet",
THCPModule_ensureCUDADeviceGuardSet,
METH_NOARGS,
nullptr},
{nullptr, nullptr, 0, nullptr}
};
#ifdef USE_CUDA
// NOLINTBEGIN(misc-use-internal-linkage)