mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D14605905: [pytorch][PR] Add return_counts to torch.unique
Differential Revision: D14605905 Original commit changeset: 555f5a12a8e2 fbshipit-source-id: c7874f5987893e956c022180a37763d88bba38db
This commit is contained in:
committed by
Facebook Github Bot
parent
bdd098c694
commit
66628f78b7
@ -14,11 +14,10 @@ namespace native{
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
|
||||
std::tuple<Tensor, Tensor> _unique_cpu_template(
|
||||
const Tensor& self,
|
||||
const bool sorted,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
const bool return_inverse) {
|
||||
const Tensor& input = self.contiguous();
|
||||
const scalar_t* input_data = input.data<scalar_t>();
|
||||
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
|
||||
@ -34,8 +33,7 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
|
||||
}
|
||||
|
||||
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
|
||||
Tensor counts = at::empty({0}, self.options().dtype(kLong));
|
||||
if (return_inverse || return_counts) {
|
||||
if (return_inverse) {
|
||||
inverse_indices.resize_(input.sizes());
|
||||
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
|
||||
std::unordered_map<scalar_t, int64_t> inverse_map;
|
||||
@ -46,29 +44,21 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
|
||||
for (int i = 0; i < input.numel(); ++i) {
|
||||
inverse_indices_data[i] = inverse_map[input_data[i]];
|
||||
}
|
||||
if (return_counts) {
|
||||
counts.resize_(output.sizes());
|
||||
counts.fill_(0);
|
||||
for (int i = 0; i < input.numel(); ++i) {
|
||||
counts[inverse_map[input_data[i]]] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(output, inverse_indices, counts);
|
||||
return std::make_tuple(output, inverse_indices);
|
||||
}
|
||||
|
||||
template<class ForwardIt>
|
||||
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
|
||||
std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
|
||||
std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
|
||||
if (first == last) {
|
||||
return last;
|
||||
}
|
||||
// save to calculate distance to iterators
|
||||
ForwardIt begin = first;
|
||||
|
||||
// set first inverse index and count
|
||||
// set first inverse index
|
||||
inverse_indices_vec[indices[0]] = 0;
|
||||
counts[0] += 1;
|
||||
|
||||
ForwardIt result = first;
|
||||
while (++first != last) {
|
||||
@ -78,18 +68,16 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
|
||||
int64_t idx_result = std::distance(begin, result);
|
||||
int64_t idx_first = std::distance(begin, first);
|
||||
inverse_indices_vec[indices[idx_first]] = idx_result;
|
||||
counts[idx_result] += 1;
|
||||
}
|
||||
|
||||
return ++result;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
|
||||
std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
|
||||
const Tensor& self,
|
||||
const int64_t dim,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
const bool return_inverse) {
|
||||
// reshape tensor as [dim, -1]
|
||||
Tensor input_flat = self.transpose(dim, 0);
|
||||
auto orig_sizes = input_flat.sizes().vec();
|
||||
@ -121,12 +109,10 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
|
||||
}
|
||||
|
||||
Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
|
||||
Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong));
|
||||
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
|
||||
auto last = _unique_dim_cpu_impl(
|
||||
input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts);
|
||||
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
|
||||
input_unbind.erase(last, input_unbind.end());
|
||||
counts = at::narrow(counts, 0, 0, input_unbind.size());
|
||||
|
||||
// reshape back
|
||||
auto output = at::stack(input_unbind, 0);
|
||||
@ -135,23 +121,22 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
|
||||
output = output.view(new_sizes);
|
||||
output = output.transpose(0, dim);
|
||||
|
||||
return std::make_tuple(output, inverse_indices, counts);
|
||||
return std::make_tuple(output, inverse_indices);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
|
||||
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
|
||||
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
|
||||
std::tuple<Tensor, Tensor>
|
||||
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
|
||||
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] {
|
||||
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
|
||||
});
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
|
||||
std::tuple<Tensor, Tensor>
|
||||
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
|
||||
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
|
||||
// The current implementation using `dim` always sorts due to unhashable tensors
|
||||
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
|
||||
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -16,10 +16,9 @@ namespace native{
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique_cuda_template(
|
||||
std::tuple<Tensor, Tensor> _unique_cuda_template(
|
||||
const Tensor& self,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
const bool return_inverse) {
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
@ -29,7 +28,7 @@ template <typename scalar_t>
|
||||
int64_t num_inp = input.numel();
|
||||
const scalar_t* input_data = input.data<scalar_t>();
|
||||
|
||||
//sort
|
||||
//sort & unique
|
||||
Tensor output = input.clone();
|
||||
output = output.view(-1);
|
||||
scalar_t* output_data = output.data<scalar_t>();
|
||||
@ -48,36 +47,21 @@ template <typename scalar_t>
|
||||
thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
|
||||
inv_loc[0] = 0;
|
||||
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
|
||||
thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
|
||||
thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
|
||||
inverse_indices.resize_(input.sizes());
|
||||
}
|
||||
|
||||
// unique
|
||||
Tensor counts = at::empty({0}, self.options().dtype(kLong));
|
||||
if (!return_counts) {
|
||||
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
|
||||
output.resize_(num_out);
|
||||
} else {
|
||||
Tensor sorted_indices = at::arange(0, num_inp + 1, self.type().toScalarType(kLong));
|
||||
int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
|
||||
int64_t num_out = thrust::unique_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr).first - output_data;
|
||||
sorted_indices[num_out] = num_inp;
|
||||
output.resize_(num_out);
|
||||
counts.resize_(num_out);
|
||||
int64_t* counts_ptr = counts.data<int64_t>();
|
||||
thrust::adjacent_difference(policy, sorted_indices_ptr + 1, sorted_indices_ptr + num_out + 1, counts_ptr);
|
||||
}
|
||||
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
|
||||
output.resize_(num_out);
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
|
||||
return std::tuple<Tensor, Tensor>(output, inverse_indices);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cuda_template(
|
||||
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
|
||||
const Tensor& self,
|
||||
const int64_t dim,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
const bool return_inverse) {
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
|
||||
@ -89,7 +73,7 @@ template <typename scalar_t>
|
||||
|
||||
scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
|
||||
|
||||
Tensor indices = at::arange(0, input_flat.size(0), self.options().dtype(kLong));
|
||||
Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
|
||||
int64_t* indices_ptr = indices.data<int64_t>();
|
||||
int64_t numel = input_flat.size(1);
|
||||
|
||||
@ -112,7 +96,7 @@ template <typename scalar_t>
|
||||
|
||||
// get unique tensors
|
||||
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
|
||||
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.options().dtype(kLong));
|
||||
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
|
||||
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
|
||||
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
|
||||
[=] __device__ (int64_t a, int64_t b) -> bool {
|
||||
@ -134,13 +118,12 @@ template <typename scalar_t>
|
||||
output = output.view(new_sizes);
|
||||
output = output.transpose(0, dim);
|
||||
|
||||
// calculate inverse indices and counts
|
||||
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
|
||||
Tensor counts = at::zeros(output.size(dim), self.options().dtype(kLong));
|
||||
if (return_inverse || return_counts) {
|
||||
// calculate inverse indices
|
||||
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
|
||||
if (return_inverse) {
|
||||
int64_t size = self.size(dim);
|
||||
inverse_indices.resize_(size);
|
||||
Tensor mask = at::empty(input_sorted.size(0), self.options().dtype(kLong));
|
||||
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
|
||||
mask[0] = 1;
|
||||
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
|
||||
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
|
||||
@ -153,29 +136,27 @@ template <typename scalar_t>
|
||||
Tensor imask = at::cumsum(mask, 0) - 1;
|
||||
for (int i = 0; i < indices.size(0); ++i) {
|
||||
inverse_indices[indices[i]] = imask[i];
|
||||
counts[inverse_indices[indices[i]]] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
|
||||
return std::tuple<Tensor, Tensor>(output, inverse_indices);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
|
||||
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
|
||||
std::tuple<Tensor, Tensor>
|
||||
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
|
||||
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] {
|
||||
// The current CUDA implementation of unique always sort due to the
|
||||
// lack of hashtable implementation in thrust
|
||||
return _unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
|
||||
return _unique_cuda_template<scalar_t>(self, return_inverse);
|
||||
});
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
|
||||
std::tuple<Tensor, Tensor>
|
||||
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
|
||||
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
|
||||
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
|
||||
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -2339,14 +2339,14 @@
|
||||
matches_jit_signature: True
|
||||
variants: method
|
||||
|
||||
- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
|
||||
- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
|
||||
matches_jit_signature: True
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _unique_cpu
|
||||
CUDA: _unique_cuda
|
||||
|
||||
- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
|
||||
- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
|
||||
matches_jit_signature: True
|
||||
variants: function
|
||||
dispatch:
|
||||
|
@ -301,7 +301,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||
indices = nz.clone();
|
||||
} else {
|
||||
Tensor i = nz.narrow(0, 0, sparse_dim);
|
||||
std::tie(indices, std::ignore, std::ignore) = _unique_dim(i, 1);
|
||||
std::tie(indices, std::ignore) = _unique_dim(i, 1);
|
||||
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
|
||||
}
|
||||
|
||||
|
@ -10362,7 +10362,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3])
|
||||
expected_unique = torch.LongTensor([1, 2, 3, 5, 8])
|
||||
expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2])
|
||||
expected_counts = torch.LongTensor([1, 3, 2, 1, 1])
|
||||
|
||||
x_unique = torch.unique(x)
|
||||
self.assertEqual(
|
||||
@ -10376,62 +10375,38 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
x_unique = x.unique(sorted=True)
|
||||
self.assertEqual(expected_unique, x_unique)
|
||||
|
||||
x_unique, x_counts = x.unique(sorted=True, return_counts=True)
|
||||
self.assertEqual(expected_counts, x_counts)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x, sorted=True, return_inverse=True)
|
||||
self.assertEqual(expected_unique, x_unique)
|
||||
self.assertEqual(expected_inverse, x_inverse)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x, sorted=True, return_inverse=True, return_counts=True)
|
||||
self.assertEqual(expected_unique, x_unique)
|
||||
self.assertEqual(expected_inverse, x_inverse)
|
||||
self.assertEqual(expected_counts, x_counts)
|
||||
|
||||
# Tests per-element unique on a higher rank tensor.
|
||||
y = x.view(2, 2, 2)
|
||||
y_unique, y_inverse = y.unique(sorted=True, return_inverse=True)
|
||||
self.assertEqual(expected_unique, y_unique)
|
||||
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
|
||||
|
||||
y_unique, y_inverse, y_counts = y.unique(
|
||||
sorted=True, return_inverse=True, return_counts=True)
|
||||
self.assertEqual(expected_unique, y_unique)
|
||||
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
|
||||
self.assertEqual(expected_counts, y_counts)
|
||||
|
||||
# Tests unique on other types.
|
||||
int_unique, int_inverse, int_counts = torch.unique(
|
||||
torch.IntTensor([2, 1, 2]),
|
||||
sorted=True,
|
||||
return_inverse=True,
|
||||
return_counts=True
|
||||
)
|
||||
int_unique, int_inverse = torch.unique(
|
||||
torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True)
|
||||
self.assertEqual(torch.IntTensor([1, 2]), int_unique)
|
||||
self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse)
|
||||
self.assertEqual(torch.LongTensor([1, 2]), int_counts)
|
||||
|
||||
double_unique, double_inverse, double_counts = torch.unique(
|
||||
double_unique, double_inverse = torch.unique(
|
||||
torch.DoubleTensor([2., 1.5, 2.1, 2.]),
|
||||
sorted=True,
|
||||
return_inverse=True,
|
||||
return_counts=True
|
||||
)
|
||||
self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique)
|
||||
self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse)
|
||||
self.assertEqual(torch.LongTensor([1, 2, 1]), double_counts)
|
||||
|
||||
byte_unique, byte_inverse, byte_counts = torch.unique(
|
||||
byte_unique, byte_inverse = torch.unique(
|
||||
torch.ByteTensor([133, 7, 7, 7, 42, 128]),
|
||||
sorted=True,
|
||||
return_inverse=True,
|
||||
return_counts=True
|
||||
)
|
||||
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
|
||||
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
|
||||
self.assertEqual(torch.LongTensor([3, 1, 1, 1]), byte_counts)
|
||||
|
||||
def test_unique_dim(self):
|
||||
def run_test(dtype=torch.float):
|
||||
@ -10448,7 +10423,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]]], dtype=dtype)
|
||||
expected_inverse_dim0 = torch.tensor([0, 0])
|
||||
expected_counts_dim0 = torch.tensor([2])
|
||||
expected_unique_dim1 = torch.tensor([[[0., 1.],
|
||||
[1., 1.],
|
||||
[2., 1.]],
|
||||
@ -10456,7 +10430,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
[1., 1.],
|
||||
[2., 1.]]], dtype=dtype)
|
||||
expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
|
||||
expected_counts_dim1 = torch.tensor([2, 1, 1])
|
||||
expected_unique_dim2 = torch.tensor([[[1., 1.],
|
||||
[0., 1.],
|
||||
[2., 1.],
|
||||
@ -10466,95 +10439,31 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]]], dtype=dtype)
|
||||
expected_inverse_dim2 = torch.tensor([0, 1])
|
||||
expected_counts_dim2 = torch.tensor([1, 1])
|
||||
|
||||
# dim0
|
||||
x_unique = torch.unique(x, dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=False,
|
||||
dim=0)
|
||||
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
self.assertEqual(expected_inverse_dim0, x_inverse)
|
||||
|
||||
x_unique, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=False,
|
||||
return_counts=True,
|
||||
dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
self.assertEqual(expected_counts_dim0, x_counts)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
self.assertEqual(expected_inverse_dim0, x_inverse)
|
||||
self.assertEqual(expected_counts_dim0, x_counts)
|
||||
|
||||
# dim1
|
||||
x_unique = torch.unique(x, dim=1)
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=False,
|
||||
dim=1)
|
||||
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
self.assertEqual(expected_inverse_dim1, x_inverse)
|
||||
|
||||
x_unique, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=False,
|
||||
return_counts=True,
|
||||
dim=1)
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
self.assertEqual(expected_counts_dim1, x_counts)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=1)
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
self.assertEqual(expected_inverse_dim1, x_inverse)
|
||||
self.assertEqual(expected_counts_dim1, x_counts)
|
||||
|
||||
# dim2
|
||||
x_unique = torch.unique(x, dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=False,
|
||||
dim=2)
|
||||
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
self.assertEqual(expected_inverse_dim2, x_inverse)
|
||||
|
||||
x_unique, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=False,
|
||||
return_counts=True,
|
||||
dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
self.assertEqual(expected_counts_dim2, x_counts)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
self.assertEqual(expected_inverse_dim2, x_inverse)
|
||||
self.assertEqual(expected_counts_dim2, x_counts)
|
||||
|
||||
run_test(torch.float)
|
||||
run_test(torch.double)
|
||||
run_test(torch.long)
|
||||
|
@ -871,7 +871,7 @@
|
||||
- name: uniform_(Tensor self, double from, double to, Generator generator)
|
||||
self: zeros_like(grad)
|
||||
|
||||
- name: _unique(Tensor self, bool sorted, bool return_inverse, bool return_counts)
|
||||
- name: _unique(Tensor self, bool sorted, bool return_inverse)
|
||||
self: not_implemented("_unique")
|
||||
|
||||
- name: _unsafe_view(Tensor self, IntArrayRef size)
|
||||
|
@ -374,8 +374,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
|
||||
return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
||||
|
||||
|
||||
def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
r"""Returns the unique elements of the input tensor.
|
||||
def unique(input, sorted=True, return_inverse=False, dim=None):
|
||||
r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
|
||||
|
||||
Arguments:
|
||||
input (Tensor): the input tensor
|
||||
@ -383,26 +383,18 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
|
||||
before returning as output.
|
||||
return_inverse (bool): Whether to also return the indices for where
|
||||
elements in the original input ended up in the returned unique list.
|
||||
return_counts (bool): Whether to also return the counts for each unique
|
||||
element.
|
||||
dim (int): the dimension to apply unique. If ``None``, the unique of the
|
||||
flattened input is returned. default: ``None``
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor (optional) Tensor (optional)):
|
||||
A tensor or a tuple of tensors containing
|
||||
(Tensor, Tensor (optional)): A tensor or a tuple of tensors containing
|
||||
|
||||
- **output** (*Tensor*): the output list of unique scalar elements.
|
||||
- **inverse_indices** (*Tensor*): (optional) if
|
||||
:attr:`return_inverse` is True, there will be an additional
|
||||
returned tensor (same shape as input) representing the indices
|
||||
:attr:`return_inverse` is True, there will be a
|
||||
2nd returned tensor (same shape as input) representing the indices
|
||||
for where elements in the original input map to in the output;
|
||||
otherwise, this function will only return a single tensor.
|
||||
- **counts** (*Tensor*): (optional) if
|
||||
:attr:`return_counts` is True, there will be an additional
|
||||
returned tensor (same shape as output or output.size(dim),
|
||||
if dim was specified) representing the number of occurences
|
||||
for each unique value or tensor.
|
||||
|
||||
Example::
|
||||
|
||||
@ -427,26 +419,20 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
|
||||
|
||||
"""
|
||||
if dim is not None:
|
||||
output, inverse_indices, counts = torch._unique_dim(
|
||||
output, inverse_indices = torch._unique_dim(
|
||||
input,
|
||||
dim,
|
||||
sorted=sorted,
|
||||
return_inverse=return_inverse,
|
||||
return_counts=return_counts
|
||||
return_inverse=return_inverse
|
||||
)
|
||||
else:
|
||||
output, inverse_indices, counts = torch._unique(
|
||||
output, inverse_indices = torch._unique(
|
||||
input,
|
||||
sorted=sorted,
|
||||
return_inverse=return_inverse,
|
||||
return_counts=return_counts
|
||||
)
|
||||
if return_inverse and return_counts:
|
||||
return output, inverse_indices, counts
|
||||
elif return_inverse:
|
||||
if return_inverse:
|
||||
return output, inverse_indices
|
||||
elif return_counts:
|
||||
return output, counts
|
||||
else:
|
||||
return output
|
||||
|
||||
|
@ -1205,11 +1205,10 @@ def conv_tbc(g, input, weight, bias, pad):
|
||||
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'i', 'i')
|
||||
def _unique(g, input, sorted, return_inverse, return_counts):
|
||||
@parse_args('v', 'i', 'i')
|
||||
def _unique(g, input, sorted, return_inverse):
|
||||
return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
|
||||
return_inverse_i=return_inverse, return_counts_i=return_counts,
|
||||
outputs=3)
|
||||
return_inverse_i=return_inverse, outputs=2)
|
||||
|
||||
|
||||
# Metaprogram symbolics for each ATen native specialized cast operator.
|
||||
|
@ -315,32 +315,26 @@ class Tensor(torch._C._TensorBase):
|
||||
else:
|
||||
return super(Tensor, self).split_with_sizes(split_size, dim)
|
||||
|
||||
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
def unique(self, sorted=True, return_inverse=False, dim=None):
|
||||
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
|
||||
|
||||
See :func:`torch.unique`
|
||||
"""
|
||||
if dim is not None:
|
||||
output, inverse_indices, counts = torch._unique_dim(
|
||||
output, inverse_indices = torch._unique_dim(
|
||||
self,
|
||||
sorted=sorted,
|
||||
return_inverse=return_inverse,
|
||||
return_counts=return_counts,
|
||||
dim=dim
|
||||
)
|
||||
else:
|
||||
output, inverse_indices, counts = torch._unique(
|
||||
output, inverse_indices = torch._unique(
|
||||
self,
|
||||
sorted=sorted,
|
||||
return_inverse=return_inverse,
|
||||
return_counts=return_counts
|
||||
return_inverse=return_inverse
|
||||
)
|
||||
if return_inverse and return_counts:
|
||||
return output, inverse_indices, counts
|
||||
elif return_inverse:
|
||||
if return_inverse:
|
||||
return output, inverse_indices
|
||||
elif return_counts:
|
||||
return output, counts
|
||||
else:
|
||||
return output
|
||||
|
||||
|
Reference in New Issue
Block a user