Fix a crash in sparse compressed tensor invariants check when nnz == 0 (#115825)

Fixes python crash example from https://github.com/pytorch/pytorch/issues/115755

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115825
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson
2023-12-14 18:39:24 +02:00
committed by PyTorch MergeBot
parent eafeba71c1
commit 419f2ca3e3
2 changed files with 15 additions and 4 deletions

View File

@ -126,8 +126,8 @@ INVARIANT_CHECK_FUNC_API _check_idx_sorted_distinct_vals_slices_with_cidx(
// Note that ptr_idx_batch = &idx[batch_idx] and is contiguous.
const auto* RESTRICT slice_begin = ptr_idx_batch + cidx;
const auto* RESTRICT slice_end = ptr_idx_batch + cidx_next;
for (auto* RESTRICT curr = slice_begin + 1; curr < slice_end; ++curr) {
const auto invariant = *(curr - 1) < *curr;
for (auto* RESTRICT curr = slice_begin; (slice_begin < slice_end) && (curr + 1 < slice_end); ++curr) {
const auto invariant = *curr < *(curr + 1);
if (cdim_name == CDimName::CRow) {
_assert(
invariant,
@ -335,10 +335,16 @@ void _validate_compressed_sparse_indices_kernel(
// NOTE: the implementation below is sync-less, but,
// unfortunately, work is not guaranteed to be well-balanced
// between different threads.
// Note: 5.6 should not be tested when
// nnz==0. Fortunately, the code below is no-op when
// nnz==0.
int64_t idx_offset = 0;
// assuming idx contiguity per batch:
int64_t tmp = batch_idx * idx_sizes[idx_ndims - 1];
for (int i = idx_ndims - 1; i >= 0; i--) {
int64_t tmp = batch_idx * nnz;
// `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)`
for (int i = idx_ndims - 1;
i >= 0 && nnz > 0; // break early when nnz==0
i--) {
int64_t div = tmp / idx_sizes[i];
idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i];
tmp = div;

View File

@ -590,6 +590,11 @@ class TestSparseCompressed(TestCase):
layout
)
compressed_indices = torch.tensor([0, 0], dtype=index_dtype)
plain_indices = torch.tensor([], dtype=index_dtype)
torch._validate_compressed_sparse_indices(layout in {torch.sparse_csr, torch.sparse_bsr},
compressed_indices, plain_indices, 1, 1, 0)
def _generate_invalid_input(self, layout, device):
from functools import partial