mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Return NoOpDeviceGuardImpl in replace of CudaDeviceGuard when device is not available, or cpu-only build (#160532)
Summary: To support exporting a cuda model on a CPU-only machine under fake tensor mode. User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call. This diff supports this. I expect the following pattern to work ``` with FakeTensorMode(allow_non_fake_inputs=True): cuda_module = module.to("cuda:0") cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs]) with torch.no_grad(): ep = torch.export.export(cuda_module, cuda_sample_inputs) ``` Test Plan: CI Rollback Plan: Differential Revision: D80181887 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160532 Approved by: https://github.com/henryoier, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0925c644ed
commit
a956c4ab1c
@ -26,6 +26,7 @@
|
||||
#include <ATen/native/Normalization.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/util/AbortHandler.h>
|
||||
#include <c10/util/Backtrace.h>
|
||||
#include <c10/util/Logging.h>
|
||||
@ -1550,6 +1551,15 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
|
||||
PyObject* self,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
c10::impl::ensureCUDADeviceGuardSet();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
{"_initExtension", THPModule_initExtension, METH_O, nullptr},
|
||||
{"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
|
||||
@ -1845,7 +1855,13 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
|
||||
METH_FASTCALL,
|
||||
nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
{"_ensureCUDADeviceGuardSet",
|
||||
THCPModule_ensureCUDADeviceGuardSet,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
|
||||
};
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// NOLINTBEGIN(misc-use-internal-linkage)
|
||||
|
Reference in New Issue
Block a user