fix torch.sparse.log_softmax on CPU (#161959)

Fix https://github.com/pytorch/pytorch/issues/152293.

**Example:**
```
import torch
from torch.sparse import log_softmax as sparse_log_softmax

def test_bug():
    a = torch.rand(4, 3)
    b = a - 10000000.0
    b_sparse = b.to_sparse()

    cpu_out_sparse = sparse_log_softmax(b_sparse, dim=1).to_dense()
    print('cpu_out_sparse =', cpu_out_sparse)

    b_sparse_double = b.double().to_sparse()
    cpu_out_sparse_double = sparse_log_softmax(b_sparse_double, dim=1).to_dense()
    print('cpu_out_sparse_double =', cpu_out_sparse_double)

if __name__ == '__main__':
    test_bug()
```

**Output:**

- before
```
cpu_out_sparse = tensor([[-2., -1., -2.],
        [-1., -1., -1.],
        [-1., -2., -2.],
        [-1., -1., -2.]])
cpu_out_sparse_double = tensor([[-1.5514, -0.5514, -1.5514],
        [-1.0986, -1.0986, -1.0986],
        [-0.5514, -1.5514, -1.5514],
        [-0.8620, -0.8620, -1.8620]], dtype=torch.float64)
```

- after
```
cpu_out_sparse = tensor([[-0.8620, -1.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986],
        [-1.8620, -0.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986]])
cpu_out_sparse_double = tensor([[-0.8620, -1.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986],
        [-1.8620, -0.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986]], dtype=torch.float64)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161959
Approved by: https://github.com/Skylion007
This commit is contained in:
Sun, Jiayi
2025-09-03 14:42:17 +00:00
committed by PyTorch MergeBot
parent 4840a1a591
commit 002e59440a
2 changed files with 12 additions and 2 deletions

View File

@ -3694,6 +3694,14 @@ class TestSparse(TestSparseBase):
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
@dtypes(torch.float)
def test_log_softmax_float(self, device, dtype):
x = (torch.rand(4, 3, dtype=dtype, device=device) - 10000000.0).to_sparse()
out = torch.sparse.log_softmax(x, dim=1).to_dense()
x_double = x.double()
out_double = torch.sparse.log_softmax(x_double, dim=1).to_dense()
self.assertEqual(out, out_double.to(dtype=dtype))
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
@coalescedonoff
@dtypes(*floating_and_complex_types())