Fix bug in torch.sparse.addmm on CUDA when beta != 0 or 1 (#56160)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/55917, which caused `torch.sparse.addmm` to fail on CUDA whenever `beta` was different from 0 or 1

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56160

Reviewed By: ejguan

Differential Revision: D27825108

Pulled By: ngimel

fbshipit-source-id: 2ade5ea38c5322768dc4dffb40c65fcbb17ec201
This commit is contained in:
sorenrasmussenai
2021-04-26 02:56:39 -07:00
committed by Facebook GitHub Bot
parent f3743f097f
commit f27513e951
2 changed files with 22 additions and 13 deletions

View File

@ -64,12 +64,8 @@ void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int
Tensor r__;
if (cast_beta == 0) {
r_.zero_();
} else if (cast_beta == 1) {
if (!is_same_tensor(t, r_)) {
r_.copy_(t);
}
} else {
at::mul_out(r_, t, scalar_to_tensor(beta));
} else if (!is_same_tensor(t, r_)) {
r_.copy_(t);
}
if(r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) {
@ -111,7 +107,9 @@ void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int
r__.data_ptr<scalar_t>(),
r__.stride(1));
}
r_.copy_(r__);
if (!is_same_tensor(r__, r_)) {
r_.copy_(r__);
}
}
// --------------------------------------------------------------------

View File

@ -1224,7 +1224,12 @@ class TestSparse(TestCase):
@coalescedonoff
@dtypes(torch.double)
def test_sparse_addmm(self, device, dtype, coalesced):
def test_shape(m, n, p, nnz, broadcast):
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
if alpha_beta is None:
alpha = random.random()
beta = random.random()
else:
alpha, beta = alpha_beta
if broadcast:
D1 = torch.randn((), dtype=dtype, device=device).requires_grad_(True)
else:
@ -1233,14 +1238,20 @@ class TestSparse(TestCase):
S = self._gen_sparse(2, nnz, [n, m], dtype, device, coalesced)[0]
S_dense = S.to_dense().requires_grad_(True)
S.requires_grad_(True)
self.assertEqual(torch.sparse.addmm(D1, S, D2), torch.addmm(D1, S_dense, D2))
Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
self.assertEqual(Y, Y_dense)
def fn(S, D1, D2):
return torch.sparse.addmm(D1, S, D2)
def fn(S, D1, D2, beta=beta, alpha=alpha):
return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
test_shape(7, 8, 9, 20, False)
test_shape(7, 8, 9, 20, True)
test_shape(7, 8, 9, 20, False, None)
test_shape(7, 8, 9, 20, True, None)
test_shape(7, 8, 9, 20, False, (1, 0))
test_shape(7, 8, 9, 20, True, (1, 0))
test_shape(7, 8, 9, 20, False, (1, 1))
test_shape(7, 8, 9, 20, True, (1, 1))
@coalescedonoff
@dtypes(torch.double)