mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix a crash in sparse compressed tensor invariants check when nnz == 0 (#115825)
Fixes python crash example from https://github.com/pytorch/pytorch/issues/115755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115825 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
eafeba71c1
commit
419f2ca3e3
@ -126,8 +126,8 @@ INVARIANT_CHECK_FUNC_API _check_idx_sorted_distinct_vals_slices_with_cidx(
|
||||
// Note that ptr_idx_batch = &idx[batch_idx] and is contiguous.
|
||||
const auto* RESTRICT slice_begin = ptr_idx_batch + cidx;
|
||||
const auto* RESTRICT slice_end = ptr_idx_batch + cidx_next;
|
||||
for (auto* RESTRICT curr = slice_begin + 1; curr < slice_end; ++curr) {
|
||||
const auto invariant = *(curr - 1) < *curr;
|
||||
for (auto* RESTRICT curr = slice_begin; (slice_begin < slice_end) && (curr + 1 < slice_end); ++curr) {
|
||||
const auto invariant = *curr < *(curr + 1);
|
||||
if (cdim_name == CDimName::CRow) {
|
||||
_assert(
|
||||
invariant,
|
||||
@ -335,10 +335,16 @@ void _validate_compressed_sparse_indices_kernel(
|
||||
// NOTE: the implementation below is sync-less, but,
|
||||
// unfortunately, work is not guaranteed to be well-balanced
|
||||
// between different threads.
|
||||
// Note: 5.6 should not be tested when
|
||||
// nnz==0. Fortunately, the code below is no-op when
|
||||
// nnz==0.
|
||||
int64_t idx_offset = 0;
|
||||
// assuming idx contiguity per batch:
|
||||
int64_t tmp = batch_idx * idx_sizes[idx_ndims - 1];
|
||||
for (int i = idx_ndims - 1; i >= 0; i--) {
|
||||
int64_t tmp = batch_idx * nnz;
|
||||
// `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)`
|
||||
for (int i = idx_ndims - 1;
|
||||
i >= 0 && nnz > 0; // break early when nnz==0
|
||||
i--) {
|
||||
int64_t div = tmp / idx_sizes[i];
|
||||
idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i];
|
||||
tmp = div;
|
||||
|
@ -590,6 +590,11 @@ class TestSparseCompressed(TestCase):
|
||||
layout
|
||||
)
|
||||
|
||||
compressed_indices = torch.tensor([0, 0], dtype=index_dtype)
|
||||
plain_indices = torch.tensor([], dtype=index_dtype)
|
||||
torch._validate_compressed_sparse_indices(layout in {torch.sparse_csr, torch.sparse_bsr},
|
||||
compressed_indices, plain_indices, 1, 1, 0)
|
||||
|
||||
def _generate_invalid_input(self, layout, device):
|
||||
from functools import partial
|
||||
|
||||
|
Reference in New Issue
Block a user