mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
fix torch.sparse.log_softmax on CPU (#161959)
Fix https://github.com/pytorch/pytorch/issues/152293. **Example:** ``` import torch from torch.sparse import log_softmax as sparse_log_softmax def test_bug(): a = torch.rand(4, 3) b = a - 10000000.0 b_sparse = b.to_sparse() cpu_out_sparse = sparse_log_softmax(b_sparse, dim=1).to_dense() print('cpu_out_sparse =', cpu_out_sparse) b_sparse_double = b.double().to_sparse() cpu_out_sparse_double = sparse_log_softmax(b_sparse_double, dim=1).to_dense() print('cpu_out_sparse_double =', cpu_out_sparse_double) if __name__ == '__main__': test_bug() ``` **Output:** - before ``` cpu_out_sparse = tensor([[-2., -1., -2.], [-1., -1., -1.], [-1., -2., -2.], [-1., -1., -2.]]) cpu_out_sparse_double = tensor([[-1.5514, -0.5514, -1.5514], [-1.0986, -1.0986, -1.0986], [-0.5514, -1.5514, -1.5514], [-0.8620, -0.8620, -1.8620]], dtype=torch.float64) ``` - after ``` cpu_out_sparse = tensor([[-0.8620, -1.8620, -0.8620], [-1.0986, -1.0986, -1.0986], [-1.8620, -0.8620, -0.8620], [-1.0986, -1.0986, -1.0986]]) cpu_out_sparse_double = tensor([[-0.8620, -1.8620, -0.8620], [-1.0986, -1.0986, -1.0986], [-1.8620, -0.8620, -0.8620], [-1.0986, -1.0986, -1.0986]], dtype=torch.float64) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161959 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
4840a1a591
commit
002e59440a
@ -2,6 +2,7 @@
|
|||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
|
#include <ATen/AccumulateType.h>
|
||||||
#include <ATen/NamedTensorUtils.h>
|
#include <ATen/NamedTensorUtils.h>
|
||||||
#include <ATen/native/sparse/ParamUtils.h>
|
#include <ATen/native/sparse/ParamUtils.h>
|
||||||
#include <ATen/native/SparseTensorUtils.h>
|
#include <ATen/native/SparseTensorUtils.h>
|
||||||
@ -295,6 +296,7 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di
|
|||||||
to exp functions as well as reuse of softmax implementation for
|
to exp functions as well as reuse of softmax implementation for
|
||||||
log_softmax.
|
log_softmax.
|
||||||
*/
|
*/
|
||||||
|
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||||
auto sparse_dim = input.sparse_dim();
|
auto sparse_dim = input.sparse_dim();
|
||||||
auto indices = input._indices().contiguous();
|
auto indices = input._indices().contiguous();
|
||||||
auto values = input._values().contiguous();
|
auto values = input._values().contiguous();
|
||||||
@ -340,14 +342,14 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di
|
|||||||
continue;
|
continue;
|
||||||
|
|
||||||
/* Prepare scratch space */
|
/* Prepare scratch space */
|
||||||
std::vector<scalar_t> mx_row(nvalues, -std::numeric_limits<scalar_t>::infinity());
|
std::vector<accscalar_t> mx_row(nvalues, -std::numeric_limits<accscalar_t>::infinity());
|
||||||
std::vector<scalar_t> exp_sums_row(nvalues, 0);
|
std::vector<scalar_t> exp_sums_row(nvalues, 0);
|
||||||
|
|
||||||
/* Compute mx */
|
/* Compute mx */
|
||||||
for (int64_t i : pool_indices) {
|
for (int64_t i : pool_indices) {
|
||||||
auto values_row = values_accessor[i];
|
auto values_row = values_accessor[i];
|
||||||
for (const auto j : c10::irange(nvalues)) {
|
for (const auto j : c10::irange(nvalues)) {
|
||||||
mx_row[j] = std::max(mx_row[j], values_row[j]);
|
mx_row[j] = std::max(mx_row[j], accscalar_t(values_row[j]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3694,6 +3694,14 @@ class TestSparse(TestSparseBase):
|
|||||||
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
|
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 1, device, dtype)
|
||||||
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
|
self._check_zero_nnz_softmax_op(torch.sparse.log_softmax, 10, device, dtype)
|
||||||
|
|
||||||
|
@dtypes(torch.float)
|
||||||
|
def test_log_softmax_float(self, device, dtype):
|
||||||
|
x = (torch.rand(4, 3, dtype=dtype, device=device) - 10000000.0).to_sparse()
|
||||||
|
out = torch.sparse.log_softmax(x, dim=1).to_dense()
|
||||||
|
x_double = x.double()
|
||||||
|
out_double = torch.sparse.log_softmax(x_double, dim=1).to_dense()
|
||||||
|
self.assertEqual(out, out_double.to(dtype=dtype))
|
||||||
|
|
||||||
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
|
# TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(*floating_and_complex_types())
|
@dtypes(*floating_and_complex_types())
|
||||||
|
Reference in New Issue
Block a user