mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Faster index_select for sparse COO tensors on CPU. (#72710)
Fixes https://github.com/pytorch/pytorch/issues/72212. This PR improves the previous algorithm in complexity. It also utilizes the structure of the problem and parallelizes computations when possible. Benchmark results. <details> <summary>Testing script</summary> ```python import torch import math from IPython import get_ipython from itertools import product import pickle from torch.utils.benchmark import Timer, Compare torch.manual_seed(13) #torch.set_num_threads(1) ipython = get_ipython() index_sizes = (100, 1000, 10000) # specifies (n, nnz) problem_dims = ( # n > nnz (10000, 100), (100000, 1000), (1000000, 10000), # n < nnz (10, 100), (10, 1000), (10, 10000), (100, 1000), (100, 10000), (1000, 10000), (1000, 100000), (1000, 1000000), #(1000000, 1000000000), ) def f(t, d, index): s = torch_sparse.SparseTensor.from_torch_sparse_coo_tensor(t) ss = s.index_select(d, index) return ss.coo() name = "PR" results = [] for (n, nnz), m in product(problem_dims, index_sizes): for d in (0, 1): if nnz < n: shape = (n, n) else: shape = (n, nnz // n) if d == 0 else (nnz // n, n) nrows, ncols = shape rowidx = torch.randint(low=0, high=nrows, size=(nnz,)) colidx = torch.randint(low=0, high=ncols, size=(nnz,)) itemidx = torch.vstack((rowidx, colidx)) xvalues = torch.randn(nnz) index = torch.randint(low=0, high=n, size=(m,)) SparseX = torch.sparse_coo_tensor(itemidx, xvalues, size=shape).coalesce() smtp = "SparseX.index_select(d, index)" timer = Timer(smtp, globals=globals(), label="coo.index_select", description=f"{name}: coo.index_select", sub_label=f"n={n}, nnz={nnz}, index_len={m}, dim={d}", num_threads=torch.get_num_threads()) results.append(timer.blocked_autorange()) compare = Compare(results) compare.trim_significant_figures() compare.print() with open(f"{name}_index_select.pickle", 'wb') as f: pickle.dump(results, f) ``` </details> <details> <summary>Gather results</summary> ```python import pickle from torch.utils.benchmark import Timer, Compare files = [ "PR", "torch_sparse", "master" ] timers = [] for name in files: with open("{}_index_select.pickle".format(name), 'rb') as f: timers += pickle.load(f) compare = Compare(timers) compare.trim_significant_figures() compare.print() ``` </details> <details> <summary>PR/torch_sparse/master runtime comparison</summary> ``` [----------------------------------- coo.index_select ----------------------------------] | PR | torch_sparse | master 32 threads: ----------------------------------------------------------------------------- n=10000, nnz=100, index_len=100, dim=0 | 14 | 140 | 10 n=10000, nnz=100, index_len=100, dim=1 | 14 | 200 | 10 n=10000, nnz=100, index_len=1000, dim=0 | 30 | 180 | 38 n=10000, nnz=100, index_len=1000, dim=1 | 34 | 240 | 38 n=10000, nnz=100, index_len=10000, dim=0 | 278 | 460 | 330 n=10000, nnz=100, index_len=10000, dim=1 | 275 | 516 | 330 n=100000, nnz=1000, index_len=100, dim=0 | 16 | 290 | 31 n=100000, nnz=1000, index_len=100, dim=1 | 26 | 390 | 31 n=100000, nnz=1000, index_len=1000, dim=0 | 45 | 405 | 263 n=100000, nnz=1000, index_len=1000, dim=1 | 73 | 500 | 261 n=100000, nnz=1000, index_len=10000, dim=0 | 444 | 783 | 2570 n=100000, nnz=1000, index_len=10000, dim=1 | 470 | 890 | 2590 n=1000000, nnz=10000, index_len=100, dim=0 | 25 | 2400 | 270 n=1000000, nnz=10000, index_len=100, dim=1 | 270 | 4000 | 269 n=1000000, nnz=10000, index_len=1000, dim=0 | 74 | 2600 | 2620 n=1000000, nnz=10000, index_len=1000, dim=1 | 464 | 3600 | 2640 n=1000000, nnz=10000, index_len=10000, dim=0 | 635 | 3300 | 26400 n=1000000, nnz=10000, index_len=10000, dim=1 | 1000 | 3960 | 26400 n=10, nnz=100, index_len=100, dim=0 | 16 | 137 | 16 n=10, nnz=100, index_len=100, dim=1 | 16 | 220 | 16 n=10, nnz=100, index_len=1000, dim=0 | 63 | 238 | 81 n=10, nnz=100, index_len=1000, dim=1 | 60 | 698 | 78 n=10, nnz=100, index_len=10000, dim=0 | 480 | 940 | 862 n=10, nnz=100, index_len=10000, dim=1 | 330 | 4930 | 1070 n=10, nnz=1000, index_len=100, dim=0 | 60 | 200 | 73 n=10, nnz=1000, index_len=100, dim=1 | 56 | 683 | 70 n=10, nnz=1000, index_len=1000, dim=0 | 480 | 530 | 1050 n=10, nnz=1000, index_len=1000, dim=1 | 330 | 4550 | 1368 n=10, nnz=1000, index_len=10000, dim=0 | 3100 | 2900 | 9300 n=10, nnz=1000, index_len=10000, dim=1 | 3400 | 46000 | 9100 n=10, nnz=10000, index_len=100, dim=0 | 400 | 453 | 857 n=10, nnz=10000, index_len=100, dim=1 | 400 | 4070 | 1730 n=10, nnz=10000, index_len=1000, dim=0 | 2840 | 2600 | 13900 n=10, nnz=10000, index_len=1000, dim=1 | 3700 | 40600 | 16000 n=10, nnz=10000, index_len=10000, dim=0 | 83200 | 67400 | 160000 n=10, nnz=10000, index_len=10000, dim=1 | 68000 | 528000 | 190000 n=100, nnz=1000, index_len=100, dim=0 | 46 | 148 | 31 n=100, nnz=1000, index_len=100, dim=1 | 45 | 242 | 37 n=100, nnz=1000, index_len=1000, dim=0 | 68 | 248 | 240 n=100, nnz=1000, index_len=1000, dim=1 | 66 | 755 | 290 n=100, nnz=1000, index_len=10000, dim=0 | 370 | 802 | 2250 n=100, nnz=1000, index_len=10000, dim=1 | 372 | 5430 | 2770 n=100, nnz=10000, index_len=100, dim=0 | 82 | 210 | 224 n=100, nnz=10000, index_len=100, dim=1 | 74 | 986 | 270 n=100, nnz=10000, index_len=1000, dim=0 | 350 | 618 | 2600 n=100, nnz=10000, index_len=1000, dim=1 | 370 | 4660 | 4560 n=100, nnz=10000, index_len=10000, dim=0 | 3000 | 3400 | 41680 n=100, nnz=10000, index_len=10000, dim=1 | 5000 | 47500 | 30400 n=1000, nnz=10000, index_len=100, dim=0 | 71 | 160 | 185 n=1000, nnz=10000, index_len=100, dim=1 | 64 | 516 | 190 n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 249 | 1740 n=1000, nnz=10000, index_len=1000, dim=1 | 98 | 1030 | 1770 n=1000, nnz=10000, index_len=10000, dim=0 | 600 | 808 | 18300 n=1000, nnz=10000, index_len=10000, dim=1 | 663 | 5300 | 18500 n=1000, nnz=100000, index_len=100, dim=0 | 160 | 258 | 1890 n=1000, nnz=100000, index_len=100, dim=1 | 200 | 3620 | 2050 n=1000, nnz=100000, index_len=1000, dim=0 | 500 | 580 | 18700 n=1000, nnz=100000, index_len=1000, dim=1 | 640 | 7550 | 30000 n=1000, nnz=100000, index_len=10000, dim=0 | 3400 | 3260 | 186000 n=1000, nnz=100000, index_len=10000, dim=1 | 3600 | 49600 | 194000 n=1000, nnz=1000000, index_len=100, dim=0 | 517 | 957 | 18700 n=1000, nnz=1000000, index_len=100, dim=1 | 680 | 39600 | 37600 n=1000, nnz=1000000, index_len=1000, dim=0 | 3600 | 4500 | 186000 n=1000, nnz=1000000, index_len=1000, dim=1 | 5800 | 76400 | 190000 n=1000, nnz=1000000, index_len=10000, dim=0 | 50000 | 67900 | 1800000 n=1000, nnz=1000000, index_len=10000, dim=1 | 45000 | 570000 | 1900000 Times are in microseconds (us). ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/72710 Approved by: https://github.com/pearu, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
3e4bff7285
commit
ce3857e73c
@ -1356,7 +1356,7 @@ Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim,
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) {
|
||||
Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) {
|
||||
/*
|
||||
Algorithm:
|
||||
index - a 1-D tensor of indicies with shape (n,)
|
||||
@ -1369,79 +1369,622 @@ Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index)
|
||||
new_values - shape is (new_nnz,) + dense_shape
|
||||
|
||||
if dim < len(sparse_shape):
|
||||
for i, idx in enumerate(index):
|
||||
for j, jdx in enumerate(indices[dim]):
|
||||
if idx == jdx:
|
||||
icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
|
||||
new_indices.add_column(icol)
|
||||
new_values.add_row(values[j])
|
||||
# Find new_indices[dim] of the output sparse tensor and
|
||||
# indices at which to select values/indices.
|
||||
# The CPP code uses (binary/in a count table) search to find matches and may
|
||||
# swap the loop order for better algorithmic complexity.
|
||||
new_dim_indices = []
|
||||
selected_dim_indices = []
|
||||
# This is a brute-force algorithms to convey the main idea.
|
||||
# The CPP code below is more efficient but more complicated.
|
||||
for i, i_idx in enumerate(indices[dim]):
|
||||
for j, j_idx in enumerate(index):
|
||||
if i_idx == j_idx:
|
||||
new_dim_indices.append(j)
|
||||
selected_dim_indices.append(i)
|
||||
new_indices = indices.index_select(1, selected_dim_indices)
|
||||
new_values = values.index_select(0, selected_dim_indices)
|
||||
new_indices[dim] = new_dim_indices
|
||||
else:
|
||||
new_indices = indices
|
||||
new_values[k] = values[k].index_select(dim - len(sparse_shape), index) for k in range(nnz)
|
||||
new_values = values.index_select(dim - sparse_dim + 1, index);
|
||||
*/
|
||||
auto ndim = self.dim();
|
||||
if (ndim == 0) {
|
||||
TORCH_CHECK_INDEX(false, "index_select() cannot be applied to a 0-dim tensor.");
|
||||
}
|
||||
if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
|
||||
TORCH_CHECK_INDEX(false, "index_select() argument index must be 1-D long-tensor.");
|
||||
}
|
||||
const auto ndim = self.dim();
|
||||
TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor.");
|
||||
TORCH_CHECK_INDEX(
|
||||
index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided,
|
||||
"index_select() argument index must be 1-D strided (non-sparse) long-tensor.");
|
||||
dim = maybe_wrap_dim(dim, ndim);
|
||||
auto size = self.size(dim);
|
||||
auto sparse_dim = self.sparse_dim();
|
||||
auto dense_dim = self.dense_dim();
|
||||
auto indices = self._indices();
|
||||
auto values = self._values();
|
||||
auto nnz = values.size(0);
|
||||
auto new_sizes = self.sizes().vec();
|
||||
new_sizes[dim] = index.size(0);
|
||||
const auto size = self.size(dim);
|
||||
const auto sparse_dim = self.sparse_dim();
|
||||
const auto dense_dim = self.dense_dim();
|
||||
const auto indices = self._indices();
|
||||
const auto values = self._values();
|
||||
const auto nnz = values.size(0);
|
||||
const auto index_len = index.size(0);
|
||||
auto res_sizes = self.sizes().vec();
|
||||
res_sizes[dim] = index_len;
|
||||
|
||||
// Equivalent to t.index_select(dim, idx), but vanilla index_select is not parallel,
|
||||
// so we use gather instead.
|
||||
// We use this method to select relevant indices/values
|
||||
// from the intersection between indices[dim] and the index.
|
||||
const auto index_select = [](const Tensor& t, int64_t dim, const Tensor& idx) -> Tensor {
|
||||
const auto idx_len = idx.numel();
|
||||
auto out_shape = t.sizes().vec();
|
||||
out_shape[dim] = idx_len;
|
||||
auto idx_shape = std::vector<int64_t>(t.dim(), 1);
|
||||
idx_shape[dim] = idx_len;
|
||||
return t.gather(dim, idx.view(idx_shape).expand(out_shape));
|
||||
};
|
||||
|
||||
// If indexing into sparse dimensions
|
||||
if (dim < sparse_dim) {
|
||||
// short-circuit if index is empty
|
||||
if (!index_len) {
|
||||
auto res_indices = index_select(indices, 1, index);
|
||||
res_indices[dim] = index;
|
||||
const auto res_values = index_select(values, 0, index);
|
||||
|
||||
auto cpu_dim_indices = indices[dim].to(c10::kCPU).contiguous();
|
||||
int64_t* cpu_dim_indices_ptr = cpu_dim_indices.data_ptr<int64_t>();
|
||||
auto cpu_index = index.to(c10::kCPU).contiguous();
|
||||
int64_t* cpu_index_ptr = cpu_index.data_ptr<int64_t>();
|
||||
std::vector<int64_t> zindices;
|
||||
std::vector<int64_t> iindices;
|
||||
int64_t new_nnz = 0;
|
||||
for (const auto i : c10::irange(new_sizes[dim])) {
|
||||
int64_t idx = cpu_index_ptr[i];
|
||||
if (idx < -size || idx >= size) {
|
||||
TORCH_CHECK_INDEX(false, "index_select(): index contains ", idx, " that is out of range for tensor of size ",
|
||||
self.sizes(), " at dimension ", dim);
|
||||
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
|
||||
}
|
||||
|
||||
const auto nneg_index = [&index, index_len, &self, size, dim](void) -> Tensor {
|
||||
const auto index_contiguous = index.contiguous();
|
||||
auto nneg_index = at::empty_like(index_contiguous);
|
||||
// nneg_index = (index < 0) * (index + size) + (index >= 0) * index
|
||||
auto* ptr_index = index_contiguous.data_ptr<int64_t>();
|
||||
auto* ptr_nneg_index = nneg_index.data_ptr<int64_t>();
|
||||
at::parallel_for(0, index_len, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
|
||||
const auto* src = ptr_index + start;
|
||||
auto* dst = ptr_nneg_index + start;
|
||||
for (C10_UNUSED const auto _ : c10::irange(start, end)) {
|
||||
auto idx = *src++;
|
||||
if (idx < -size || idx >= size) {
|
||||
TORCH_CHECK_INDEX(false,
|
||||
"index_select(): index contains ", idx, " that is out of range for tensor of size ",
|
||||
self.sizes(), " at dimension ", dim
|
||||
);
|
||||
}
|
||||
if (idx < 0) {
|
||||
idx += size;
|
||||
}
|
||||
*dst++ = idx;
|
||||
}
|
||||
});
|
||||
|
||||
return nneg_index;
|
||||
}();
|
||||
|
||||
const auto dim_indices = indices[dim].contiguous();
|
||||
|
||||
// If nnz is smaller than size, then either indices[dim] or index gets sorted,
|
||||
// then this is followed by a binary search to find interesections.
|
||||
const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple<Tensor, Tensor> {
|
||||
const auto grain_size = at::internal::GRAIN_SIZE;
|
||||
const auto n_threads_nnz = std::max<int64_t>(
|
||||
1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto n_threads_index = std::max<int64_t>(
|
||||
1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto search_in_dim_indices
|
||||
// if either dim_indices or index requires sorting, we compare
|
||||
// the cost of sort + binary search, which is comparing
|
||||
// (len(dim_indices) + len(index)) * log(len(index)) to
|
||||
// (len(dim_indices) + len(index)) * log(len(dim_indices)).
|
||||
// That simplifies to comparing len(dim_indices) to len(index).
|
||||
// Additionally, we take into consideration potential parallel
|
||||
// speedup.
|
||||
= (nnz / n_threads_nnz <= index_len / n_threads_index)
|
||||
// if self is coalesced and dim is 0, then we compare
|
||||
// index_len * log(len(dim_indices)), which is binary search into dim_indices,
|
||||
// to (len(index_len) + len(dim_indices)) * log(index_len).
|
||||
// Additionally, we take into consideration potential parallel
|
||||
// speedup.
|
||||
|| (self.is_coalesced() && dim == 0
|
||||
&& (index_len * std::log2(nnz) / n_threads_index
|
||||
<= (nnz / n_threads_nnz + index_len) * std::log2(index_len)))
|
||||
? true : false;
|
||||
|
||||
// src is a source of indices to binary search in sorted
|
||||
Tensor sorted, sorted_idx, src;
|
||||
std::tie(sorted, sorted_idx, src) = [
|
||||
&dim_indices, &nneg_index, &self,
|
||||
search_in_dim_indices, dim, nnz
|
||||
](void) -> std::tuple<Tensor, Tensor, Tensor> {
|
||||
// sort dim_indices to binary search into it
|
||||
if (search_in_dim_indices) {
|
||||
// dim_indices is already sorted if self is coalesced and dim == 0
|
||||
if (self.is_coalesced() && dim == 0) {
|
||||
return std::make_tuple(dim_indices, at::arange(nnz, dim_indices.options()), nneg_index);
|
||||
}
|
||||
else {
|
||||
Tensor sorted_dim_indices, sorted_dim_indices_idx;
|
||||
std::tie(sorted_dim_indices, sorted_dim_indices_idx) = dim_indices.sort();
|
||||
return std::make_tuple(sorted_dim_indices, sorted_dim_indices_idx, nneg_index);
|
||||
}
|
||||
}
|
||||
// sort nneg_index to binary search into it
|
||||
else {
|
||||
Tensor sorted_nneg_index, sorted_nneg_index_idx;
|
||||
std::tie(sorted_nneg_index, sorted_nneg_index_idx) = nneg_index.sort();
|
||||
return std::make_tuple(sorted_nneg_index, sorted_nneg_index_idx, dim_indices);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto src_grain_size = at::internal::GRAIN_SIZE;
|
||||
const auto src_len = src.numel();
|
||||
const auto n_threads_src = std::max<int64_t>(
|
||||
// 1 <= n_threads_src <= std::min(ceil(src.numel() / src_grain_size), max_threads)
|
||||
1, std::min<int64_t>((src_len + src_grain_size - 1) / src_grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto chunk_size_src = (src_len + n_threads_src - 1) / n_threads_src;
|
||||
|
||||
const std::vector<int64_t> src_n_threads_shape = {
|
||||
n_threads_src, (src_len + n_threads_src - 1) / n_threads_src
|
||||
};
|
||||
|
||||
// src_int_idx and sorted_int_idx store "i" and "j" indices indicating
|
||||
// intersections such that src_int_idx[i] == sorted_int_idx[j].
|
||||
// These intersections are found with binary search and in parallel.
|
||||
auto src_int_idx = at::empty(src_n_threads_shape, src.options());
|
||||
auto sorted_int_idx = at::empty_like(src_int_idx);
|
||||
// For each element "i" from src, int_counts define how many
|
||||
// elements there are in sorted, i.e. "j" indices, corresponding
|
||||
// to "i", i.e.:
|
||||
// |{j : src_int_idx[i] == sorted_int_idx[j]}| for each i in src_int_idx.
|
||||
auto int_counts = at::zeros_like(src_int_idx);
|
||||
|
||||
// fill in src_int_idx, sorted_int_idx, int_counts
|
||||
{
|
||||
const auto sorted_len = sorted.numel();
|
||||
const auto* ptr_sorted = sorted.data_ptr<int64_t>();
|
||||
const auto* ptr_sorted_start = ptr_sorted;
|
||||
const auto* ptr_sorted_end = ptr_sorted + sorted_len;
|
||||
|
||||
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
|
||||
const auto start = tid * chunk_size_src;
|
||||
const auto end = std::min(start + chunk_size_src, src_len);
|
||||
auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_src = src.data_ptr<int64_t>() + start;
|
||||
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
const auto src_val = *ptr_src++;
|
||||
const auto src_val_lb = std::lower_bound(ptr_sorted_start, ptr_sorted_end, src_val);
|
||||
// We cannot just use *src_val_lb != src_val because when
|
||||
// src_val_lb == ptr_sorted_end, dereferencing past-the-end value
|
||||
// is not well-defined.
|
||||
if (src_val_lb == ptr_sorted_end || *src_val_lb != src_val) {
|
||||
++ptr_tid_src_int_idx;
|
||||
++ptr_tid_sorted_int_idx;
|
||||
++ptr_tid_int_counts;
|
||||
continue;
|
||||
}
|
||||
const auto src_val_ub = std::upper_bound(ptr_sorted_start, ptr_sorted_end, src_val);
|
||||
|
||||
const int64_t count = src_val_ub - src_val_lb;
|
||||
const int64_t j = src_val_lb - ptr_sorted_start;
|
||||
|
||||
*ptr_tid_src_int_idx++ = i;
|
||||
*ptr_tid_sorted_int_idx++ = j;
|
||||
*ptr_tid_int_counts++ = count;
|
||||
}
|
||||
});
|
||||
}
|
||||
if (idx < 0) {
|
||||
idx += size;
|
||||
|
||||
const auto compressed_int_counts = int_counts.sum(-1);
|
||||
const auto res_len = compressed_int_counts.sum().item<int64_t>();
|
||||
|
||||
// Short-circuit if empty intersection
|
||||
if (!res_len) {
|
||||
auto empty_idx = at::empty({0}, src.options());
|
||||
return std::make_tuple(empty_idx, empty_idx);
|
||||
}
|
||||
for (const auto j : c10::irange(nnz)) {
|
||||
int64_t jdx = cpu_dim_indices_ptr[j];
|
||||
if (idx == jdx) {
|
||||
new_nnz++;
|
||||
iindices.push_back(i);
|
||||
zindices.push_back(j);
|
||||
|
||||
// Now that we know "i", "j" and the counts, we "unflatten"
|
||||
// them into two arrays of intersection indices such that
|
||||
// selected_src = repeat_interleave(src_int_idx, int_counts),
|
||||
// and selected_sorted is obtained as follows:
|
||||
// offsets = int_counts.cumsum(0).sub_(int_counts)
|
||||
// for ii, (j, c) in enumerate(zip(sorted_int_idx, int_counts)):
|
||||
// out_slice = slice(offsets[ii], offsets[ii] + c)
|
||||
// src_slice = slice(j, j + c)
|
||||
// selected_sorted[out_slice] = sorted_int_idx[src_slice]
|
||||
auto selected_sorted = at::empty({res_len}, sorted.options());
|
||||
auto selected_src = at::empty({res_len}, src.options());
|
||||
|
||||
// fill in selected_sorted, selected_src
|
||||
{
|
||||
auto* ptr_selected_sorted = selected_sorted.data_ptr<int64_t>();
|
||||
auto* ptr_selected_src = selected_src.data_ptr<int64_t>();
|
||||
|
||||
const auto thread_offsets = compressed_int_counts.cumsum(0).sub_(compressed_int_counts);
|
||||
const auto* ptr_sorted_idx = sorted_idx.data_ptr<int64_t>();
|
||||
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
|
||||
const auto start = tid * chunk_size_src;
|
||||
const auto end = std::min(start + chunk_size_src, src_len);
|
||||
const auto tid_offset = thread_offsets.data_ptr<int64_t>()[tid];
|
||||
const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
|
||||
auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset;
|
||||
auto* ptr_tid_selected_src = ptr_selected_src + tid_offset;
|
||||
|
||||
for (C10_UNUSED const auto _ : c10::irange(start, end)) {
|
||||
const auto count = *ptr_tid_int_counts++;
|
||||
const auto i = *ptr_tid_src_int_idx++;
|
||||
const auto j = *ptr_tid_sorted_int_idx++;
|
||||
if (!count) continue;
|
||||
|
||||
std::fill_n(ptr_tid_selected_src, count, i);
|
||||
std::copy_n(ptr_sorted_idx + j, count, ptr_tid_selected_sorted);
|
||||
|
||||
ptr_tid_selected_sorted += count;
|
||||
ptr_tid_selected_src += count;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return search_in_dim_indices
|
||||
? std::make_tuple(selected_sorted, selected_src)
|
||||
: std::make_tuple(selected_src, selected_sorted);
|
||||
};
|
||||
|
||||
// Converts a 1d sorted idx to a compressed 1d compressed idx,
|
||||
// aka crow in the CSR format. Useful to get a count table in
|
||||
// a parallelized and no-sync manner.
|
||||
// TODO: this function is equivalent to _convert_indices_from_coo_to_csr.
|
||||
// The mentioned function is not public yet.
|
||||
const auto sorted_idx_to_cidx = [](
|
||||
const Tensor& idx,
|
||||
int64_t len,
|
||||
bool run_in_parallel = true) -> Tensor {
|
||||
auto cidx = at::empty({len + 1}, idx.options());
|
||||
|
||||
const auto* ptr_idx = idx.data_ptr<int64_t>();
|
||||
auto* ptr_cidx = cidx.data_ptr<int64_t>();
|
||||
|
||||
const auto idx_len = idx.numel();
|
||||
|
||||
std::fill_n(ptr_cidx, ptr_idx[0] + 1, 0);
|
||||
std::fill_n(ptr_cidx + ptr_idx[idx_len - 1] + 1, len - ptr_idx[idx_len - 1], idx_len);
|
||||
|
||||
const auto grain_size = run_in_parallel ? at::internal::GRAIN_SIZE : idx_len;
|
||||
at::parallel_for(0, idx_len, grain_size, [&](int64_t start, int64_t end) {
|
||||
auto* ptr_curr_cidx = ptr_cidx + ptr_idx[start] + 1;
|
||||
for (int64_t i = start; i < std::min(end, idx_len - 1); ++i) {
|
||||
const auto diff = ptr_idx[i + 1] - ptr_idx[i];
|
||||
std::fill_n(ptr_curr_cidx, diff, i + 1);
|
||||
ptr_curr_cidx += diff;
|
||||
}
|
||||
});
|
||||
|
||||
return cidx;
|
||||
};
|
||||
|
||||
// If nnz is (much) larger than size, then both indices[dim] and index get sorted
|
||||
// with a count sort (faster, and no huge nnz-sized chunk memory allocations).
|
||||
// The element-wise product between the count tables gives us all the intersections.
|
||||
const auto get_selected_indices_large_nnz_small_size = [&]() -> std::tuple<Tensor, Tensor> {
|
||||
const auto get_counts = [&sorted_idx_to_cidx](
|
||||
// Writes into counts (must be preallocated and zero)
|
||||
// and allows to use external buffers.
|
||||
Tensor& counts,
|
||||
const Tensor& t,
|
||||
int64_t bins,
|
||||
bool is_sorted = false,
|
||||
bool run_in_parallel = true) -> void {
|
||||
if (is_sorted) {
|
||||
const auto cidx = sorted_idx_to_cidx(t, bins, run_in_parallel);
|
||||
at::sub_out(counts, cidx.slice(0, 1, bins + 1), cidx.slice(0, 0, bins));
|
||||
}
|
||||
else {
|
||||
auto* ptr_counts = counts.data_ptr<int64_t>();
|
||||
const auto* ptr_vals = t.data_ptr<int64_t>();
|
||||
for (C10_UNUSED const auto _ : c10::irange(t.numel())) {
|
||||
++ptr_counts[*ptr_vals++];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const auto counts_per_thread = [&get_counts, size](
|
||||
const Tensor& idx,
|
||||
bool is_sorted = false,
|
||||
int64_t grain_size = at::internal::GRAIN_SIZE
|
||||
) -> Tensor {
|
||||
const auto idx_len = idx.numel();
|
||||
// 1 <= n_threads <= min(ceil(len / grain_size), max_threads)
|
||||
const auto n_threads = std::max<int64_t>(
|
||||
1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto chunk_size = (idx_len + n_threads - 1) / n_threads;
|
||||
const auto run_in_parallel = (n_threads == 1);
|
||||
|
||||
auto counts_per_thread = at::zeros({n_threads, size}, idx.options());
|
||||
at::parallel_for(0, n_threads, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
|
||||
const auto start = tid * chunk_size;
|
||||
const auto end = std::min(start + chunk_size, idx_len);
|
||||
const auto tid_idx = idx.slice(0, start, end);
|
||||
auto tid_counts = counts_per_thread.select(0, tid);
|
||||
get_counts(tid_counts, tid_idx, /*bins=*/size,
|
||||
/*is_sorted=*/is_sorted, /*run_in_parallel=*/run_in_parallel);
|
||||
});
|
||||
|
||||
return counts_per_thread;
|
||||
};
|
||||
|
||||
auto dim_indices_counts_per_thread = counts_per_thread(
|
||||
dim_indices,
|
||||
/*is_sorted=*/self.is_coalesced() && dim == 0
|
||||
/*grain_size = at::internal::GRAIN_SIZE*/
|
||||
);
|
||||
auto dim_indices_offset_counts_per_thread = dim_indices_counts_per_thread.cumsum(0);
|
||||
|
||||
auto index_counts_per_thread = counts_per_thread(
|
||||
nneg_index,
|
||||
/*is_sorted=*/false
|
||||
/*grain_size = at::internal::GRAIN_SIZE*/
|
||||
);
|
||||
auto index_offset_counts_per_thread = index_counts_per_thread.cumsum(0);
|
||||
|
||||
const auto index_counts = index_offset_counts_per_thread.select(0, -1);
|
||||
const auto dim_indices_counts = dim_indices_offset_counts_per_thread.select(0, -1);
|
||||
const auto intersection_counts = index_counts.mul(dim_indices_counts);
|
||||
const auto res_len = intersection_counts.sum().item<int64_t>();
|
||||
// Short-circuit if empty intersection
|
||||
if (!res_len) {
|
||||
auto empty_idx = at::empty({0}, index.options());
|
||||
return std::make_tuple(empty_idx, empty_idx);
|
||||
}
|
||||
const auto intersection_offsets = intersection_counts.cumsum(0);
|
||||
|
||||
const auto search_in_dim_indices = [&]() -> bool {
|
||||
const auto grain_size = at::internal::GRAIN_SIZE;
|
||||
const auto n_threads_index = std::max<int64_t>(
|
||||
1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto n_threads_dim_indices = std::max<int64_t>(
|
||||
1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
|
||||
const auto index_max_copy_work_per_thread =
|
||||
index_counts_per_thread.mul(dim_indices_counts).sum(-1).max().item<int64_t>();
|
||||
const auto dim_indices_max_copy_work_per_thread
|
||||
= dim_indices_counts_per_thread.mul(index_counts).sum(-1).max().item<int64_t>();
|
||||
|
||||
const auto index_max_work_per_thread = index_max_copy_work_per_thread * index_len / n_threads_index;
|
||||
const auto dim_indices_max_work_per_thread = dim_indices_max_copy_work_per_thread * nnz / n_threads_dim_indices;
|
||||
return index_max_work_per_thread <= dim_indices_max_work_per_thread
|
||||
? true
|
||||
: false;
|
||||
}();
|
||||
|
||||
Tensor idx, idx_counts_per_thread, idx_offset_counts_per_thread;
|
||||
Tensor src, src_counts_per_thread, src_offset_counts_per_thread;
|
||||
std::tie(
|
||||
idx, idx_counts_per_thread, idx_offset_counts_per_thread,
|
||||
src, src_counts_per_thread, src_offset_counts_per_thread
|
||||
) = [&]() {
|
||||
return search_in_dim_indices
|
||||
? std::make_tuple(
|
||||
nneg_index, index_counts_per_thread, index_offset_counts_per_thread,
|
||||
dim_indices, dim_indices_counts_per_thread, dim_indices_offset_counts_per_thread
|
||||
)
|
||||
: std::make_tuple(
|
||||
dim_indices, dim_indices_counts_per_thread, dim_indices_counts_per_thread.cumsum(0),
|
||||
nneg_index, index_counts_per_thread, index_counts_per_thread.cumsum(0)
|
||||
);
|
||||
}();
|
||||
|
||||
const auto idx_counts = idx_offset_counts_per_thread.select(0, -1);
|
||||
const auto src_counts = src_offset_counts_per_thread.select(0, -1);
|
||||
|
||||
Tensor src_idx, src_idx_offsets;
|
||||
std::tie(src_idx, src_idx_offsets) = [&](
|
||||
int64_t grain_size = at::internal::GRAIN_SIZE
|
||||
) -> std::tuple<Tensor, Tensor> {
|
||||
const auto src_intersection_counts = src_counts.mul(idx_counts > 0);
|
||||
const auto src_intersection_offsets = src_intersection_counts.cumsum(0);
|
||||
const auto src_idx_len = src_intersection_offsets.data_ptr<int64_t>()[size - 1];
|
||||
auto src_idx = at::empty({src_idx_len}, src.options());
|
||||
|
||||
const auto* ptr_src = src.data_ptr<int64_t>();
|
||||
const auto* ptr_intersection_counts = intersection_counts.data_ptr<int64_t>();
|
||||
const auto* ptr_src_intersection_counts = src_intersection_counts.data_ptr<int64_t>();
|
||||
const auto* ptr_src_intersection_offsets = src_intersection_offsets.data_ptr<int64_t>();
|
||||
auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
|
||||
|
||||
const auto src_len = src.numel();
|
||||
const auto n_threads_src = std::max<int64_t>(
|
||||
1, std::min<int64_t>((src_len + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto chunk_size = (src_len + n_threads_src - 1) / n_threads_src;
|
||||
at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
|
||||
const auto start = tid * chunk_size;
|
||||
const auto end = std::min(start + chunk_size, src_len);
|
||||
auto* ptr_src_tid = ptr_src + start;
|
||||
const auto* ptr_src_counts_per_thread
|
||||
= src_counts_per_thread.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_src_offset_counts_per_thread
|
||||
= src_offset_counts_per_thread.select(0, tid).data_ptr<int64_t>();
|
||||
auto tid_counts = at::zeros({size}, src.options());
|
||||
auto* ptr_tid_counts = tid_counts.data_ptr<int64_t>();
|
||||
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
const auto idx_val = *ptr_src_tid++;
|
||||
// skip idx value if not in the intersection
|
||||
if (!ptr_intersection_counts[idx_val]) continue;
|
||||
const auto idx_val_offset
|
||||
= ptr_src_intersection_offsets[idx_val]
|
||||
- ptr_src_intersection_counts[idx_val];
|
||||
const auto idx_val_tid_offset
|
||||
= ptr_src_offset_counts_per_thread[idx_val]
|
||||
- ptr_src_counts_per_thread[idx_val];
|
||||
auto& idx_val_local_tid_count = ptr_tid_counts[idx_val];
|
||||
ptr_src_idx[idx_val_offset + idx_val_tid_offset + idx_val_local_tid_count] = i;
|
||||
++idx_val_local_tid_count;
|
||||
}
|
||||
});
|
||||
|
||||
const auto src_idx_offsets = src_intersection_offsets.sub_(src_intersection_counts);
|
||||
|
||||
return std::make_tuple(src_idx, src_idx_offsets);
|
||||
}();
|
||||
|
||||
Tensor idx_selected, src_selected;
|
||||
std::tie(idx_selected, src_selected) = [&](
|
||||
int64_t grain_size = at::internal::GRAIN_SIZE
|
||||
) -> std::tuple<Tensor, Tensor> {
|
||||
const auto thread_offset = [&]() {
|
||||
// we do not need idx_counts_per_thread anymore,
|
||||
// so it is safe to do in-place intersection.
|
||||
auto counts_per_thread = idx_counts_per_thread.mul_(src_counts).sum(-1);
|
||||
return counts_per_thread.cumsum(0).sub_(counts_per_thread);
|
||||
}();
|
||||
const auto* ptr_thread_offset = thread_offset.data_ptr<int64_t>();
|
||||
|
||||
auto idx_selected = at::empty({res_len}, idx.options());
|
||||
auto src_selected = at::empty({res_len}, src.options());
|
||||
|
||||
const auto* ptr_idx = idx.data_ptr<int64_t>();
|
||||
const auto* ptr_src_counts = src_counts.data_ptr<int64_t>();
|
||||
const auto* ptr_intersection_counts = intersection_counts.data_ptr<int64_t>();
|
||||
const auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
|
||||
const auto* ptr_src_idx_offsets = src_idx_offsets.data_ptr<int64_t>();
|
||||
auto* ptr_idx_selected = idx_selected.data_ptr<int64_t>();
|
||||
auto* ptr_src_selected = src_selected.data_ptr<int64_t>();
|
||||
|
||||
const auto idx_len = idx.numel();
|
||||
const auto n_threads_idx = std::max<int64_t>(
|
||||
1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
|
||||
);
|
||||
const auto chunk_size = (idx_len + n_threads_idx - 1) / n_threads_idx;
|
||||
at::parallel_for(0, n_threads_idx, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
|
||||
const auto start = tid * chunk_size;
|
||||
const auto end = std::min(start + chunk_size, idx_len);
|
||||
const auto tid_offset = ptr_thread_offset[tid];
|
||||
const auto* ptr_idx_tid = ptr_idx + start;
|
||||
auto* ptr_idx_selected_tid = ptr_idx_selected + tid_offset;
|
||||
auto* ptr_src_selected_tid = ptr_src_selected + tid_offset;
|
||||
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
const auto idx_val = *ptr_idx_tid++;
|
||||
// skip if idx_val is not in the intersection
|
||||
if (!ptr_intersection_counts[idx_val]) continue;
|
||||
const auto count = ptr_src_counts[idx_val];
|
||||
const auto j = ptr_src_idx_offsets[idx_val];
|
||||
std::fill_n(ptr_idx_selected_tid, count, i);
|
||||
std::copy_n(ptr_src_idx + j, count, ptr_src_selected_tid);
|
||||
ptr_idx_selected_tid += count;
|
||||
ptr_src_selected_tid += count;
|
||||
}
|
||||
});
|
||||
|
||||
return std::make_tuple(idx_selected, src_selected);
|
||||
}();
|
||||
|
||||
return search_in_dim_indices
|
||||
? std::make_tuple(src_selected, idx_selected)
|
||||
: std::make_tuple(idx_selected, src_selected);
|
||||
};
|
||||
|
||||
const auto make_output = [&](
|
||||
const Tensor& selected_dim_indices,
|
||||
const Tensor& res_dim_indices) -> Tensor {
|
||||
auto res_indices = index_select(indices, 1, selected_dim_indices);
|
||||
res_indices[dim] = res_dim_indices;
|
||||
const auto res_values = index_select(values, 0, selected_dim_indices);
|
||||
|
||||
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
|
||||
};
|
||||
|
||||
// Brute-force solution for small values of nnz and index_len
|
||||
const auto get_result_small_nnz_small_index = [&]()
|
||||
-> Tensor {
|
||||
const auto dim_indices_in_inner_loop = nnz >= index_len;
|
||||
Tensor outer, inner;
|
||||
std::tie(outer, inner) = [&]() -> std::tuple<Tensor, Tensor> {
|
||||
if (dim_indices_in_inner_loop) {
|
||||
return std::make_tuple(nneg_index, dim_indices);
|
||||
}
|
||||
else {
|
||||
return std::make_tuple(dim_indices, nneg_index);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto* ptr_outer = outer.data_ptr<int64_t>();
|
||||
const auto* ptr_inner = inner.data_ptr<int64_t>();
|
||||
// NOTE: if very critical, replace std::vector with
|
||||
// a data structure that operates on stack up to some limit.
|
||||
auto outer_selected_idx = std::vector<int64_t>();
|
||||
auto inner_selected_idx = std::vector<int64_t>();
|
||||
int64_t res_len = 0;
|
||||
for (const auto i : c10::irange(outer.numel())) {
|
||||
for (const auto j : c10::irange(inner.numel())) {
|
||||
if (ptr_outer[i] == ptr_inner[j]) {
|
||||
++res_len;
|
||||
outer_selected_idx.push_back(i);
|
||||
inner_selected_idx.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto outer_selected_idx_tensor = at::from_blob(
|
||||
outer_selected_idx.data(), {res_len}, at::kLong
|
||||
);
|
||||
const auto inner_selected_idx_tensor = at::from_blob(
|
||||
inner_selected_idx.data(), {res_len}, at::kLong
|
||||
);
|
||||
|
||||
return dim_indices_in_inner_loop
|
||||
? make_output(inner_selected_idx_tensor, outer_selected_idx_tensor)
|
||||
: make_output(outer_selected_idx_tensor, inner_selected_idx_tensor);
|
||||
};
|
||||
|
||||
constexpr int64_t BRUTE_FORCE_SIZE_LIMIT = 2 << 14; // 16384
|
||||
// NOTE: such a condition to avoid overflows in (nnz * index_len)
|
||||
if (nnz <= BRUTE_FORCE_SIZE_LIMIT && index_len <= BRUTE_FORCE_SIZE_LIMIT
|
||||
&& (nnz * index_len) <= BRUTE_FORCE_SIZE_LIMIT) {
|
||||
return get_result_small_nnz_small_index();
|
||||
}
|
||||
auto zIndices = at::from_blob(zindices.data(), {new_nnz}, at::kLong).to(indices.device());
|
||||
auto new_indices = indices.index_select(1, zIndices);
|
||||
new_indices[dim] = at::from_blob(iindices.data(), {new_nnz}, at::kLong).to(indices.device());
|
||||
auto new_values = values.index_select(0, zIndices);
|
||||
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, new_sizes, new_indices, new_values, self.options());
|
||||
else {
|
||||
Tensor selected_dim_indices;
|
||||
Tensor res_dim_indices;
|
||||
|
||||
} else {
|
||||
// A more precise decision could be of the form:
|
||||
// `nnz < C(nnz, size) * size`, but it requires heavy benchmarking.
|
||||
// We choose `nnz < size`, which measures theoretical complexity
|
||||
// and does not rely on runtime performance.
|
||||
// TODO: perform this analysis and find better C(nnz, size).
|
||||
if (nnz <= size) {
|
||||
std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_small_nnz_large_size();
|
||||
}
|
||||
else {
|
||||
std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_large_nnz_small_size();
|
||||
}
|
||||
|
||||
auto vsize = values.sizes().vec();
|
||||
vsize[dim + 1 - sparse_dim] = index.size(0);
|
||||
auto new_values = at::empty(vsize, values.options());
|
||||
for (const auto k : c10::irange(nnz)) {
|
||||
new_values[k] = values[k].index_select(dim - sparse_dim, index);
|
||||
return make_output(selected_dim_indices, res_dim_indices);
|
||||
}
|
||||
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, new_sizes, indices, new_values, self.options());
|
||||
|
||||
}
|
||||
// If indexing into dense dimensions
|
||||
else {
|
||||
// It is sufficient to just perform `index_select` on values
|
||||
// if `dim` refers to dense dimensions.
|
||||
const auto res_values = index_select(values, dim - sparse_dim + 1, index);
|
||||
|
||||
return _sparse_coo_tensor_with_dims_and_tensors(
|
||||
sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
Tensor index_select_sparse_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
|
||||
auto res = index_select_sparse_cpu(self.to(at::kCPU), dim, index.to(at::kCPU));
|
||||
return res.to(self.device());
|
||||
}
|
||||
|
||||
Tensor slice(
|
||||
|
||||
@ -7056,8 +7056,8 @@
|
||||
QuantizedCPU: index_select_quantized_cpu_
|
||||
CUDA: index_select_cuda
|
||||
QuantizedCUDA: index_select_quantized_cuda
|
||||
SparseCPU: index_select_sparse
|
||||
SparseCUDA: index_select_sparse
|
||||
SparseCPU: index_select_sparse_cpu
|
||||
SparseCUDA: index_select_sparse_cuda
|
||||
|
||||
- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
|
||||
@ -157,7 +157,7 @@ class TestSparse(TestCase):
|
||||
self.assertEqual(i, x._indices())
|
||||
self.assertEqual(v, x._values())
|
||||
self.assertEqual(x.ndimension(), len(with_size))
|
||||
self.assertEqual(x.coalesce()._nnz(), nnz)
|
||||
self.assertEqual(x.coalesce()._nnz(), nnz if x.is_coalesced() else nnz // 2)
|
||||
self.assertEqual(list(x.size()), with_size)
|
||||
|
||||
# Test .indices() and .values()
|
||||
@ -999,6 +999,105 @@ class TestSparse(TestCase):
|
||||
test_shape(len(sizes) // 2, 10, sizes, d, index)
|
||||
test_shape(len(sizes), 10, sizes, d, index)
|
||||
|
||||
def _test_index_select_exhaustive_index(self, sizes, dims, device, dtype, coalesced):
|
||||
t = make_tensor(sizes, dtype=dtype, device=device)
|
||||
t_sparse = t.to_sparse().coalesce() if coalesced else t.to_sparse()
|
||||
t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
|
||||
t_small = t_small_sparse.to_dense()
|
||||
for d in dims:
|
||||
# NOTE: indices are negative
|
||||
idx_dim_d_range = list(range(-sizes[d], 0))
|
||||
for idx_len in range(sizes[d], sizes[d] + 1):
|
||||
# creates all possible valid indices into dim d of lenght idx_len
|
||||
for idx in itertools.product(*itertools.repeat(idx_dim_d_range, idx_len)):
|
||||
t_idx = torch.tensor(idx, dtype=torch.long, device=device)
|
||||
|
||||
# NOTE: index_select for dense does not support negative indices,
|
||||
# hence + sizes[d]. See https://github.com/pytorch/pytorch/issues/76347
|
||||
|
||||
# tests the nnz > sizes[d] branch
|
||||
dense_result = t.index_select(d, t_idx + sizes[d])
|
||||
sparse_result = t_sparse.index_select(d, t_idx)
|
||||
self.assertEqual(dense_result, sparse_result)
|
||||
|
||||
# tests the nnz <= sizes[d] branch
|
||||
small_dense_result = t_small.index_select(d, t_idx + sizes[d])
|
||||
small_sparse_result = t_small_sparse.index_select(d, t_idx)
|
||||
self.assertEqual(small_dense_result, small_sparse_result)
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_index_select_exhaustive_index_small(self, device, dtype, coalesced):
|
||||
# will trigger brute-force algo
|
||||
self._test_index_select_exhaustive_index((3, 3, 4), range(3), device, dtype, coalesced)
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_index_select_exhaustive_index_large(self, device, dtype, coalesced):
|
||||
# will trigger more sophisticated algos
|
||||
self._test_index_select_exhaustive_index((100, 50, 3, 3), (2, 3), device, dtype, coalesced)
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_index_select_empty_and_non_contiguous_index(self, device, dtype, coalesced):
|
||||
# empty index
|
||||
idx_empty = torch.tensor([], dtype=torch.long, device=device)
|
||||
t = make_tensor((5, 5), dtype=dtype, device=device)
|
||||
res_dense = t.index_select(0, idx_empty)
|
||||
res_sparse = t.to_sparse().index_select(0, idx_empty)
|
||||
self.assertEqual(res_dense, res_sparse)
|
||||
|
||||
# non-contigous index
|
||||
idx = torch.randint(low=0, high=5, size=(10, 2), device=device)[:, 0]
|
||||
|
||||
def run_test(sizes):
|
||||
# case nnz > size[d]
|
||||
t = make_tensor(sizes, dtype=dtype, device=device)
|
||||
res_dense = t.index_select(0, idx)
|
||||
res_sparse = t.to_sparse().index_select(0, idx)
|
||||
self.assertEqual(res_dense, res_sparse)
|
||||
|
||||
# case nnz <= size[d]
|
||||
t_small_sparse, _, _ = self._gen_sparse(len(sizes), 2, sizes, dtype, device, coalesced)
|
||||
res_sparse = t_small_sparse.index_select(0, idx)
|
||||
res_dense = t_small_sparse.to_dense().index_select(0, idx)
|
||||
self.assertEqual(res_dense, res_sparse)
|
||||
|
||||
# brute-force
|
||||
run_test((10, 10))
|
||||
# more sophisticated algos
|
||||
run_test((10, 100, 100))
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_index_select_parallelization(self, device, dtype, coalesced):
|
||||
"""
|
||||
Test with sizes that will trigger parallelization (i.e. with sizes
|
||||
that are >= at::internal::GRAIN_SIZE)
|
||||
"""
|
||||
def run_test(nnz, size):
|
||||
t_sparse, _, _ = self._gen_sparse(1, nnz, (size,), dtype, device, coalesced)
|
||||
t_dense = t_sparse.to_dense()
|
||||
|
||||
# idx_small to (sort) and (binary) search into t_sparse
|
||||
idx_small = torch.randint(size, (nnz // 2,), device=device)
|
||||
# idx_large to (sort) and (binary) search into idx_large
|
||||
# NOTE: when coalesced=True, the (binary) search will be
|
||||
# done over t_sparse anyway, as it is already sorted.
|
||||
idx_large = torch.randint(size, (nnz * 2,), device=device)
|
||||
for idx in (idx_small, idx_large):
|
||||
res_dense = t_dense.index_select(0, idx)
|
||||
res_sparse = t_sparse.index_select(0, idx)
|
||||
self.assertEqual(res_dense, res_sparse)
|
||||
|
||||
# NOTE: GRAIN_SIZE = 32768
|
||||
# case nnz <= size[d]
|
||||
tlen = 70000 # > 2 * GRAIN_SIZE
|
||||
run_test(tlen, tlen)
|
||||
|
||||
# case nnz > size[d]
|
||||
run_test(tlen, tlen // 2)
|
||||
|
||||
@onlyCPU
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
|
||||
@ -2087,8 +2087,9 @@ class TestCase(expecttest.TestCase):
|
||||
i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
|
||||
i = i.to(torch.long)
|
||||
if is_uncoalesced:
|
||||
v = torch.cat([v, torch.randn_like(v)], 0)
|
||||
i = torch.cat([i, i], 1)
|
||||
i1 = i[:, :(nnz // 2), ...]
|
||||
i2 = i[:, :((nnz + 1) // 2), ...]
|
||||
i = torch.cat([i1, i2], 1)
|
||||
x = torch.sparse_coo_tensor(i, v, torch.Size(size), dtype=dtype, device=device)
|
||||
|
||||
if not is_uncoalesced:
|
||||
|
||||
Reference in New Issue
Block a user