mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Reland] Return NoOpDeviceGuardImpl in replace of CudaDeviceGuard when device is not available, or cpu-only build (#163016)"
This reverts commit f1eb99e2e4363f20eb5896433e1eb7f7500aadea.
Reverted https://github.com/pytorch/pytorch/pull/163016 on behalf of https://github.com/jeffdaily due to broke rocm CI, see export/test_export_opinfo.py::TestExportOnFakeCudaCUDA::test_fake_export_nonzero_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/17787208381/job/50564369696) [HUD commit link](f1eb99e2e4
) ([comment](https://github.com/pytorch/pytorch/pull/163016#issuecomment-3303707552))
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/core/impl/FakeGuardImpl.h>
|
||||
#include <array>
|
||||
|
||||
namespace c10::impl {
|
||||
@ -15,26 +14,4 @@ DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
|
||||
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
|
||||
}
|
||||
|
||||
namespace {
|
||||
thread_local std::unique_ptr<DeviceGuardImplInterface> tls_fake_device_guard =
|
||||
nullptr;
|
||||
}
|
||||
|
||||
void ensureCUDADeviceGuardSet() {
|
||||
constexpr auto cuda_idx = static_cast<std::size_t>(DeviceType::CUDA);
|
||||
|
||||
const DeviceGuardImplInterface* p =
|
||||
device_guard_impl_registry[cuda_idx].load();
|
||||
|
||||
// A non-null `ptr` indicates that the CUDA guard is already set up,
|
||||
// implying this is using cuda build
|
||||
if (p && p->deviceCount() == 0) {
|
||||
// In following cases, we override CUDA guard interface with a no-op
|
||||
// device guard. When p->deviceCount() == 0, cuda build is enabled, but no
|
||||
// cuda devices available.
|
||||
tls_fake_device_guard = std::make_unique<FakeGuardImpl<DeviceType::CUDA>>();
|
||||
device_guard_impl_registry[cuda_idx].store(tls_fake_device_guard.get());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
|
@ -6,7 +6,6 @@
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
// Just for C10_ANONYMOUS_VARIABLE
|
||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <array>
|
||||
@ -252,7 +251,7 @@ struct C10_API DeviceGuardImplInterface {
|
||||
// for devices that don't actually have a concept of device index. Prominent
|
||||
// examples are CPU and Meta.
|
||||
template <DeviceType D>
|
||||
struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface {
|
||||
struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
|
||||
NoOpDeviceGuardImpl() = default;
|
||||
DeviceType type() const override {
|
||||
return D;
|
||||
@ -372,7 +371,5 @@ inline bool hasDeviceGuardImpl(DeviceType type) {
|
||||
return device_guard_impl_registry[static_cast<size_t>(type)].load();
|
||||
}
|
||||
|
||||
void C10_API ensureCUDADeviceGuardSet();
|
||||
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
@ -3,9 +3,6 @@
|
||||
# flake8: noqa
|
||||
|
||||
import itertools
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
@ -14,7 +11,6 @@ from torch.testing._internal.common_device_type import (
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
onlyCUDA,
|
||||
op_db,
|
||||
skip,
|
||||
skipOps,
|
||||
@ -133,157 +129,8 @@ class TestExportOpInfo(TestCase):
|
||||
_test_export_helper(self, dtype, op)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestExportOpInfo, globals(), only_for="cpu")
|
||||
|
||||
|
||||
selected_ops = {
|
||||
"__getitem__",
|
||||
# "nn.functional.batch_norm", # needs to fix
|
||||
"nn.functional.instance_norm",
|
||||
"nn.functional.multi_margin_loss",
|
||||
"nonzero",
|
||||
}
|
||||
selected_op_db = [op for op in op_db if op.name in selected_ops]
|
||||
|
||||
|
||||
class TestExportOnFakeCuda(TestCase):
|
||||
# In CI, this test runs on a CUDA machine with cuda build
|
||||
# We set CUDA_VISIBLE_DEVICES="" to simulate a CPU machine with cuda build
|
||||
# Running this on all ops in op_db is too slow, so we only run on a selected subset
|
||||
@onlyCUDA
|
||||
@ops(selected_op_db, allowed_dtypes=(torch.float,))
|
||||
def test_fake_export(self, device, dtype, op):
|
||||
test_script = f"""\
|
||||
import torch
|
||||
import itertools
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
ops = [op for op in op_db if op.name == "{op.name}"]
|
||||
assert len(ops) == 1
|
||||
op = ops[0]
|
||||
|
||||
sample_inputs_itr = op.sample_inputs("cpu", torch.float, requires_grad=False)
|
||||
|
||||
mode = FakeTensorMode(allow_non_fake_inputs=True)
|
||||
converter = mode.fake_tensor_converter
|
||||
# intentionally avoid cuda:0 to flush out some bugs
|
||||
target_device = "cuda:1"
|
||||
|
||||
def to_fake_device(x):
|
||||
x = converter.from_real_tensor(mode, x)
|
||||
x.fake_device = torch.device(target_device)
|
||||
return x
|
||||
|
||||
# Limit to first 100 inputs so tests don't take too long
|
||||
for sample_input in itertools.islice(sample_inputs_itr, 100):
|
||||
args = tuple([sample_input.input] + list(sample_input.args))
|
||||
kwargs = sample_input.kwargs
|
||||
|
||||
# hack to skip non-tensor in args, as export doesn't support it
|
||||
if any(not isinstance(arg, torch.Tensor) for arg in args):
|
||||
continue
|
||||
|
||||
if "device" in kwargs:
|
||||
kwargs["device"] = target_device
|
||||
|
||||
with mode:
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
torch.Tensor, to_fake_device, (args, kwargs)
|
||||
)
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
return op.op(*args, **kwargs)
|
||||
|
||||
m = Module()
|
||||
|
||||
ep = torch.export.export(m, args)
|
||||
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
fake_tensor = node.meta.get("val", None)
|
||||
if isinstance(fake_tensor, FakeTensor):
|
||||
assert fake_tensor.device == torch.device(target_device)
|
||||
"""
|
||||
r = (
|
||||
(
|
||||
subprocess.check_output(
|
||||
[sys.executable, "-c", test_script],
|
||||
env={"CUDA_VISIBLE_DEVICES": ""},
|
||||
)
|
||||
)
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
self.assertEqual(r, "")
|
||||
|
||||
@unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build")
|
||||
def test_preserve_original_behavior(self):
|
||||
test_script = f"""\
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
|
||||
def cuda_calls_behavior_unchanged():
|
||||
exception_count = 0
|
||||
|
||||
try:
|
||||
cpu_x = torch.randn(2)
|
||||
cuda_x = cpu_x.to("cuda")
|
||||
except Exception as e:
|
||||
exception_count += 1
|
||||
|
||||
try:
|
||||
torch.randn(2, device="cuda")
|
||||
except Exception as e:
|
||||
exception_count += 1
|
||||
|
||||
try:
|
||||
torch.cuda.get_device_capability()
|
||||
except Exception as e:
|
||||
exception_count += 1
|
||||
|
||||
try:
|
||||
torch.cuda.set_device(1)
|
||||
except Exception as e:
|
||||
exception_count += 1
|
||||
|
||||
try:
|
||||
torch.cuda.current_device()
|
||||
except Exception as e:
|
||||
exception_count += 1
|
||||
|
||||
assert torch.cuda.is_available() == False
|
||||
assert torch.cuda.device_count() == 0
|
||||
assert exception_count == 5
|
||||
|
||||
cuda_calls_behavior_unchanged()
|
||||
|
||||
cpu_x = torch.randn(2)
|
||||
with FakeTensorMode(allow_non_fake_inputs=True) as mode:
|
||||
cuda_x = mode.from_tensor(cpu_x)
|
||||
cuda_x.fake_device = torch.device("cuda")
|
||||
cuda_y = cuda_x + cuda_x
|
||||
assert cuda_y.device.type == "cuda"
|
||||
|
||||
# should fail again after exiting the fake mode, with the identical error message
|
||||
cuda_calls_behavior_unchanged()
|
||||
"""
|
||||
r = (
|
||||
(
|
||||
subprocess.check_output(
|
||||
[sys.executable, "-c", test_script],
|
||||
env={"CUDA_VISIBLE_DEVICES": ""},
|
||||
)
|
||||
)
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
self.assertEqual(r, "")
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestExportOnFakeCuda, globals(), only_for="cuda")
|
||||
only_for = "cpu"
|
||||
instantiate_device_type_tests(TestExportOpInfo, globals(), only_for=only_for)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1379,7 +1379,6 @@ def _get_linalg_preferred_backend() -> _LinalgBackend: ...
|
||||
def _set_linalg_preferred_backend(arg: _LinalgBackend): ...
|
||||
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
|
||||
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
|
||||
def _ensureCUDADeviceGuardSet() -> None: ...
|
||||
|
||||
class _LinalgBackend:
|
||||
Default: _LinalgBackend
|
||||
|
@ -1387,12 +1387,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# See NOTE: [torch.tensor, lift_fresh, and device movement]
|
||||
prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors()
|
||||
torch._C._set_only_lift_cpu_tensors(True)
|
||||
|
||||
# In the case of CPU-only build or cuda device unavailable,
|
||||
# we patch the cuda device guard to use NoOpDeviceGuardImpl.
|
||||
# This enables us to trace over cuda kernels under FakeTensorMode.
|
||||
torch._C._ensureCUDADeviceGuardSet()
|
||||
|
||||
maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key)
|
||||
if self is not maybe_prev_fake_mode:
|
||||
self.enter_stack.append(
|
||||
@ -1403,7 +1397,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# no-op (still need to re-set the fake mode though since we unset it)
|
||||
torch._C._set_dispatch_mode(self)
|
||||
self.enter_stack.append((False, None, prev_only_lift_cpu_tensors))
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
|
@ -26,7 +26,6 @@
|
||||
#include <ATen/native/Normalization.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/util/AbortHandler.h>
|
||||
#include <c10/util/Backtrace.h>
|
||||
#include <c10/util/Logging.h>
|
||||
@ -1551,15 +1550,6 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
|
||||
PyObject* self,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
c10::impl::ensureCUDADeviceGuardSet();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
{"_initExtension", THPModule_initExtension, METH_O, nullptr},
|
||||
{"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
|
||||
@ -1855,13 +1845,7 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
|
||||
METH_FASTCALL,
|
||||
nullptr},
|
||||
{"_ensureCUDADeviceGuardSet",
|
||||
THCPModule_ensureCUDADeviceGuardSet,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
|
||||
};
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// NOLINTBEGIN(misc-use-internal-linkage)
|
||||
|
Reference in New Issue
Block a user