Step 1: Secretly add return_counts to unique, and refactor unique_dim for performance (#18648)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18648
ghimport-source-id: 1cf4a8fe91492621e02217f38cae5d7e0699fb05

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18661 Step 7: remove _unique
* #18655 Step 6: Rename _unique2 to unique and add int? dim
* #18654 Step 5: remove _unque_dim in favor of unique_dim
* #18651 Step 4: add support for unique with dim=None
* #18650 Step 3: Add support for return_counts to torch.unique for dim not None
* #18649 Step 2: Rename _unique_dim2_temporary_will_remove_soon to unique_dim
* **#18648 Step 1: Secretly add return_counts to unique, and refactor unique_dim for performance**

`unique` is fragile, previously I tried to change it in #18391 and #17097, they all pass OSS tests but finally get reverted due to internal failure. My previous work of refactoring unique #18459 is based on #18391, and after #18391 get reverted, I could not work on #18459. To continue working on #18459, #18391, and #17097 without worrying about internal failures, I am suggesting the following steps for the improvements of `unique` and `unique_dim`. soumith Please take this and there is no need to put #18391 back.

The motivation is basically to move forward as much as possible without causing any internal failures. So I will try to divide it into steps and sort from low probability of internal failure to high probability. (I don't know what the internal failure is, so I have to guess). Let's merge these PR stack one by one until we enounter internal failure.

Step 1: Create two new ATen operators, `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon` and keep `_unique` and `_unique_dim` unchanged. The backend of these two functions and `_unique` and `_unique_dim` are all the same, the only difference is the temporary ones support `return_counts` but not the `_unique` and `_unique_dim`. Step one is mostly #18391 + #18459. The cuda8 errors has been fixed. At this point, there is no user visible API change, so no docs are updated. `torch.unique` does not support `return_counts` yet, and `return_counts` is tested through the newly added temporary operators. This step just added two new ATen operators, so there shouldn't be any internal failure.

Step 2: Rename `_unique_dim2_temporary_will_remove_soon` to `unique_dim`. This should cause no internal failure either, because no change to existing operators. The only thing to worry about is to delete `unique_dim` from python side because we don't want users to use it. At this point, C++ users now have `return_counts` support for `unique_dim`.

Step 3: Update the docs of `torch.unique` and use `unique_dim` inside `torch.unique` to support `return_counts` In the docs, we should say `torch.unique` with None dim support does not support `return_counts` yet. This might cause internal failure.

Step 4: Rename `_unique2_temporary_will_remove_soon` to `_unique2` and use `_unique2` inside `torch.unique` to support `return_counts`. Update the docs saying that `torch.unique` with None dim now support `return_counts`. This might cause internal failure.

Step 5: Remove `_unique_dim`. This might cause internal failure.

Step 6: Rename `_unique2` to `unique`, add optional `dim` argument to make it looks like the signature of Python's `torch.unique`. Inside `torch.unique`, use `unique` and get rid of `unique_dim`. Unbind `unique_dim` totally from Python at codegen. This is likely to cause internal fail.

Step 7: Remove `_unique`. This is very likely to cause internal failure.

This PR
======

This PR is for step 1. This create two new ATen operators, `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon` and implement `return_counts` inside them and do refactor for performance improvements.

Please review ngimel VitalyFedyunin. They are mostly copied from #18391 and #18459, so the review should be easy.

Below is a benchmark on a tensor of shape `torch.Size([15320, 2])`:

Before
---------

```python
print(torch.__version__)
%timeit a.unique(dim=0, sorted=True, return_inverse=False); torch.cuda.synchronize()
%timeit a.unique(dim=0, sorted=True, return_inverse=True); torch.cuda.synchronize()
```

```
1.0.1
192 µs ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
548 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

```python
print(torch.__version__)
%timeit a.unique(sorted=True, return_inverse=False); torch.cuda.synchronize()
%timeit a.unique(sorted=True, return_inverse=True); torch.cuda.synchronize()
```

```
1.0.1
226 µs ± 929 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
302 µs ± 7.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

After
-------

```python
print(torch.__version__)
%timeit a.unique(dim=0, sorted=True, return_inverse=False); torch.cuda.synchronize()
%timeit a.unique(dim=0, sorted=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted=True, return_inverse=False, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```

```
1.1.0a0+83ab8ac
190 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
237 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
219 µs ± 2.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
263 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

```python
print(torch.__version__)
%timeit a.unique(sorted=True, return_inverse=False); torch.cuda.synchronize()
%timeit a.unique(sorted=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted=True, return_inverse=False, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```

```
1.1.0a0+83ab8ac
232 µs ± 2.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
301 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
264 µs ± 7.67 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
339 µs ± 9.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```

Differential Revision: D14730905

fbshipit-source-id: 10026b4b98628a8565cc28a13317d29adf1225cc
This commit is contained in:
Vitaly Fedyunin
2019-04-03 15:26:34 -07:00
committed by Facebook Github Bot
parent 7ae0263e1b
commit 773ce4fbd0
4 changed files with 410 additions and 186 deletions

View File

@ -14,10 +14,11 @@ namespace native{
namespace {
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_cpu_template(
std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
const Tensor& self,
const bool sorted,
const bool return_inverse) {
const bool return_inverse,
const bool return_counts) {
const Tensor& input = self.contiguous();
const scalar_t* input_data = input.data<scalar_t>();
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
@ -33,7 +34,8 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
}
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
if (return_inverse) {
Tensor counts = at::empty({0}, self.options().dtype(kLong));
if (return_inverse || return_counts) {
inverse_indices.resize_(input.sizes());
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
std::unordered_map<scalar_t, int64_t> inverse_map;
@ -44,21 +46,29 @@ std::tuple<Tensor, Tensor> _unique_cpu_template(
for (int i = 0; i < input.numel(); ++i) {
inverse_indices_data[i] = inverse_map[input_data[i]];
}
if (return_counts) {
counts.resize_(output.sizes());
counts.fill_(0);
for (int i = 0; i < input.numel(); ++i) {
counts[inverse_map[input_data[i]]] += 1;
}
}
}
return std::make_tuple(output, inverse_indices);
return std::make_tuple(output, inverse_indices, counts);
}
template<class ForwardIt>
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
if (first == last) {
return last;
}
// save to calculate distance to iterators
ForwardIt begin = first;
// set first inverse index
// set first inverse index and count
inverse_indices_vec[indices[0]] = 0;
counts[0] += 1;
ForwardIt result = first;
while (++first != last) {
@ -68,16 +78,18 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
int64_t idx_result = std::distance(begin, result);
int64_t idx_first = std::distance(begin, first);
inverse_indices_vec[indices[idx_first]] = idx_result;
counts[idx_result] += 1;
}
return ++result;
}
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {
const bool return_inverse,
const bool return_counts) {
// reshape tensor as [dim, -1]
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
@ -109,10 +121,12 @@ std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
}
Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong));
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
auto last = _unique_dim_cpu_impl(
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts);
input_unbind.erase(last, input_unbind.end());
counts = at::narrow(counts, 0, 0, input_unbind.size());
// reshape back
auto output = at::stack(input_unbind, 0);
@ -121,14 +135,24 @@ std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
output = output.view(new_sizes);
output = output.transpose(0, dim);
return std::make_tuple(output, inverse_indices);
return std::make_tuple(output, inverse_indices, counts);
}
} // namespace
std::tuple<Tensor, Tensor>
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] {
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
Tensor output, inverse;
std::tie(output, inverse, std::ignore) = _unique_cpu_template<scalar_t>(self, sorted, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique2_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
});
}
@ -136,7 +160,17 @@ std::tuple<Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
Tensor output, inverse;
std::tie(output, inverse, std::ignore) = _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique_dim2_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
});
}

View File

@ -5,6 +5,7 @@
#include <thrust/execution_policy.h>
#include <tuple>
#include <iterator>
#include <thrust/unique.h>
#include <thrust/sort.h>
#include <thrust/scan.h>
@ -15,148 +16,213 @@ namespace native{
namespace {
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_cuda_template(
const Tensor& self,
const bool return_inverse) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
template <
typename policy_t, typename scalar_t,
typename equal_t, typename not_equal_t
>
std::tuple<Tensor, Tensor, int64_t> compute_unique(
const policy_t &policy,
scalar_t *data,
int64_t num_inp,
const Tensor &sorted_indices,
const bool return_inverse,
const bool return_counts,
TensorOptions options,
equal_t equal,
not_equal_t not_equal
) {
const Tensor& input = self.contiguous();
int64_t num_inp = input.numel();
const scalar_t* input_data = input.data<scalar_t>();
//sort & unique
Tensor output = input.clone();
output = output.view(-1);
scalar_t* output_data = output.data<scalar_t>();
Tensor inverse_indices;
if (!return_inverse) {
inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
thrust::sort(policy, output_data, output_data + num_inp);
} else {
Tensor sorted_indices = at::arange(0, num_inp, self.type().toScalarType(kLong));
int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
Tensor inv_loc = at::empty({num_inp}, self.type().toScalarType(kLong));
inverse_indices = at::empty({num_inp}, self.type().toScalarType(kLong));
int64_t* inv_loc_ptr = inv_loc.data<int64_t>();
int64_t* inverse_indices_ptr = inverse_indices.data<int64_t>();
thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
inv_loc[0] = 0;
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
inverse_indices.resize_(input.sizes());
}
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
output.resize_(num_out);
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
// inverse indices
Tensor inverse_indices;
if (!return_inverse) {
inverse_indices = at::empty({0}, options);
} else {
AT_CHECK(sorted_indices.defined(),
"return_inverse is set to true, but sorted_indices is undefined. Send a bug report!");
const int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
Tensor inv_loc = at::empty({num_inp}, options);
inverse_indices = at::empty({num_inp}, options);
int64_t* inv_loc_ptr = inv_loc.data<int64_t>();
int64_t* inverse_indices_ptr = inverse_indices.data<int64_t>();
thrust::adjacent_difference(policy, data, data + num_inp, inv_loc_ptr, not_equal);
inv_loc[0] = 0;
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
}
template <typename scalar_t>
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
input_flat = input_flat.contiguous().view({input_flat.size(0), -1});
scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
int64_t* indices_ptr = indices.data<int64_t>();
int64_t numel = input_flat.size(1);
// sort indices using data
thrust::sort(policy, indices_ptr, indices_ptr + indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_flat_ptr[i + a * numel];
scalar_t rhs = input_flat_ptr[i + b * numel];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
});
Tensor input_sorted = input_flat.index_select(0, indices);
// get unique tensors
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < numel; ++i) {
scalar_t lhs = input_sorted_ptr[i + a * numel];
scalar_t rhs = input_sorted_ptr[i + b * numel];
if (lhs != rhs) {
return false;
}
}
return true;
});
input_sorted_indices.resize_(last - input_sorted_indices_ptr);
Tensor output = input_sorted.index_select(0, input_sorted_indices);
// reshape back
auto new_sizes = std::vector<int64_t>(orig_sizes);
new_sizes[0] = -1;
output = output.view(new_sizes);
output = output.transpose(0, dim);
// calculate inverse indices
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
if (return_inverse) {
int64_t size = self.size(dim);
inverse_indices.resize_(size);
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
mask[0] = 1;
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
mask[i+1] = 1;
} else {
mask[i+1] = 0;
}
}
Tensor imask = at::cumsum(mask, 0) - 1;
for (int i = 0; i < indices.size(0); ++i) {
inverse_indices[indices[i]] = imask[i];
}
}
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor>(output, inverse_indices);
// unique and count
Tensor counts = at::empty({0}, options);
int64_t num_out;
if (!return_counts) {
num_out = thrust::unique(policy, data, data + num_inp, equal) - data;
} else {
Tensor range = at::arange(0, num_inp + 1, options);
int64_t *range_ptr = range.data<int64_t>();
num_out = thrust::unique_by_key(policy, data, data + num_inp, range_ptr, equal).first - data;
range[num_out] = num_inp;
counts.resize_(num_out);
int64_t* counts_ptr = counts.data<int64_t>();
thrust::adjacent_difference(policy, range_ptr + 1, range_ptr + num_out + 1, counts_ptr);
}
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor, int64_t>(inverse_indices, counts, num_out);
}
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> unique_cuda_template(
const Tensor& self,
const bool return_inverse,
const bool return_counts
) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
auto options = self.options().dtype(kLong);
Tensor output = self.clone().reshape(-1);
int64_t num_inp = output.numel();
scalar_t* output_data = output.data<scalar_t>();
Tensor sorted_indices;
if (!return_inverse) {
thrust::sort(policy, output_data, output_data + num_inp);
} else {
sorted_indices = at::arange(0, num_inp, options);
int64_t *sorted_indices_ptr = sorted_indices.data<int64_t>();
thrust::sort_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr);
}
Tensor inverse_indices, counts;
int64_t num_out;
std::tie(inverse_indices, counts, num_out) = compute_unique(
policy, output_data, num_inp, sorted_indices,
return_inverse, return_counts, options,
thrust::equal_to<scalar_t>(),
thrust::not_equal_to<scalar_t>()
);
output.resize_(num_out);
if (return_inverse) {
inverse_indices.resize_(self.sizes());
}
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
}
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse,
const bool return_counts
) {
/**
* The idea for implementing this is basically the same as unique.
* For unique_dim, we are taking the unique with respect to a index
* tensor, but during the processes, we override the compare and equal
* operator by checking the data underlying it instead. After the
* algorithm, we would use index_select to map the resulting indicies
* to the result on the actual data.
*/
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(allocator).on(stream);
int64_t num_inp = self.size(dim);
auto options = self.options().dtype(kLong);
Tensor input_flat = self.transpose(dim, 0).contiguous().view({num_inp, -1});
int64_t n = input_flat.size(1);
scalar_t *input_flat_ptr = input_flat.data<scalar_t>();
Tensor indices = at::arange(0, num_inp, options);
int64_t *indices_data = indices.data<int64_t>();
thrust::sort(policy, indices_data, indices_data + num_inp,
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < n; ++i) {
scalar_t lhs = input_flat_ptr[i + a * n];
scalar_t rhs = input_flat_ptr[i + b * n];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
return false;
}
}
return false;
}
);
Tensor inverse_indices, counts;
int64_t num_out;
std::tie(inverse_indices, counts, num_out) = compute_unique(
policy, indices_data, num_inp, indices,
return_inverse, return_counts, options,
[=] __device__ (int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < n; ++i) {
scalar_t lhs = input_flat_ptr[i + a * n];
scalar_t rhs = input_flat_ptr[i + b * n];
if (lhs != rhs) {
return false;
}
}
return true;
},
[=] __device__ (int64_t a, int64_t b) -> int64_t {
for (int64_t i = 0; i < n; ++i) {
scalar_t lhs = input_flat_ptr[i + a * n];
scalar_t rhs = input_flat_ptr[i + b * n];
if (lhs != rhs) {
return 1;
}
}
return 0;
}
);
indices.resize_(num_out);
return std::tuple<Tensor, Tensor, Tensor>(self.index_select(dim, indices), inverse_indices, counts);
}
} // namespace
std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return _unique_cuda_template<scalar_t>(self, return_inverse);
Tensor output, inverse;
std::tie(output, inverse, std::ignore) = unique_cuda_template<scalar_t>(self, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique2_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
});
}
std::tuple<Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
Tensor output, inverse;
std::tie(output, inverse, std::ignore) = unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, false);
return std::make_tuple(output, inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique_dim2_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
return unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
});
}

View File

@ -2357,6 +2357,24 @@
CPU: _unique_dim_cpu
CUDA: _unique_dim_cuda
# _unique and _unique_dim are fragile and modifying them easily cause internal break
# below two operators are a temporary hack for adding return_counts support
# Please don't rely on these two operators, they will be removed soon
- func: _unique2_temporary_will_remove_soon(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _unique2_cpu
CUDA: _unique2_cuda
- func: _unique_dim2_temporary_will_remove_soon(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _unique_dim2_cpu
CUDA: _unique_dim2_cuda
- func: _unsafe_view(Tensor self, int[] size) -> Tensor
matches_jit_signature: True

View File

@ -10506,57 +10506,87 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
torch.set_flush_denormal(False)
def test_unique(self):
x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3])
expected_unique = torch.LongTensor([1, 2, 3, 5, 8])
expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2])
def run_test(device):
x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], device=device)
expected_unique = torch.tensor([1, 2, 3, 5, 8], device=device)
expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
x_unique = torch.unique(x)
self.assertEqual(
expected_unique.tolist(), sorted(x_unique.tolist()))
x_unique = torch.unique(x)
self.assertEqual(
expected_unique.tolist(), sorted(x_unique.tolist()))
x_unique, x_inverse = x.unique(return_inverse=True)
self.assertEqual(
expected_unique.tolist(), sorted(x_unique.tolist()))
self.assertEqual(expected_inverse.numel(), x_inverse.numel())
x_unique, x_inverse = x.unique(return_inverse=True)
self.assertEqual(
expected_unique.tolist(), sorted(x_unique.tolist()))
self.assertEqual(expected_inverse.numel(), x_inverse.numel())
x_unique = x.unique(sorted=True)
self.assertEqual(expected_unique, x_unique)
x_unique = x.unique(sorted=True)
self.assertEqual(expected_unique, x_unique)
x_unique, x_inverse = torch.unique(
x, sorted=True, return_inverse=True)
self.assertEqual(expected_unique, x_unique)
self.assertEqual(expected_inverse, x_inverse)
x_unique, _, x_counts = torch._unique2_temporary_will_remove_soon(x, sorted=True, return_counts=True)
self.assertEqual(expected_counts, x_counts)
# Tests per-element unique on a higher rank tensor.
y = x.view(2, 2, 2)
y_unique, y_inverse = y.unique(sorted=True, return_inverse=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
x_unique, x_inverse = torch.unique(
x, sorted=True, return_inverse=True)
self.assertEqual(expected_unique, x_unique)
self.assertEqual(expected_inverse, x_inverse)
# Tests unique on other types.
int_unique, int_inverse = torch.unique(
torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True)
self.assertEqual(torch.IntTensor([1, 2]), int_unique)
self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse)
x_unique, x_inverse, x_counts = torch._unique2_temporary_will_remove_soon(
x, sorted=True, return_inverse=True, return_counts=True)
self.assertEqual(expected_unique, x_unique)
self.assertEqual(expected_inverse, x_inverse)
self.assertEqual(expected_counts, x_counts)
double_unique, double_inverse = torch.unique(
torch.DoubleTensor([2., 1.5, 2.1, 2.]),
sorted=True,
return_inverse=True,
)
self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique)
self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse)
# Tests per-element unique on a higher rank tensor.
y = x.view(2, 2, 2)
y_unique, y_inverse = y.unique(sorted=True, return_inverse=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
byte_unique, byte_inverse = torch.unique(
torch.ByteTensor([133, 7, 7, 7, 42, 128]),
sorted=True,
return_inverse=True,
)
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
y_unique, y_inverse, y_counts = torch._unique2_temporary_will_remove_soon(
y, sorted=True, return_inverse=True, return_counts=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
self.assertEqual(expected_counts, y_counts)
# Tests unique on other types.
int_unique, int_inverse, int_counts = torch._unique2_temporary_will_remove_soon(
torch.tensor([2, 1, 2], dtype=torch.int, device=device),
sorted=True,
return_inverse=True,
return_counts=True
)
self.assertEqual(torch.tensor([1, 2], dtype=torch.int, device=device), int_unique)
self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse)
self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts)
double_unique, double_inverse, double_counts = torch._unique2_temporary_will_remove_soon(
torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device),
sorted=True,
return_inverse=True,
return_counts=True
)
self.assertEqual(torch.tensor([1.5, 2., 2.1], dtype=torch.double, device=device), double_unique)
self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse)
self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts)
byte_unique, byte_inverse, byte_counts = torch._unique2_temporary_will_remove_soon(
torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device),
sorted=True,
return_inverse=True,
return_counts=True
)
self.assertEqual(torch.tensor([7, 42, 128, 133], dtype=torch.uint8, device=device), byte_unique)
self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse)
self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts)
run_test(torch.device('cpu'))
if torch.cuda.is_available():
run_test(torch.device('cuda'))
def test_unique_dim(self):
def run_test(dtype=torch.float):
def run_test(dtype=torch.float, device=torch.device('cpu')):
x = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
@ -10564,19 +10594,27 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
[0., 1.]]],
dtype=dtype,
device=device)
expected_unique_dim0 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
[0., 1.]]],
dtype=dtype,
device=device)
expected_inverse_dim0 = torch.tensor([0, 0])
expected_counts_dim0 = torch.tensor([2])
expected_unique_dim1 = torch.tensor([[[0., 1.],
[1., 1.],
[2., 1.]],
[[0., 1.],
[1., 1.],
[2., 1.]]], dtype=dtype)
[2., 1.]]],
dtype=dtype,
device=device)
expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
expected_counts_dim1 = torch.tensor([2, 1, 1])
expected_unique_dim2 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
@ -10584,37 +10622,105 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
[[1., 1.],
[0., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
[0., 1.]]],
dtype=dtype,
device=device)
expected_inverse_dim2 = torch.tensor([0, 1])
expected_counts_dim2 = torch.tensor([1, 1])
# dim0
x_unique = torch.unique(x, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
x_unique, x_inverse = torch.unique(
x,
return_inverse=True,
dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_inverse_dim0, x_inverse)
x_unique, _, x_counts = torch._unique_dim2_temporary_will_remove_soon(
x,
return_inverse=False,
return_counts=True,
dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_counts_dim0, x_counts)
x_unique, x_inverse, x_counts = torch._unique_dim2_temporary_will_remove_soon(
x,
return_inverse=True,
return_counts=True,
dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_inverse_dim0, x_inverse)
self.assertEqual(expected_counts_dim0, x_counts)
# dim1
x_unique = torch.unique(x, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
x_unique, x_inverse = torch.unique(
x,
return_inverse=True,
dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_inverse_dim1, x_inverse)
x_unique, _, x_counts = torch._unique_dim2_temporary_will_remove_soon(
x,
return_inverse=False,
return_counts=True,
dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_counts_dim1, x_counts)
x_unique, x_inverse, x_counts = torch._unique_dim2_temporary_will_remove_soon(
x,
return_inverse=True,
return_counts=True,
dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_inverse_dim1, x_inverse)
self.assertEqual(expected_counts_dim1, x_counts)
# dim2
x_unique = torch.unique(x, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
x_unique, x_inverse = torch.unique(
x,
return_inverse=True,
dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_inverse_dim2, x_inverse)
x_unique, _, x_counts = torch._unique_dim2_temporary_will_remove_soon(
x,
return_inverse=False,
return_counts=True,
dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_counts_dim2, x_counts)
x_unique, x_inverse, x_counts = torch._unique_dim2_temporary_will_remove_soon(
x,
return_inverse=True,
return_counts=True,
dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_inverse_dim2, x_inverse)
self.assertEqual(expected_counts_dim2, x_counts)
run_test(torch.float)
run_test(torch.double)
run_test(torch.long)
run_test(torch.uint8)
if torch.cuda.is_available():
run_test(torch.float, torch.device('cuda'))
run_test(torch.double, torch.device('cuda'))
run_test(torch.long, torch.device('cuda'))
run_test(torch.uint8, torch.device('cuda'))
@staticmethod
def _test_bincount(self, device):