Avoid cuda init to FakeTensorMode (#124413)

Also partially fixes #122109

This PR:
- We add a C++ flag (only_lift_cpu_tensors) to toggle the
  torch.tensor(1, device='cuda') ctor strategy.
  When false (default), it does the current PyTorch behavior
  of unconditionally constructing a concrete CUDA tensor then calling
  lift_fresh on it. When true, we instead construct a concrete CPU
  tensor, call lift_fresh, and then call Tensor.to(device) (under any ambient
  modes).
- FakeTensorMode flips this flag depending on if CUDA is available or
  not. We don't unconditionally set the flag to True because that is
  likely BC-breaking.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124413
Approved by: https://github.com/eellison
This commit is contained in:
rzou
2024-04-18 10:45:17 -07:00
committed by PyTorch MergeBot
parent e620c3e814
commit 889e3eeed3
7 changed files with 109 additions and 12 deletions

View File

@ -16,6 +16,7 @@
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/tensor_new.h>
#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
@ -903,6 +904,9 @@ void initDispatchBindings(PyObject* module) {
->set_warn_deprecated_on_mutable_data_ptr();
});
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
using c10::impl::TorchDispatchModeKey;
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)