fix half grad assignment (#11781)

Summary:
currently grad assignment for half type fails with a misleading RuntimeError
```
RuntimeError: torch.cuda.sparse.HalfTensor is not enabled.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11781

Differential Revision: D9931884

Pulled By: soumith

fbshipit-source-id: 03e946c3833d1339a99585c9aa2dbb670f8bf459
This commit is contained in:
Natalia Gimelshein
2018-09-18 22:47:59 -07:00
committed by Facebook Github Bot
parent b46f1b8ca7
commit 8601b33c07
2 changed files with 10 additions and 4 deletions

View File

@ -237,9 +237,15 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
"can't assign Variable as its own grad");
auto& grad = ((THPVariable*)py_grad)->cdata;
auto& sparseType = var.type().toBackend(var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU);
bool gradIsSparse = false;
auto backend = var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU;
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType());
if (typeOpt) {
auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType());
gradIsSparse = grad.type() == sparseType;
}
THPUtils_assertRet(-1, grad.type() == var.type() || grad.type() == sparseType,
THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
"assigned grad has data of a different type");
if (var.type().is_cuda()) {
THPUtils_assertRet(-1, grad.get_device() == var.get_device(),