Faster index_select for sparse COO tensors on CPU. (#72710)

Fixes https://github.com/pytorch/pytorch/issues/72212.

This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible.

Benchmark results.

<details>

<summary>Testing script</summary>

```python
import torch
import math
from IPython import get_ipython
from itertools import product
import pickle
from torch.utils.benchmark import Timer, Compare

torch.manual_seed(13)
#torch.set_num_threads(1)
ipython = get_ipython()

index_sizes = (100, 1000, 10000)
# specifies (n, nnz)
problem_dims = (
    # n > nnz
    (10000, 100),
    (100000, 1000),
    (1000000, 10000),
    # n < nnz
    (10, 100),
    (10, 1000),
    (10, 10000),
    (100, 1000),
    (100, 10000),
    (1000, 10000),
    (1000, 100000),
    (1000, 1000000),
    #(1000000, 1000000000),
)

def f(t, d, index):
    s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t)
    ss = s.index_select(d, index)
    return ss.coo()

name = "PR"
results = []

for (n, nnz), m in product(problem_dims, index_sizes):
    for d in (0, 1):
        if nnz < n:
            shape = (n, n)
        else:
            shape = (n, nnz // n) if d == 0 else (nnz // n, n)
        nrows, ncols = shape
        rowidx = torch.randint(low=0, high=nrows, size=(nnz,))
        colidx = torch.randint(low=0, high=ncols, size=(nnz,))
        itemidx = torch.vstack((rowidx, colidx))
        xvalues = torch.randn(nnz)
        index = torch.randint(low=0, high=n, size=(m,))

        SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce()
        smtp = "SparseX.index_select(d, index)"
        timer = Timer(smtp,
                      globals=globals(),
                      label="coo.index_select",
                      description=f"{name}: coo.index_select",
                      sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}",
                      num_threads=torch.get_num_threads())
        results.append(timer.blocked_autorange())

compare = Compare(results)
compare.trim_significant_figures()
compare.print()

with open(f"{name}_index_select.pickle", 'wb') as f:
    pickle.dump(results, f)

```

</details>

<details>

<summary>Gather results</summary>

```python
import pickle
from torch.utils.benchmark import Timer, Compare

files = [
        "PR",
        "torch_sparse",
        "master"
        ]

timers = []
for name in files:
    with open("{}_index_select.pickle".format(name), 'rb') as f:
        timers += pickle.load(f)

compare = Compare(timers)
compare.trim_significant_figures()
compare.print()

```

</details>

<details>

<summary>PR/torch_sparse/master runtime comparison</summary>

```
[----------------------------------- coo.index_select ----------------------------------]
                                                    |    PR   |  torch_sparse  |   master
32 threads: -----------------------------------------------------------------------------
      n=10000, nnz=100, index_len=100, dim=0        |     14  |        140     |       10
      n=10000, nnz=100, index_len=100, dim=1        |     14  |        200     |       10
      n=10000, nnz=100, index_len=1000, dim=0       |     30  |        180     |       38
      n=10000, nnz=100, index_len=1000, dim=1       |     34  |        240     |       38
      n=10000, nnz=100, index_len=10000, dim=0      |    278  |        460     |      330
      n=10000, nnz=100, index_len=10000, dim=1      |    275  |        516     |      330
      n=100000, nnz=1000, index_len=100, dim=0      |     16  |        290     |       31
      n=100000, nnz=1000, index_len=100, dim=1      |     26  |        390     |       31
      n=100000, nnz=1000, index_len=1000, dim=0     |     45  |        405     |      263
      n=100000, nnz=1000, index_len=1000, dim=1     |     73  |        500     |      261
      n=100000, nnz=1000, index_len=10000, dim=0    |    444  |        783     |     2570
      n=100000, nnz=1000, index_len=10000, dim=1    |    470  |        890     |     2590
      n=1000000, nnz=10000, index_len=100, dim=0    |     25  |       2400     |      270
      n=1000000, nnz=10000, index_len=100, dim=1    |    270  |       4000     |      269
      n=1000000, nnz=10000, index_len=1000, dim=0   |     74  |       2600     |     2620
      n=1000000, nnz=10000, index_len=1000, dim=1   |    464  |       3600     |     2640
      n=1000000, nnz=10000, index_len=10000, dim=0  |    635  |       3300     |    26400
      n=1000000, nnz=10000, index_len=10000, dim=1  |   1000  |       3960     |    26400
      n=10, nnz=100, index_len=100, dim=0           |     16  |        137     |       16
      n=10, nnz=100, index_len=100, dim=1           |     16  |        220     |       16
      n=10, nnz=100, index_len=1000, dim=0          |     63  |        238     |       81
      n=10, nnz=100, index_len=1000, dim=1          |     60  |        698     |       78
      n=10, nnz=100, index_len=10000, dim=0         |    480  |        940     |      862
      n=10, nnz=100, index_len=10000, dim=1         |    330  |       4930     |     1070
      n=10, nnz=1000, index_len=100, dim=0          |     60  |        200     |       73
      n=10, nnz=1000, index_len=100, dim=1          |     56  |        683     |       70
      n=10, nnz=1000, index_len=1000, dim=0         |    480  |        530     |     1050
      n=10, nnz=1000, index_len=1000, dim=1         |    330  |       4550     |     1368
      n=10, nnz=1000, index_len=10000, dim=0        |   3100  |       2900     |     9300
      n=10, nnz=1000, index_len=10000, dim=1        |   3400  |      46000     |     9100
      n=10, nnz=10000, index_len=100, dim=0         |    400  |        453     |      857
      n=10, nnz=10000, index_len=100, dim=1         |    400  |       4070     |     1730
      n=10, nnz=10000, index_len=1000, dim=0        |   2840  |       2600     |    13900
      n=10, nnz=10000, index_len=1000, dim=1        |   3700  |      40600     |    16000
      n=10, nnz=10000, index_len=10000, dim=0       |  83200  |      67400     |   160000
      n=10, nnz=10000, index_len=10000, dim=1       |  68000  |     528000     |   190000
      n=100, nnz=1000, index_len=100, dim=0         |     46  |        148     |       31
      n=100, nnz=1000, index_len=100, dim=1         |     45  |        242     |       37
      n=100, nnz=1000, index_len=1000, dim=0        |     68  |        248     |      240
      n=100, nnz=1000, index_len=1000, dim=1        |     66  |        755     |      290
      n=100, nnz=1000, index_len=10000, dim=0       |    370  |        802     |     2250
      n=100, nnz=1000, index_len=10000, dim=1       |    372  |       5430     |     2770
      n=100, nnz=10000, index_len=100, dim=0        |     82  |        210     |      224
      n=100, nnz=10000, index_len=100, dim=1        |     74  |        986     |      270
      n=100, nnz=10000, index_len=1000, dim=0       |    350  |        618     |     2600
      n=100, nnz=10000, index_len=1000, dim=1       |    370  |       4660     |     4560
      n=100, nnz=10000, index_len=10000, dim=0      |   3000  |       3400     |    41680
      n=100, nnz=10000, index_len=10000, dim=1      |   5000  |      47500     |    30400
      n=1000, nnz=10000, index_len=100, dim=0       |     71  |        160     |      185
      n=1000, nnz=10000, index_len=100, dim=1       |     64  |        516     |      190
      n=1000, nnz=10000, index_len=1000, dim=0      |    100  |        249     |     1740
      n=1000, nnz=10000, index_len=1000, dim=1      |     98  |       1030     |     1770
      n=1000, nnz=10000, index_len=10000, dim=0     |    600  |        808     |    18300
      n=1000, nnz=10000, index_len=10000, dim=1     |    663  |       5300     |    18500
      n=1000, nnz=100000, index_len=100, dim=0      |    160  |        258     |     1890
      n=1000, nnz=100000, index_len=100, dim=1      |    200  |       3620     |     2050
      n=1000, nnz=100000, index_len=1000, dim=0     |    500  |        580     |    18700
      n=1000, nnz=100000, index_len=1000, dim=1     |    640  |       7550     |    30000
      n=1000, nnz=100000, index_len=10000, dim=0    |   3400  |       3260     |   186000
      n=1000, nnz=100000, index_len=10000, dim=1    |   3600  |      49600     |   194000
      n=1000, nnz=1000000, index_len=100, dim=0     |    517  |        957     |    18700
      n=1000, nnz=1000000, index_len=100, dim=1     |    680  |      39600     |    37600
      n=1000, nnz=1000000, index_len=1000, dim=0    |   3600  |       4500     |   186000
      n=1000, nnz=1000000, index_len=1000, dim=1    |   5800  |      76400     |   190000
      n=1000, nnz=1000000, index_len=10000, dim=0   |  50000  |      67900     |  1800000
      n=1000, nnz=1000000, index_len=10000, dim=1   |  45000  |     570000     |  1900000

Times are in microseconds (us).

```

</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72710
Approved by: https://github.com/pearu, https://github.com/cpuhrsch
This commit is contained in:
Nikita Vedeneev
2022-05-09 19:59:37 +00:00
committed by PyTorch MergeBot
parent 3e4bff7285
commit ce3857e73c
4 changed files with 706 additions and 63 deletions

View File

@ -1356,7 +1356,7 @@ Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim,
return grad_input;
}
Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) {
Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) {
/*
Algorithm:
index - a 1-D tensor of indicies with shape (n,)
@ -1369,79 +1369,622 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
new_values - shape is (new_nnz,) + dense_shape
if dim < len(sparse_shape):
for i, idx in enumerate(index):
for j, jdx in enumerate(indices[dim]):
if idx == jdx:
icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
new_indices.add_column(icol)
new_values.add_row(values[j])
# Find new_indices[dim] of the output sparse tensor and
# indices at which to select values/indices.
# The CPP code uses (binary/in a count table) search to find matches and may
# swap the loop order for better algorithmic complexity.
new_dim_indices = []
selected_dim_indices = []
# This is a brute-force algorithms to convey the main idea.
# The CPP code below is more efficient but more complicated.
for i, i_idx in enumerate(indices[dim]):
for j, j_idx in enumerate(index):
if i_idx == j_idx:
new_dim_indices.append(j)
selected_dim_indices.append(i)
new_indices = indices.index_select(1, selected_dim_indices)
new_values = values.index_select(0, selected_dim_indices)
new_indices[dim] = new_dim_indices
else:
new_indices = indices
new_values[k] = values[k].index_select(dim - len(sparse_shape), index) for k in range(nnz)
new_values = values.index_select(dim - sparse_dim + 1, index);
*/
auto ndim = self.dim();
if (ndim == 0) {
TORCH_CHECK_INDEX(false, "index_select() cannot be applied to a 0-dim tensor.");
}
if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
TORCH_CHECK_INDEX(false, "index_select() argument index must be 1-D long-tensor.");
}
const auto ndim = self.dim();
TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor.");
TORCH_CHECK_INDEX(
index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided,
"index_select() argument index must be 1-D strided (non-sparse) long-tensor.");
dim = maybe_wrap_dim(dim, ndim);
auto size = self.size(dim);
auto sparse_dim = self.sparse_dim();
auto dense_dim = self.dense_dim();
auto indices = self._indices();
auto values = self._values();
auto nnz = values.size(0);
auto new_sizes = self.sizes().vec();
new_sizes[dim] = index.size(0);
const auto size = self.size(dim);
const auto sparse_dim = self.sparse_dim();
const auto dense_dim = self.dense_dim();
const auto indices = self._indices();
const auto values = self._values();
const auto nnz = values.size(0);
const auto index_len = index.size(0);
auto res_sizes = self.sizes().vec();
res_sizes[dim] = index_len;
// Equivalent to t.index_select(dim, idx), but vanilla index_select is not parallel,
// so we use gather instead.
// We use this method to select relevant indices/values
// from the intersection between indices[dim] and the index.
const auto index_select = [](const Tensor& t, int64_t dim, const Tensor& idx) -> Tensor {
const auto idx_len = idx.numel();
auto out_shape = t.sizes().vec();
out_shape[dim] = idx_len;
auto idx_shape = std::vector<int64_t>(t.dim(), 1);
idx_shape[dim] = idx_len;
return t.gather(dim, idx.view(idx_shape).expand(out_shape));
};
// If indexing into sparse dimensions
if (dim < sparse_dim) {
// short-circuit if index is empty
if (!index_len) {
auto res_indices = index_select(indices, 1, index);
res_indices[dim] = index;
const auto res_values = index_select(values, 0, index);
auto cpu_dim_indices = indices[dim].to(c10::kCPU).contiguous();
int64_t* cpu_dim_indices_ptr = cpu_dim_indices.data_ptr<int64_t>();
auto cpu_index = index.to(c10::kCPU).contiguous();
int64_t* cpu_index_ptr = cpu_index.data_ptr<int64_t>();
std::vector<int64_t> zindices;
std::vector<int64_t> iindices;
int64_t new_nnz = 0;
for (const auto i : c10::irange(new_sizes[dim])) {
int64_t idx = cpu_index_ptr[i];
if (idx < -size || idx >= size) {
TORCH_CHECK_INDEX(false, "index_select(): index contains ", idx, " that is out of range for tensor of size ",
self.sizes(), " at dimension ", dim);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
}
const auto nneg_index = [&index, index_len, &self, size, dim](void) -> Tensor {
const auto index_contiguous = index.contiguous();
auto nneg_index = at::empty_like(index_contiguous);
// nneg_index = (index < 0) * (index + size) + (index >= 0) * index
auto* ptr_index = index_contiguous.data_ptr<int64_t>();
auto* ptr_nneg_index = nneg_index.data_ptr<int64_t>();
at::parallel_for(0, index_len, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
const auto* src = ptr_index + start;
auto* dst = ptr_nneg_index + start;
for (C10_UNUSED const auto _ : c10::irange(start, end)) {
auto idx = *src++;
if (idx < -size || idx >= size) {
TORCH_CHECK_INDEX(false,
"index_select(): index contains ", idx, " that is out of range for tensor of size ",
self.sizes(), " at dimension ", dim
);
}
if (idx < 0) {
idx += size;
}
*dst++ = idx;
}
});
return nneg_index;
}();
const auto dim_indices = indices[dim].contiguous();
// If nnz is smaller than size, then either indices[dim] or index gets sorted,
// then this is followed by a binary search to find interesections.
const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple<Tensor, Tensor> {
const auto grain_size = at::internal::GRAIN_SIZE;
const auto n_threads_nnz = std::max<int64_t>(
1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto n_threads_index = std::max<int64_t>(
1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto search_in_dim_indices
// if either dim_indices or index requires sorting, we compare
// the cost of sort + binary search, which is comparing
// (len(dim_indices) + len(index)) * log(len(index)) to
// (len(dim_indices) + len(index)) * log(len(dim_indices)).
// That simplifies to comparing len(dim_indices) to len(index).
// Additionally, we take into consideration potential parallel
// speedup.
= (nnz / n_threads_nnz <= index_len / n_threads_index)
// if self is coalesced and dim is 0, then we compare
// index_len * log(len(dim_indices)), which is binary search into dim_indices,
// to (len(index_len) + len(dim_indices)) * log(index_len).
// Additionally, we take into consideration potential parallel
// speedup.
|| (self.is_coalesced() && dim == 0
&& (index_len * std::log2(nnz) / n_threads_index
<= (nnz / n_threads_nnz + index_len) * std::log2(index_len)))
? true : false;
// src is a source of indices to binary search in sorted
Tensor sorted, sorted_idx, src;
std::tie(sorted, sorted_idx, src) = [
&dim_indices, &nneg_index, &self,
search_in_dim_indices, dim, nnz
](void) -> std::tuple<Tensor, Tensor, Tensor> {
// sort dim_indices to binary search into it
if (search_in_dim_indices) {
// dim_indices is already sorted if self is coalesced and dim == 0
if (self.is_coalesced() && dim == 0) {
return std::make_tuple(dim_indices, at::arange(nnz, dim_indices.options()), nneg_index);
}
else {
Tensor sorted_dim_indices, sorted_dim_indices_idx;
std::tie(sorted_dim_indices, sorted_dim_indices_idx) = dim_indices.sort();
return std::make_tuple(sorted_dim_indices, sorted_dim_indices_idx, nneg_index);
}
}
// sort nneg_index to binary search into it
else {
Tensor sorted_nneg_index, sorted_nneg_index_idx;
std::tie(sorted_nneg_index, sorted_nneg_index_idx) = nneg_index.sort();
return std::make_tuple(sorted_nneg_index, sorted_nneg_index_idx, dim_indices);
}
}();
const auto src_grain_size = at::internal::GRAIN_SIZE;
const auto src_len = src.numel();
const auto n_threads_src = std::max<int64_t>(
// 1 <= n_threads_src <= std::min(ceil(src.numel() / src_grain_size), max_threads)
1, std::min<int64_t>((src_len + src_grain_size - 1) / src_grain_size, at::get_num_threads())
);
const auto chunk_size_src = (src_len + n_threads_src - 1) / n_threads_src;
const std::vector<int64_t> src_n_threads_shape = {
n_threads_src, (src_len + n_threads_src - 1) / n_threads_src
};
// src_int_idx and sorted_int_idx store "i" and "j" indices indicating
// intersections such that src_int_idx[i] == sorted_int_idx[j].
// These intersections are found with binary search and in parallel.
auto src_int_idx = at::empty(src_n_threads_shape, src.options());
auto sorted_int_idx = at::empty_like(src_int_idx);
// For each element "i" from src, int_counts define how many
// elements there are in sorted, i.e. "j" indices, corresponding
// to "i", i.e.:
// |{j : src_int_idx[i] == sorted_int_idx[j]}| for each i in src_int_idx.
auto int_counts = at::zeros_like(src_int_idx);
// fill in src_int_idx, sorted_int_idx, int_counts
{
const auto sorted_len = sorted.numel();
const auto* ptr_sorted = sorted.data_ptr<int64_t>();
const auto* ptr_sorted_start = ptr_sorted;
const auto* ptr_sorted_end = ptr_sorted + sorted_len;
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
const auto start = tid * chunk_size_src;
const auto end = std::min(start + chunk_size_src, src_len);
auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
const auto* ptr_src = src.data_ptr<int64_t>() + start;
for (const auto i : c10::irange(start, end)) {
const auto src_val = *ptr_src++;
const auto src_val_lb = std::lower_bound(ptr_sorted_start, ptr_sorted_end, src_val);
// We cannot just use *src_val_lb != src_val because when
// src_val_lb == ptr_sorted_end, dereferencing past-the-end value
// is not well-defined.
if (src_val_lb == ptr_sorted_end || *src_val_lb != src_val) {
++ptr_tid_src_int_idx;
++ptr_tid_sorted_int_idx;
++ptr_tid_int_counts;
continue;
}
const auto src_val_ub = std::upper_bound(ptr_sorted_start, ptr_sorted_end, src_val);
const int64_t count = src_val_ub - src_val_lb;
const int64_t j = src_val_lb - ptr_sorted_start;
*ptr_tid_src_int_idx++ = i;
*ptr_tid_sorted_int_idx++ = j;
*ptr_tid_int_counts++ = count;
}
});
}
if (idx < 0) {
idx += size;
const auto compressed_int_counts = int_counts.sum(-1);
const auto res_len = compressed_int_counts.sum().item<int64_t>();
// Short-circuit if empty intersection
if (!res_len) {
auto empty_idx = at::empty({0}, src.options());
return std::make_tuple(empty_idx, empty_idx);
}
for (const auto j : c10::irange(nnz)) {
int64_t jdx = cpu_dim_indices_ptr[j];
if (idx == jdx) {
new_nnz++;
iindices.push_back(i);
zindices.push_back(j);
// Now that we know "i", "j" and the counts, we "unflatten"
// them into two arrays of intersection indices such that
// selected_src = repeat_interleave(src_int_idx, int_counts),
// and selected_sorted is obtained as follows:
// offsets = int_counts.cumsum(0).sub_(int_counts)
// for ii, (j, c) in enumerate(zip(sorted_int_idx, int_counts)):
// out_slice = slice(offsets[ii], offsets[ii] + c)
// src_slice = slice(j, j + c)
// selected_sorted[out_slice] = sorted_int_idx[src_slice]
auto selected_sorted = at::empty({res_len}, sorted.options());
auto selected_src = at::empty({res_len}, src.options());
// fill in selected_sorted, selected_src
{
auto* ptr_selected_sorted = selected_sorted.data_ptr<int64_t>();
auto* ptr_selected_src = selected_src.data_ptr<int64_t>();
const auto thread_offsets = compressed_int_counts.cumsum(0).sub_(compressed_int_counts);
const auto* ptr_sorted_idx = sorted_idx.data_ptr<int64_t>();
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
const auto start = tid * chunk_size_src;
const auto end = std::min(start + chunk_size_src, src_len);
const auto tid_offset = thread_offsets.data_ptr<int64_t>()[tid];
const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
const auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset;
auto* ptr_tid_selected_src = ptr_selected_src + tid_offset;
for (C10_UNUSED const auto _ : c10::irange(start, end)) {
const auto count = *ptr_tid_int_counts++;
const auto i = *ptr_tid_src_int_idx++;
const auto j = *ptr_tid_sorted_int_idx++;
if (!count) continue;
std::fill_n(ptr_tid_selected_src, count, i);
std::copy_n(ptr_sorted_idx + j, count, ptr_tid_selected_sorted);
ptr_tid_selected_sorted += count;
ptr_tid_selected_src += count;
}
});
}
return search_in_dim_indices
? std::make_tuple(selected_sorted, selected_src)
: std::make_tuple(selected_src, selected_sorted);
};
// Converts a 1d sorted idx to a compressed 1d compressed idx,
// aka crow in the CSR format. Useful to get a count table in
// a parallelized and no-sync manner.
// TODO: this function is equivalent to _convert_indices_from_coo_to_csr.
// The mentioned function is not public yet.
const auto sorted_idx_to_cidx = [](
const Tensor& idx,
int64_t len,
bool run_in_parallel = true) -> Tensor {
auto cidx = at::empty({len + 1}, idx.options());
const auto* ptr_idx = idx.data_ptr<int64_t>();
auto* ptr_cidx = cidx.data_ptr<int64_t>();
const auto idx_len = idx.numel();
std::fill_n(ptr_cidx, ptr_idx[0] + 1, 0);
std::fill_n(ptr_cidx + ptr_idx[idx_len - 1] + 1, len - ptr_idx[idx_len - 1], idx_len);
const auto grain_size = run_in_parallel ? at::internal::GRAIN_SIZE : idx_len;
at::parallel_for(0, idx_len, grain_size, [&](int64_t start, int64_t end) {
auto* ptr_curr_cidx = ptr_cidx + ptr_idx[start] + 1;
for (int64_t i = start; i < std::min(end, idx_len - 1); ++i) {
const auto diff = ptr_idx[i + 1] - ptr_idx[i];
std::fill_n(ptr_curr_cidx, diff, i + 1);
ptr_curr_cidx += diff;
}
});
return cidx;
};
// If nnz is (much) larger than size, then both indices[dim] and index get sorted
// with a count sort (faster, and no huge nnz-sized chunk memory allocations).
// The element-wise product between the count tables gives us all the intersections.
const auto get_selected_indices_large_nnz_small_size = [&]() -> std::tuple<Tensor, Tensor> {
const auto get_counts = [&sorted_idx_to_cidx](
// Writes into counts (must be preallocated and zero)
// and allows to use external buffers.
Tensor& counts,
const Tensor& t,
int64_t bins,
bool is_sorted = false,
bool run_in_parallel = true) -> void {
if (is_sorted) {
const auto cidx = sorted_idx_to_cidx(t, bins, run_in_parallel);
at::sub_out(counts, cidx.slice(0, 1, bins + 1), cidx.slice(0, 0, bins));
}
else {
auto* ptr_counts = counts.data_ptr<int64_t>();
const auto* ptr_vals = t.data_ptr<int64_t>();
for (C10_UNUSED const auto _ : c10::irange(t.numel())) {
++ptr_counts[*ptr_vals++];
}
}
};
const auto counts_per_thread = [&get_counts, size](
const Tensor& idx,
bool is_sorted = false,
int64_t grain_size = at::internal::GRAIN_SIZE
) -> Tensor {
const auto idx_len = idx.numel();
// 1 <= n_threads <= min(ceil(len / grain_size), max_threads)
const auto n_threads = std::max<int64_t>(
1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto chunk_size = (idx_len + n_threads - 1) / n_threads;
const auto run_in_parallel = (n_threads == 1);
auto counts_per_thread = at::zeros({n_threads, size}, idx.options());
at::parallel_for(0, n_threads, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
const auto start = tid * chunk_size;
const auto end = std::min(start + chunk_size, idx_len);
const auto tid_idx = idx.slice(0, start, end);
auto tid_counts = counts_per_thread.select(0, tid);
get_counts(tid_counts, tid_idx, /*bins=*/size,
/*is_sorted=*/is_sorted, /*run_in_parallel=*/run_in_parallel);
});
return counts_per_thread;
};
auto dim_indices_counts_per_thread = counts_per_thread(
dim_indices,
/*is_sorted=*/self.is_coalesced() && dim == 0
/*grain_size = at::internal::GRAIN_SIZE*/
);
auto dim_indices_offset_counts_per_thread = dim_indices_counts_per_thread.cumsum(0);
auto index_counts_per_thread = counts_per_thread(
nneg_index,
/*is_sorted=*/false
/*grain_size = at::internal::GRAIN_SIZE*/
);
auto index_offset_counts_per_thread = index_counts_per_thread.cumsum(0);
const auto index_counts = index_offset_counts_per_thread.select(0, -1);
const auto dim_indices_counts = dim_indices_offset_counts_per_thread.select(0, -1);
const auto intersection_counts = index_counts.mul(dim_indices_counts);
const auto res_len = intersection_counts.sum().item<int64_t>();
// Short-circuit if empty intersection
if (!res_len) {
auto empty_idx = at::empty({0}, index.options());
return std::make_tuple(empty_idx, empty_idx);
}
const auto intersection_offsets = intersection_counts.cumsum(0);
const auto search_in_dim_indices = [&]() -> bool {
const auto grain_size = at::internal::GRAIN_SIZE;
const auto n_threads_index = std::max<int64_t>(
1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto n_threads_dim_indices = std::max<int64_t>(
1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto index_max_copy_work_per_thread =
index_counts_per_thread.mul(dim_indices_counts).sum(-1).max().item<int64_t>();
const auto dim_indices_max_copy_work_per_thread
= dim_indices_counts_per_thread.mul(index_counts).sum(-1).max().item<int64_t>();
const auto index_max_work_per_thread = index_max_copy_work_per_thread * index_len / n_threads_index;
const auto dim_indices_max_work_per_thread = dim_indices_max_copy_work_per_thread * nnz / n_threads_dim_indices;
return index_max_work_per_thread <= dim_indices_max_work_per_thread
? true
: false;
}();
Tensor idx, idx_counts_per_thread, idx_offset_counts_per_thread;
Tensor src, src_counts_per_thread, src_offset_counts_per_thread;
std::tie(
idx, idx_counts_per_thread, idx_offset_counts_per_thread,
src, src_counts_per_thread, src_offset_counts_per_thread
) = [&]() {
return search_in_dim_indices
? std::make_tuple(
nneg_index, index_counts_per_thread, index_offset_counts_per_thread,
dim_indices, dim_indices_counts_per_thread, dim_indices_offset_counts_per_thread
)
: std::make_tuple(
dim_indices, dim_indices_counts_per_thread, dim_indices_counts_per_thread.cumsum(0),
nneg_index, index_counts_per_thread, index_counts_per_thread.cumsum(0)
);
}();
const auto idx_counts = idx_offset_counts_per_thread.select(0, -1);
const auto src_counts = src_offset_counts_per_thread.select(0, -1);
Tensor src_idx, src_idx_offsets;
std::tie(src_idx, src_idx_offsets) = [&](
int64_t grain_size = at::internal::GRAIN_SIZE
) -> std::tuple<Tensor, Tensor> {
const auto src_intersection_counts = src_counts.mul(idx_counts > 0);
const auto src_intersection_offsets = src_intersection_counts.cumsum(0);
const auto src_idx_len = src_intersection_offsets.data_ptr<int64_t>()[size - 1];
auto src_idx = at::empty({src_idx_len}, src.options());
const auto* ptr_src = src.data_ptr<int64_t>();
const auto* ptr_intersection_counts = intersection_counts.data_ptr<int64_t>();
const auto* ptr_src_intersection_counts = src_intersection_counts.data_ptr<int64_t>();
const auto* ptr_src_intersection_offsets = src_intersection_offsets.data_ptr<int64_t>();
auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
const auto src_len = src.numel();
const auto n_threads_src = std::max<int64_t>(
1, std::min<int64_t>((src_len + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto chunk_size = (src_len + n_threads_src - 1) / n_threads_src;
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
const auto start = tid * chunk_size;
const auto end = std::min(start + chunk_size, src_len);
auto* ptr_src_tid = ptr_src + start;
const auto* ptr_src_counts_per_thread
= src_counts_per_thread.select(0, tid).data_ptr<int64_t>();
const auto* ptr_src_offset_counts_per_thread
= src_offset_counts_per_thread.select(0, tid).data_ptr<int64_t>();
auto tid_counts = at::zeros({size}, src.options());
auto* ptr_tid_counts = tid_counts.data_ptr<int64_t>();
for (const auto i : c10::irange(start, end)) {
const auto idx_val = *ptr_src_tid++;
// skip idx value if not in the intersection
if (!ptr_intersection_counts[idx_val]) continue;
const auto idx_val_offset
= ptr_src_intersection_offsets[idx_val]
- ptr_src_intersection_counts[idx_val];
const auto idx_val_tid_offset
= ptr_src_offset_counts_per_thread[idx_val]
- ptr_src_counts_per_thread[idx_val];
auto& idx_val_local_tid_count = ptr_tid_counts[idx_val];
ptr_src_idx[idx_val_offset + idx_val_tid_offset + idx_val_local_tid_count] = i;
++idx_val_local_tid_count;
}
});
const auto src_idx_offsets = src_intersection_offsets.sub_(src_intersection_counts);
return std::make_tuple(src_idx, src_idx_offsets);
}();
Tensor idx_selected, src_selected;
std::tie(idx_selected, src_selected) = [&](
int64_t grain_size = at::internal::GRAIN_SIZE
) -> std::tuple<Tensor, Tensor> {
const auto thread_offset = [&]() {
// we do not need idx_counts_per_thread anymore,
// so it is safe to do in-place intersection.
auto counts_per_thread = idx_counts_per_thread.mul_(src_counts).sum(-1);
return counts_per_thread.cumsum(0).sub_(counts_per_thread);
}();
const auto* ptr_thread_offset = thread_offset.data_ptr<int64_t>();
auto idx_selected = at::empty({res_len}, idx.options());
auto src_selected = at::empty({res_len}, src.options());
const auto* ptr_idx = idx.data_ptr<int64_t>();
const auto* ptr_src_counts = src_counts.data_ptr<int64_t>();
const auto* ptr_intersection_counts = intersection_counts.data_ptr<int64_t>();
const auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
const auto* ptr_src_idx_offsets = src_idx_offsets.data_ptr<int64_t>();
auto* ptr_idx_selected = idx_selected.data_ptr<int64_t>();
auto* ptr_src_selected = src_selected.data_ptr<int64_t>();
const auto idx_len = idx.numel();
const auto n_threads_idx = std::max<int64_t>(
1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
);
const auto chunk_size = (idx_len + n_threads_idx - 1) / n_threads_idx;
at::parallel_for(0, n_threads_idx, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
const auto start = tid * chunk_size;
const auto end = std::min(start + chunk_size, idx_len);
const auto tid_offset = ptr_thread_offset[tid];
const auto* ptr_idx_tid = ptr_idx + start;
auto* ptr_idx_selected_tid = ptr_idx_selected + tid_offset;
auto* ptr_src_selected_tid = ptr_src_selected + tid_offset;
for (const auto i : c10::irange(start, end)) {
const auto idx_val = *ptr_idx_tid++;
// skip if idx_val is not in the intersection
if (!ptr_intersection_counts[idx_val]) continue;
const auto count = ptr_src_counts[idx_val];
const auto j = ptr_src_idx_offsets[idx_val];
std::fill_n(ptr_idx_selected_tid, count, i);
std::copy_n(ptr_src_idx + j, count, ptr_src_selected_tid);
ptr_idx_selected_tid += count;
ptr_src_selected_tid += count;
}
});
return std::make_tuple(idx_selected, src_selected);
}();
return search_in_dim_indices
? std::make_tuple(src_selected, idx_selected)
: std::make_tuple(idx_selected, src_selected);
};
const auto make_output = [&](
const Tensor& selected_dim_indices,
const Tensor& res_dim_indices) -> Tensor {
auto res_indices = index_select(indices, 1, selected_dim_indices);
res_indices[dim] = res_dim_indices;
const auto res_values = index_select(values, 0, selected_dim_indices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
};
// Brute-force solution for small values of nnz and index_len
const auto get_result_small_nnz_small_index = [&]()
-> Tensor {
const auto dim_indices_in_inner_loop = nnz >= index_len;
Tensor outer, inner;
std::tie(outer, inner) = [&]() -> std::tuple<Tensor, Tensor> {
if (dim_indices_in_inner_loop) {
return std::make_tuple(nneg_index, dim_indices);
}
else {
return std::make_tuple(dim_indices, nneg_index);
}
}();
const auto* ptr_outer = outer.data_ptr<int64_t>();
const auto* ptr_inner = inner.data_ptr<int64_t>();
// NOTE: if very critical, replace std::vector with
// a data structure that operates on stack up to some limit.
auto outer_selected_idx = std::vector<int64_t>();
auto inner_selected_idx = std::vector<int64_t>();
int64_t res_len = 0;
for (const auto i : c10::irange(outer.numel())) {
for (const auto j : c10::irange(inner.numel())) {
if (ptr_outer[i] == ptr_inner[j]) {
++res_len;
outer_selected_idx.push_back(i);
inner_selected_idx.push_back(j);
}
}
}
const auto outer_selected_idx_tensor = at::from_blob(
outer_selected_idx.data(), {res_len}, at::kLong
);
const auto inner_selected_idx_tensor = at::from_blob(
inner_selected_idx.data(), {res_len}, at::kLong
);
return dim_indices_in_inner_loop
? make_output(inner_selected_idx_tensor, outer_selected_idx_tensor)
: make_output(outer_selected_idx_tensor, inner_selected_idx_tensor);
};
constexpr int64_t BRUTE_FORCE_SIZE_LIMIT = 2 << 14; // 16384
// NOTE: such a condition to avoid overflows in (nnz * index_len)
if (nnz <= BRUTE_FORCE_SIZE_LIMIT && index_len <= BRUTE_FORCE_SIZE_LIMIT
&& (nnz * index_len) <= BRUTE_FORCE_SIZE_LIMIT) {
return get_result_small_nnz_small_index();
}
auto zIndices = at::from_blob(zindices.data(), {new_nnz}, at::kLong).to(indices.device());
auto new_indices = indices.index_select(1, zIndices);
new_indices[dim] = at::from_blob(iindices.data(), {new_nnz}, at::kLong).to(indices.device());
auto new_values = values.index_select(0, zIndices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, new_sizes, new_indices, new_values, self.options());
else {
Tensor selected_dim_indices;
Tensor res_dim_indices;
} else {
// A more precise decision could be of the form:
// `nnz < C(nnz, size) * size`, but it requires heavy benchmarking.
// We choose `nnz < size`, which measures theoretical complexity
// and does not rely on runtime performance.
// TODO: perform this analysis and find better C(nnz, size).
if (nnz <= size) {
std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_small_nnz_large_size();
}
else {
std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_large_nnz_small_size();
}
auto vsize = values.sizes().vec();
vsize[dim + 1 - sparse_dim] = index.size(0);
auto new_values = at::empty(vsize, values.options());
for (const auto k : c10::irange(nnz)) {
new_values[k] = values[k].index_select(dim - sparse_dim, index);
return make_output(selected_dim_indices, res_dim_indices);
}
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, new_sizes, indices, new_values, self.options());
}
// If indexing into dense dimensions
else {
// It is sufficient to just perform `index_select` on values
// if `dim` refers to dense dimensions.
const auto res_values = index_select(values, dim - sparse_dim + 1, index);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
}
}
Tensor index_select_sparse_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
auto res = index_select_sparse_cpu(self.to(at::kCPU), dim, index.to(at::kCPU));
return res.to(self.device());
}
Tensor slice(

View File

@ -7056,8 +7056,8 @@
QuantizedCPU: index_select_quantized_cpu_
CUDA: index_select_cuda
QuantizedCUDA: index_select_quantized_cuda
SparseCPU: index_select_sparse
SparseCUDA: index_select_sparse
SparseCPU: index_select_sparse_cpu
SparseCUDA: index_select_sparse_cuda
- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -157,7 +157,7 @@ class TestSparse(TestCase):
self.assertEqual(i, x._indices())
self.assertEqual(v, x._values())
self.assertEqual(x.ndimension(), len(with_size))
self.assertEqual(x.coalesce()._nnz(), nnz)
self.assertEqual(x.coalesce()._nnz(), nnz if x.is_coalesced() else nnz // 2)
self.assertEqual(list(x.size()), with_size)
# Test .indices() and .values()
@ -999,6 +999,105 @@ class TestSparse(TestCase):
test_shape(len(sizes) // 2, 10, sizes, d, index)
test_shape(len(sizes), 10, sizes, d, index)
def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coalesced):
t = make_tensor(sizes, dtype=dtype, device=device)
t_sparse = t.to_sparse().coalesce() if coalesced else t.to_sparse()
t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
t_small = t_small_sparse.to_dense()
for d in dims:
# NOTE: indices are negative
idx_dim_d_range = list(range(-sizes[d], 0))
for idx_len in range(sizes[d], sizes[d] + 1):
# creates all possible valid indices into dim d of lenght idx_len
for idx in itertools.product(*itertools.repeat(idx_dim_d_range, idx_len)):
t_idx = torch.tensor(idx, dtype=torch.long, device=device)
# NOTE: index_select for dense does not support negative indices,
# hence + sizes[d]. See https://github.com/pytorch/pytorch/issues/76347
# tests the nnz > sizes[d] branch
dense_result = t.index_select(d, t_idx + sizes[d])
sparse_result = t_sparse.index_select(d, t_idx)
self.assertEqual(dense_result, sparse_result)
# tests the nnz <= sizes[d] branch
small_dense_result = t_small.index_select(d, t_idx + sizes[d])
small_sparse_result = t_small_sparse.index_select(d, t_idx)
self.assertEqual(small_dense_result, small_sparse_result)
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
def test_index_select_exhaustive_index_small(self, device, dtype, coalesced):
# will trigger brute-force algo
self._test_index_select_exhaustive_index((3, 3, 4), range(3), device, dtype, coalesced)
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
def test_index_select_exhaustive_index_large(self, device, dtype, coalesced):
# will trigger more sophisticated algos
self._test_index_select_exhaustive_index((100, 50, 3, 3), (2, 3), device, dtype, coalesced)
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
def test_index_select_empty_and_non_contiguous_index(self, device, dtype, coalesced):
# empty index
idx_empty = torch.tensor([], dtype=torch.long, device=device)
t = make_tensor((5, 5), dtype=dtype, device=device)
res_dense = t.index_select(0, idx_empty)
res_sparse = t.to_sparse().index_select(0, idx_empty)
self.assertEqual(res_dense, res_sparse)
# non-contigous index
idx = torch.randint(low=0, high=5, size=(10, 2), device=device)[:, 0]
def run_test(sizes):
# case nnz > size[d]
t = make_tensor(sizes, dtype=dtype, device=device)
res_dense = t.index_select(0, idx)
res_sparse = t.to_sparse().index_select(0, idx)
self.assertEqual(res_dense, res_sparse)
# case nnz <= size[d]
t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
res_sparse = t_small_sparse.index_select(0, idx)
res_dense = t_small_sparse.to_dense().index_select(0, idx)
self.assertEqual(res_dense, res_sparse)
# brute-force
run_test((10, 10))
# more sophisticated algos
run_test((10, 100, 100))
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
def test_index_select_parallelization(self, device, dtype, coalesced):
"""
Test with sizes that will trigger parallelization (i.e. with sizes
that are >= at::internal::GRAIN_SIZE)
"""
def run_test(nnz, size):
t_sparse, _, _ = self._gen_sparse(1, nnz, (size,), dtype, device, coalesced)
t_dense = t_sparse.to_dense()
# idx_small to (sort) and (binary) search into t_sparse
idx_small = torch.randint(size, (nnz // 2,), device=device)
# idx_large to (sort) and (binary) search into idx_large
# NOTE: when coalesced=True, the (binary) search will be
# done over t_sparse anyway, as it is already sorted.
idx_large = torch.randint(size, (nnz * 2,), device=device)
for idx in (idx_small, idx_large):
res_dense = t_dense.index_select(0, idx)
res_sparse = t_sparse.index_select(0, idx)
self.assertEqual(res_dense, res_sparse)
# NOTE: GRAIN_SIZE = 32768
# case nnz <= size[d]
tlen = 70000 # > 2 * GRAIN_SIZE
run_test(tlen, tlen)
# case nnz > size[d]
run_test(tlen, tlen // 2)
@onlyCPU
@coalescedonoff
@dtypes(torch.double, torch.cdouble)

View File

@ -2087,8 +2087,9 @@ class TestCase(expecttest.TestCase):
i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
i = i.to(torch.long)
if is_uncoalesced:
v = torch.cat([v, torch.randn_like(v)], 0)
i = torch.cat([i, i], 1)
i1 = i[:, :(nnz // 2), ...]
i2 = i[:, :((nnz + 1) // 2), ...]
i = torch.cat([i1, i2], 1)
x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device)
if not is_uncoalesced: