mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make getting the dtype of a tensor work for backend extensions.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17131 Differential Revision: D14093163 Pulled By: gchanan fbshipit-source-id: 06638706e26505e3c741b7ae290000ca258599db
This commit is contained in:
committed by
Facebook Github Bot
parent
9b5d3f6f5e
commit
6454e3262d
@ -248,10 +248,10 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad)
|
||||
auto& grad = ((THPVariable*)py_grad)->cdata;
|
||||
bool gradIsSparse = false;
|
||||
auto backend = var.is_cuda() ? Backend::SparseCUDA : Backend::SparseCPU;
|
||||
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.type().scalarType());
|
||||
auto typeOpt = at::globalContext().getNonVariableTypeOpt(backend, var.scalar_type());
|
||||
if (typeOpt) {
|
||||
auto& sparseType = at::globalContext().getNonVariableType(backend, var.type().scalarType());
|
||||
auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.type().scalarType());
|
||||
auto& sparseType = at::globalContext().getNonVariableType(backend, var.scalar_type());
|
||||
auto& gradType = at::globalContext().getNonVariableType(grad.type().backend(), grad.scalar_type());
|
||||
gradIsSparse = gradType == sparseType;
|
||||
}
|
||||
|
||||
@ -387,7 +387,7 @@ static PyObject *THPVariable_dtype(THPVariable *self)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto& self_ = self->cdata;
|
||||
return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType()));
|
||||
return torch::autograd::utils::wrap(torch::getDtype(self_.scalar_type()));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user