fix torch.sparse.log_softmax on CPU (#161959)

Fix https://github.com/pytorch/pytorch/issues/152293.

**Example:**
```
import torch
from torch.sparse import log_softmax as sparse_log_softmax

def test_bug():
    a = torch.rand(4, 3)
    b = a - 10000000.0
    b_sparse = b.to_sparse()

    cpu_out_sparse = sparse_log_softmax(b_sparse, dim=1).to_dense()
    print('cpu_out_sparse =', cpu_out_sparse)

    b_sparse_double = b.double().to_sparse()
    cpu_out_sparse_double = sparse_log_softmax(b_sparse_double, dim=1).to_dense()
    print('cpu_out_sparse_double =', cpu_out_sparse_double)

if __name__ == '__main__':
    test_bug()
```

**Output:**

- before
```
cpu_out_sparse = tensor([[-2., -1., -2.],
        [-1., -1., -1.],
        [-1., -2., -2.],
        [-1., -1., -2.]])
cpu_out_sparse_double = tensor([[-1.5514, -0.5514, -1.5514],
        [-1.0986, -1.0986, -1.0986],
        [-0.5514, -1.5514, -1.5514],
        [-0.8620, -0.8620, -1.8620]], dtype=torch.float64)
```

- after
```
cpu_out_sparse = tensor([[-0.8620, -1.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986],
        [-1.8620, -0.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986]])
cpu_out_sparse_double = tensor([[-0.8620, -1.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986],
        [-1.8620, -0.8620, -0.8620],
        [-1.0986, -1.0986, -1.0986]], dtype=torch.float64)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161959
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/mingfeima
This commit is contained in:
Sun, Jiayi
2025-09-10 10:28:39 +00:00
committed by PyTorch MergeBot
parent 1274297e06
commit 6b9b7ce6fe
2 changed files with 14 additions and 3 deletions

View File

@ -2,6 +2,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/sparse/ParamUtils.h>
#include <ATen/native/SparseTensorUtils.h>
@ -295,6 +296,7 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di
to exp functions as well as reuse of softmax implementation for
log_softmax.
*/
using accscalar_t = at::acc_type<scalar_t, false>;
auto sparse_dim = input.sparse_dim();
auto indices = input._indices().contiguous();
auto values = input._values().contiguous();
@ -340,14 +342,14 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di
continue;
/* Prepare scratch space */
std::vector<scalar_t> mx_row(nvalues, -std::numeric_limits<scalar_t>::infinity());
std::vector<scalar_t> exp_sums_row(nvalues, 0);
std::vector<accscalar_t> mx_row(nvalues, -std::numeric_limits<accscalar_t>::infinity());
std::vector<accscalar_t> exp_sums_row(nvalues, 0);
/* Compute mx */
for (int64_t i : pool_indices) {
auto values_row = values_accessor[i];
for (const auto j : c10::irange(nvalues)) {
mx_row[j] = std::max(mx_row[j], values_row[j]);
mx_row[j] = std::max(mx_row[j], accscalar_t(values_row[j]));
}
}

View File

@ -3694,6 +3694,15 @@ class TestSparse(TestSparseBase):
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
@dtypes(torch.float)
@expectedFailureMPS
def test_log_softmax_float(self, device, dtype):
x = (torch.rand(4, 3, dtype=dtype, device=device) - 10000000.0).to_sparse()
out = torch.sparse.log_softmax(x, dim=1).to_dense()
x_double = x.double()
out_double = torch.sparse.log_softmax(x_double, dim=1).to_dense()
self.assertEqual(out, out_double.to(dtype=dtype))
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
@coalescedonoff
@dtypes(*floating_and_complex_types())