mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
enable device index check for all device types (#126767)
enable device index check for all device types for grad setter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126767 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
0b7e8df7d8
commit
27a14405d3
@ -1084,15 +1084,13 @@ int THPVariable_set_grad(THPVariable* self, PyObject* py_grad, void* unused) {
|
||||
grad.options().type_equal(var.options()),
|
||||
"attempting to assign a gradient to a tensor that has data of a different type");
|
||||
}
|
||||
if (var.is_cuda()) {
|
||||
TORCH_CHECK(
|
||||
grad.get_device() == var.get_device(),
|
||||
"attempting to assign a gradient located on device with index '",
|
||||
grad.get_device(),
|
||||
"' to a tensor located on device with index '",
|
||||
var.get_device(),
|
||||
"'. Please ensure that the gradient and the tensor are on the same device");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
grad.get_device() == var.get_device(),
|
||||
"attempting to assign a gradient located on device with index '",
|
||||
grad.get_device(),
|
||||
"' to a tensor located on device with index '",
|
||||
var.get_device(),
|
||||
"'. Please ensure that the gradient and the tensor are on the same device");
|
||||
TORCH_CHECK(
|
||||
grad.sym_sizes().equals(var.sym_sizes()),
|
||||
"attempting to assign a gradient of size '",
|
||||
|
Reference in New Issue
Block a user