Make getting the dtype of a tensor work for backend extensions.

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17131

Differential Revision: D14093163

Pulled By: gchanan

fbshipit-source-id: 06638706e26505e3c741b7ae290000ca258599db
This commit is contained in:
Gregory Chanan
2019-02-15 13:44:18 -08:00
committed by Facebook Github Bot
parent 9b5d3f6f5e
commit 6454e3262d
3 changed files with 19 additions and 5 deletions

View File

@ -248,10 +248,10 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
auto& grad = ((THPVariable*)py_grad)->cdata;
bool gradIsSparse = false;
auto backend = var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU;
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType());
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.scalar_type());
if (typeOpt) {
auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType());
auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.type().scalarType());
auto& sparseType = at::globalContext().getNonVariableType(backend, var.scalar_type());
auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.scalar_type());
gradIsSparse = gradType == sparseType;
}
@ -387,7 +387,7 @@ static PyObject *THPVariable_dtype(THPVariable *self)
{
HANDLE_TH_ERRORS
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType()));
return torch::autograd::utils::wrap(torch::getDtype(self_.scalar_type()));
END_HANDLE_TH_ERRORS
}