mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix half grad assignment (#11781)
Summary: currently grad assignment for half type fails with a misleading RuntimeError ``` RuntimeError: torch.cuda.sparse.HalfTensor is not enabled. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/11781 Differential Revision: D9931884 Pulled By: soumith fbshipit-source-id: 03e946c3833d1339a99585c9aa2dbb670f8bf459
This commit is contained in:
committed by
Facebook Github Bot
parent
b46f1b8ca7
commit
8601b33c07
@ -237,9 +237,15 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
|
||||
"can't assign Variable as its own grad");
|
||||
|
||||
auto& grad = ((THPVariable*)py_grad)->cdata;
|
||||
auto& sparseType = var.type().toBackend(var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU);
|
||||
bool gradIsSparse = false;
|
||||
auto backend = var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU;
|
||||
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType());
|
||||
if (typeOpt) {
|
||||
auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType());
|
||||
gradIsSparse = grad.type() == sparseType;
|
||||
}
|
||||
|
||||
THPUtils_assertRet(-1, grad.type() == var.type() || grad.type() == sparseType,
|
||||
THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
|
||||
"assigned grad has data of a different type");
|
||||
if (var.type().is_cuda()) {
|
||||
THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
|
||||
|
Reference in New Issue
Block a user