mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix torch.sparse.log_softmax on CPU (#161959)
Fix https://github.com/pytorch/pytorch/issues/152293. **Example:** ``` import torch from torch.sparse import log_softmax as sparse_log_softmax def test_bug(): a = torch.rand(4, 3) b = a - 10000000.0 b_sparse = b.to_sparse() cpu_out_sparse = sparse_log_softmax(b_sparse, dim=1).to_dense() print('cpu_out_sparse =', cpu_out_sparse) b_sparse_double = b.double().to_sparse() cpu_out_sparse_double = sparse_log_softmax(b_sparse_double, dim=1).to_dense() print('cpu_out_sparse_double =', cpu_out_sparse_double) if __name__ == '__main__': test_bug() ``` **Output:** - before ``` cpu_out_sparse = tensor([[-2., -1., -2.], [-1., -1., -1.], [-1., -2., -2.], [-1., -1., -2.]]) cpu_out_sparse_double = tensor([[-1.5514, -0.5514, -1.5514], [-1.0986, -1.0986, -1.0986], [-0.5514, -1.5514, -1.5514], [-0.8620, -0.8620, -1.8620]], dtype=torch.float64) ``` - after ``` cpu_out_sparse = tensor([[-0.8620, -1.8620, -0.8620], [-1.0986, -1.0986, -1.0986], [-1.8620, -0.8620, -0.8620], [-1.0986, -1.0986, -1.0986]]) cpu_out_sparse_double = tensor([[-0.8620, -1.8620, -0.8620], [-1.0986, -1.0986, -1.0986], [-1.8620, -0.8620, -0.8620], [-1.0986, -1.0986, -1.0986]], dtype=torch.float64) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161959 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
4840a1a591
commit
002e59440a
@ -3694,6 +3694,14 @@ class TestSparse(TestSparseBase):
|
||||
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
|
||||
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_log_softmax_float(self, device, dtype):
|
||||
x = (torch.rand(4, 3, dtype=dtype, device=device) - 10000000.0).to_sparse()
|
||||
out = torch.sparse.log_softmax(x, dim=1).to_dense()
|
||||
x_double = x.double()
|
||||
out_double = torch.sparse.log_softmax(x_double, dim=1).to_dense()
|
||||
self.assertEqual(out, out_double.to(dtype=dtype))
|
||||
|
||||
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
|
||||
@coalescedonoff
|
||||
@dtypes(*floating_and_complex_types())
|
||||
|
Reference in New Issue
Block a user