mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix bug in torch.sparse.addmm on CUDA when beta != 0 or 1 (#56160)
Summary: Fixes https://github.com/pytorch/pytorch/issues/55917, which caused `torch.sparse.addmm` to fail on CUDA whenever `beta` was different from 0 or 1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/56160 Reviewed By: ejguan Differential Revision: D27825108 Pulled By: ngimel fbshipit-source-id: 2ade5ea38c5322768dc4dffb40c65fcbb17ec201
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f3743f097f
commit
f27513e951
@ -64,12 +64,8 @@ void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int
|
||||
Tensor r__;
|
||||
if (cast_beta == 0) {
|
||||
r_.zero_();
|
||||
} else if (cast_beta == 1) {
|
||||
if (!is_same_tensor(t, r_)) {
|
||||
r_.copy_(t);
|
||||
}
|
||||
} else {
|
||||
at::mul_out(r_, t, scalar_to_tensor(beta));
|
||||
} else if (!is_same_tensor(t, r_)) {
|
||||
r_.copy_(t);
|
||||
}
|
||||
|
||||
if(r_.stride(0) == 1 && r_.stride(1) == r_.size(0)) {
|
||||
@ -111,7 +107,9 @@ void s_addmm_out_sparse_dense_cuda_worker(int64_t nnz, int64_t m, int64_t n, int
|
||||
r__.data_ptr<scalar_t>(),
|
||||
r__.stride(1));
|
||||
}
|
||||
r_.copy_(r__);
|
||||
if (!is_same_tensor(r__, r_)) {
|
||||
r_.copy_(r__);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
|
@ -1224,7 +1224,12 @@ class TestSparse(TestCase):
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
def test_sparse_addmm(self, device, dtype, coalesced):
|
||||
def test_shape(m, n, p, nnz, broadcast):
|
||||
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
|
||||
if alpha_beta is None:
|
||||
alpha = random.random()
|
||||
beta = random.random()
|
||||
else:
|
||||
alpha, beta = alpha_beta
|
||||
if broadcast:
|
||||
D1 = torch.randn((), dtype=dtype, device=device).requires_grad_(True)
|
||||
else:
|
||||
@ -1233,14 +1238,20 @@ class TestSparse(TestCase):
|
||||
S = self._gen_sparse(2, nnz, [n, m], dtype, device, coalesced)[0]
|
||||
S_dense = S.to_dense().requires_grad_(True)
|
||||
S.requires_grad_(True)
|
||||
self.assertEqual(torch.sparse.addmm(D1, S, D2), torch.addmm(D1, S_dense, D2))
|
||||
Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
|
||||
Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
|
||||
self.assertEqual(Y, Y_dense)
|
||||
|
||||
def fn(S, D1, D2):
|
||||
return torch.sparse.addmm(D1, S, D2)
|
||||
def fn(S, D1, D2, beta=beta, alpha=alpha):
|
||||
return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
|
||||
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
|
||||
|
||||
test_shape(7, 8, 9, 20, False)
|
||||
test_shape(7, 8, 9, 20, True)
|
||||
test_shape(7, 8, 9, 20, False, None)
|
||||
test_shape(7, 8, 9, 20, True, None)
|
||||
test_shape(7, 8, 9, 20, False, (1, 0))
|
||||
test_shape(7, 8, 9, 20, True, (1, 0))
|
||||
test_shape(7, 8, 9, 20, False, (1, 1))
|
||||
test_shape(7, 8, 9, 20, True, (1, 1))
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
|
Reference in New Issue
Block a user