enable device index check for all device types (#126767)

enable device index check for all device types for grad setter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126767
Approved by: https://github.com/albanD
This commit is contained in:
garfield1997
2024-06-27 01:09:53 +00:00
committed by PyTorch MergeBot
parent 0b7e8df7d8
commit 27a14405d3

View File

@ -1084,15 +1084,13 @@ int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) {
grad.options().type_equal(var.options()),
"attempting to assign a gradient to a tensor that has data of a different type");
}
if (var.is_cuda()) {
TORCH_CHECK(
grad.get_device() == var.get_device(),
"attempting to assign a gradient located on device with index '",
grad.get_device(),
"' to a tensor located on device with index '",
var.get_device(),
"'. Please ensure that the gradient and the tensor are on the same device");
}
TORCH_CHECK(
grad.get_device() == var.get_device(),
"attempting to assign a gradient located on device with index '",
grad.get_device(),
"' to a tensor located on device with index '",
var.get_device(),
"'. Please ensure that the gradient and the tensor are on the same device");
TORCH_CHECK(
grad.sym_sizes().equals(var.sym_sizes()),
"attempting to assign a gradient of size '",