Revert D14605905: [pytorch][PR] Add return_counts to torch.unique

Differential Revision:
D14605905

Original commit changeset: 555f5a12a8e2

fbshipit-source-id: c7874f5987893e956c022180a37763d88bba38db
This commit is contained in:
Soumith Chintala
2019-03-26 17:14:26 -07:00
committed by Facebook Github Bot
parent bdd098c694
commit 66628f78b7
9 changed files with 68 additions and 214 deletions

View File

@ -14,11 +14,10 @@ namespace native{
namespace {
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
std::tuple<Tensor, Tensor> _unique_cpu_template(
const Tensor& self,
const bool sorted,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {
const Tensor& input = self.contiguous();
const scalar_t* input_data = input.data<scalar_t>();
std::unordered_set<scalar_t> set(input_data, input_data + input.numel());
@ -34,8 +33,7 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
}
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
Tensor counts = at::empty({0}, self.options().dtype(kLong));
if (return_inverse || return_counts) {
if (return_inverse) {
inverse_indices.resize_(input.sizes());
int64_t* inverse_indices_data = inverse_indices.data<int64_t>();
std::unordered_map<scalar_t, int64_t> inverse_map;
@ -46,29 +44,21 @@ std::tuple<Tensor, Tensor, Tensor> _unique_cpu_template(
for (int i = 0; i < input.numel(); ++i) {
inverse_indices_data[i] = inverse_map[input_data[i]];
}
if (return_counts) {
counts.resize_(output.sizes());
counts.fill_(0);
for (int i = 0; i < input.numel(); ++i) {
counts[inverse_map[input_data[i]]] += 1;
}
}
}
return std::make_tuple(output, inverse_indices, counts);
return std::make_tuple(output, inverse_indices);
}
template<class ForwardIt>
ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
std::vector<int64_t>& indices, Tensor inverse_indices_vec, Tensor counts) {
std::vector<int64_t>& indices, Tensor inverse_indices_vec) {
if (first == last) {
return last;
}
// save to calculate distance to iterators
ForwardIt begin = first;
// set first inverse index and count
// set first inverse index
inverse_indices_vec[indices[0]] = 0;
counts[0] += 1;
ForwardIt result = first;
while (++first != last) {
@ -78,18 +68,16 @@ ForwardIt _unique_dim_cpu_impl(ForwardIt first, ForwardIt last,
int64_t idx_result = std::distance(begin, result);
int64_t idx_first = std::distance(begin, first);
inverse_indices_vec[indices[idx_first]] = idx_result;
counts[idx_result] += 1;
}
return ++result;
}
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {
// reshape tensor as [dim, -1]
Tensor input_flat = self.transpose(dim, 0);
auto orig_sizes = input_flat.sizes().vec();
@ -121,12 +109,10 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
}
Tensor inverse_indices = at::empty(indices.size(), self.options().dtype(kLong));
Tensor counts = at::zeros(indices.size(), self.options().dtype(kLong));
std::vector<Tensor> input_unbind = at::unbind(input_sorted, 0);
auto last = _unique_dim_cpu_impl(
input_unbind.begin(), input_unbind.end(), indices, inverse_indices, counts);
input_unbind.begin(), input_unbind.end(), indices, inverse_indices);
input_unbind.erase(last, input_unbind.end());
counts = at::narrow(counts, 0, 0, input_unbind.size());
// reshape back
auto output = at::stack(input_unbind, 0);
@ -135,23 +121,22 @@ std::tuple<Tensor, Tensor, Tensor> _unique_dim_cpu_template(
output = output.view(new_sizes);
output = output.transpose(0, dim);
return std::make_tuple(output, inverse_indices, counts);
return std::make_tuple(output, inverse_indices);
}
} // namespace
std::tuple<Tensor, Tensor, Tensor>
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse, return_counts);
std::tuple<Tensor, Tensor>
_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cpu", [&] {
return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
std::tuple<Tensor, Tensor>
_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
// The current implementation using `dim` always sorts due to unhashable tensors
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse, return_counts);
return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
});
}

View File

@ -16,10 +16,9 @@ namespace native{
namespace {
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_cuda_template(
std::tuple<Tensor, Tensor> _unique_cuda_template(
const Tensor& self,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
@ -29,7 +28,7 @@ template <typename scalar_t>
int64_t num_inp = input.numel();
const scalar_t* input_data = input.data<scalar_t>();
//sort
//sort & unique
Tensor output = input.clone();
output = output.view(-1);
scalar_t* output_data = output.data<scalar_t>();
@ -48,36 +47,21 @@ template <typename scalar_t>
thrust::adjacent_difference(policy, output_data, output_data + num_inp, inv_loc_ptr, [=] __device__ (scalar_t a, scalar_t b) -> int64_t { if (a != b) {return 1;} else { return 0; }});
inv_loc[0] = 0;
thrust::inclusive_scan(policy, inv_loc_ptr, inv_loc_ptr + num_inp, inv_loc_ptr);
thrust::scatter(policy, inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
thrust::scatter(policy,inv_loc_ptr, inv_loc_ptr + num_inp, sorted_indices_ptr, inverse_indices_ptr);
inverse_indices.resize_(input.sizes());
}
// unique
Tensor counts = at::empty({0}, self.options().dtype(kLong));
if (!return_counts) {
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
output.resize_(num_out);
} else {
Tensor sorted_indices = at::arange(0, num_inp + 1, self.type().toScalarType(kLong));
int64_t* sorted_indices_ptr = sorted_indices.data<int64_t>();
int64_t num_out = thrust::unique_by_key(policy, output_data, output_data + num_inp, sorted_indices_ptr).first - output_data;
sorted_indices[num_out] = num_inp;
output.resize_(num_out);
counts.resize_(num_out);
int64_t* counts_ptr = counts.data<int64_t>();
thrust::adjacent_difference(policy, sorted_indices_ptr + 1, sorted_indices_ptr + num_out + 1, counts_ptr);
}
int64_t num_out = thrust::unique(policy, output_data, output_data + num_inp) - output_data;
output.resize_(num_out);
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
template <typename scalar_t>
std::tuple<Tensor, Tensor, Tensor> _unique_dim_cuda_template(
std::tuple<Tensor, Tensor> _unique_dim_cuda_template(
const Tensor& self,
const int64_t dim,
const bool return_inverse,
const bool return_counts) {
const bool return_inverse) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
@ -89,7 +73,7 @@ template <typename scalar_t>
scalar_t* input_flat_ptr = input_flat.data<scalar_t>();
Tensor indices = at::arange(0, input_flat.size(0), self.options().dtype(kLong));
Tensor indices = at::arange(0, input_flat.size(0), self.type().toScalarType(kLong));
int64_t* indices_ptr = indices.data<int64_t>();
int64_t numel = input_flat.size(1);
@ -112,7 +96,7 @@ template <typename scalar_t>
// get unique tensors
scalar_t* input_sorted_ptr = input_sorted.data<scalar_t>();
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.options().dtype(kLong));
Tensor input_sorted_indices = at::arange(0, input_sorted.size(0), self.type().toScalarType(kLong));
int64_t* input_sorted_indices_ptr = input_sorted_indices.data<int64_t>();
auto last = thrust::unique(policy, input_sorted_indices_ptr, input_sorted_indices_ptr + input_sorted_indices.numel(),
[=] __device__ (int64_t a, int64_t b) -> bool {
@ -134,13 +118,12 @@ template <typename scalar_t>
output = output.view(new_sizes);
output = output.transpose(0, dim);
// calculate inverse indices and counts
Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
Tensor counts = at::zeros(output.size(dim), self.options().dtype(kLong));
if (return_inverse || return_counts) {
// calculate inverse indices
Tensor inverse_indices = at::empty({0}, self.type().toScalarType(kLong));
if (return_inverse) {
int64_t size = self.size(dim);
inverse_indices.resize_(size);
Tensor mask = at::empty(input_sorted.size(0), self.options().dtype(kLong));
Tensor mask = at::empty(input_sorted.size(0), self.type().toScalarType(kLong));
mask[0] = 1;
for (int i = 0; i < input_sorted.size(0) - 1; ++i) {
if (!at::equal(input_sorted[i], input_sorted[i+1])) {
@ -153,29 +136,27 @@ template <typename scalar_t>
Tensor imask = at::cumsum(mask, 0) - 1;
for (int i = 0; i < indices.size(0); ++i) {
inverse_indices[indices[i]] = imask[i];
counts[inverse_indices[indices[i]]] += 1;
}
}
THCudaCheck(cudaGetLastError());
return std::tuple<Tensor, Tensor, Tensor>(output, inverse_indices, counts);
return std::tuple<Tensor, Tensor>(output, inverse_indices);
}
} // namespace
std::tuple<Tensor, Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, const bool return_counts) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique", [&] {
std::tuple<Tensor, Tensor>
_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_cuda", [&] {
// The current CUDA implementation of unique always sort due to the
// lack of hashtable implementation in thrust
return _unique_cuda_template<scalar_t>(self, return_inverse, return_counts);
return _unique_cuda_template<scalar_t>(self, return_inverse);
});
}
std::tuple<Tensor, Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
std::tuple<Tensor, Tensor>
_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "unique_dim", [&] {
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse, return_counts);
return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
});
}

View File

@ -2339,14 +2339,14 @@
matches_jit_signature: True
variants: method
- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
CPU: _unique_cpu
CUDA: _unique_cuda
- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
- func: _unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:

View File

@ -301,7 +301,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
indices = nz.clone();
} else {
Tensor i = nz.narrow(0, 0, sparse_dim);
std::tie(indices, std::ignore, std::ignore) = _unique_dim(i, 1);
std::tie(indices, std::ignore) = _unique_dim(i, 1);
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
}

View File

@ -10362,7 +10362,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
x = torch.LongTensor([1, 2, 3, 2, 8, 5, 2, 3])
expected_unique = torch.LongTensor([1, 2, 3, 5, 8])
expected_inverse = torch.LongTensor([0, 1, 2, 1, 4, 3, 1, 2])
expected_counts = torch.LongTensor([1, 3, 2, 1, 1])
x_unique = torch.unique(x)
self.assertEqual(
@ -10376,62 +10375,38 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
x_unique = x.unique(sorted=True)
self.assertEqual(expected_unique, x_unique)
x_unique, x_counts = x.unique(sorted=True, return_counts=True)
self.assertEqual(expected_counts, x_counts)
x_unique, x_inverse = torch.unique(
x, sorted=True, return_inverse=True)
self.assertEqual(expected_unique, x_unique)
self.assertEqual(expected_inverse, x_inverse)
x_unique, x_inverse, x_counts = torch.unique(
x, sorted=True, return_inverse=True, return_counts=True)
self.assertEqual(expected_unique, x_unique)
self.assertEqual(expected_inverse, x_inverse)
self.assertEqual(expected_counts, x_counts)
# Tests per-element unique on a higher rank tensor.
y = x.view(2, 2, 2)
y_unique, y_inverse = y.unique(sorted=True, return_inverse=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
y_unique, y_inverse, y_counts = y.unique(
sorted=True, return_inverse=True, return_counts=True)
self.assertEqual(expected_unique, y_unique)
self.assertEqual(expected_inverse.view(y.size()), y_inverse)
self.assertEqual(expected_counts, y_counts)
# Tests unique on other types.
int_unique, int_inverse, int_counts = torch.unique(
torch.IntTensor([2, 1, 2]),
sorted=True,
return_inverse=True,
return_counts=True
)
int_unique, int_inverse = torch.unique(
torch.IntTensor([2, 1, 2]), sorted=True, return_inverse=True)
self.assertEqual(torch.IntTensor([1, 2]), int_unique)
self.assertEqual(torch.LongTensor([1, 0, 1]), int_inverse)
self.assertEqual(torch.LongTensor([1, 2]), int_counts)
double_unique, double_inverse, double_counts = torch.unique(
double_unique, double_inverse = torch.unique(
torch.DoubleTensor([2., 1.5, 2.1, 2.]),
sorted=True,
return_inverse=True,
return_counts=True
)
self.assertEqual(torch.DoubleTensor([1.5, 2., 2.1]), double_unique)
self.assertEqual(torch.LongTensor([1, 0, 2, 1]), double_inverse)
self.assertEqual(torch.LongTensor([1, 2, 1]), double_counts)
byte_unique, byte_inverse, byte_counts = torch.unique(
byte_unique, byte_inverse = torch.unique(
torch.ByteTensor([133, 7, 7, 7, 42, 128]),
sorted=True,
return_inverse=True,
return_counts=True
)
self.assertEqual(torch.ByteTensor([7, 42, 128, 133]), byte_unique)
self.assertEqual(torch.LongTensor([3, 0, 0, 0, 1, 2]), byte_inverse)
self.assertEqual(torch.LongTensor([3, 1, 1, 1]), byte_counts)
def test_unique_dim(self):
def run_test(dtype=torch.float):
@ -10448,7 +10423,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_inverse_dim0 = torch.tensor([0, 0])
expected_counts_dim0 = torch.tensor([2])
expected_unique_dim1 = torch.tensor([[[0., 1.],
[1., 1.],
[2., 1.]],
@ -10456,7 +10430,6 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
[1., 1.],
[2., 1.]]], dtype=dtype)
expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
expected_counts_dim1 = torch.tensor([2, 1, 1])
expected_unique_dim2 = torch.tensor([[[1., 1.],
[0., 1.],
[2., 1.],
@ -10466,95 +10439,31 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
[2., 1.],
[0., 1.]]], dtype=dtype)
expected_inverse_dim2 = torch.tensor([0, 1])
expected_counts_dim2 = torch.tensor([1, 1])
# dim0
x_unique = torch.unique(x, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
x_unique, x_inverse = torch.unique(
x,
return_inverse=True,
return_counts=False,
dim=0)
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_inverse_dim0, x_inverse)
x_unique, x_counts = torch.unique(
x,
return_inverse=False,
return_counts=True,
dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_counts_dim0, x_counts)
x_unique, x_inverse, x_counts = torch.unique(
x,
return_inverse=True,
return_counts=True,
dim=0)
self.assertEqual(expected_unique_dim0, x_unique)
self.assertEqual(expected_inverse_dim0, x_inverse)
self.assertEqual(expected_counts_dim0, x_counts)
# dim1
x_unique = torch.unique(x, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
x_unique, x_inverse = torch.unique(
x,
return_inverse=True,
return_counts=False,
dim=1)
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_inverse_dim1, x_inverse)
x_unique, x_counts = torch.unique(
x,
return_inverse=False,
return_counts=True,
dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_counts_dim1, x_counts)
x_unique, x_inverse, x_counts = torch.unique(
x,
return_inverse=True,
return_counts=True,
dim=1)
self.assertEqual(expected_unique_dim1, x_unique)
self.assertEqual(expected_inverse_dim1, x_inverse)
self.assertEqual(expected_counts_dim1, x_counts)
# dim2
x_unique = torch.unique(x, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
x_unique, x_inverse = torch.unique(
x,
return_inverse=True,
return_counts=False,
dim=2)
x_unique, x_inverse = torch.unique(x, return_inverse=True, dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_inverse_dim2, x_inverse)
x_unique, x_counts = torch.unique(
x,
return_inverse=False,
return_counts=True,
dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_counts_dim2, x_counts)
x_unique, x_inverse, x_counts = torch.unique(
x,
return_inverse=True,
return_counts=True,
dim=2)
self.assertEqual(expected_unique_dim2, x_unique)
self.assertEqual(expected_inverse_dim2, x_inverse)
self.assertEqual(expected_counts_dim2, x_counts)
run_test(torch.float)
run_test(torch.double)
run_test(torch.long)

View File

@ -871,7 +871,7 @@
- name: uniform_(Tensor self, double from, double to, Generator generator)
self: zeros_like(grad)
- name: _unique(Tensor self, bool sorted, bool return_inverse, bool return_counts)
- name: _unique(Tensor self, bool sorted, bool return_inverse)
self: not_implemented("_unique")
- name: _unsafe_view(Tensor self, IntArrayRef size)

View File

@ -374,8 +374,8 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
r"""Returns the unique elements of the input tensor.
def unique(input, sorted=True, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
Arguments:
input (Tensor): the input tensor
@ -383,26 +383,18 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
before returning as output.
return_inverse (bool): Whether to also return the indices for where
elements in the original input ended up in the returned unique list.
return_counts (bool): Whether to also return the counts for each unique
element.
dim (int): the dimension to apply unique. If ``None``, the unique of the
flattened input is returned. default: ``None``
Returns:
(Tensor, Tensor (optional) Tensor (optional)):
A tensor or a tuple of tensors containing
(Tensor, Tensor (optional)): A tensor or a tuple of tensors containing
- **output** (*Tensor*): the output list of unique scalar elements.
- **inverse_indices** (*Tensor*): (optional) if
:attr:`return_inverse` is True, there will be an additional
returned tensor (same shape as input) representing the indices
:attr:`return_inverse` is True, there will be a
2nd returned tensor (same shape as input) representing the indices
for where elements in the original input map to in the output;
otherwise, this function will only return a single tensor.
- **counts** (*Tensor*): (optional) if
:attr:`return_counts` is True, there will be an additional
returned tensor (same shape as output or output.size(dim),
if dim was specified) representing the number of occurences
for each unique value or tensor.
Example::
@ -427,26 +419,20 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No
"""
if dim is not None:
output, inverse_indices, counts = torch._unique_dim(
output, inverse_indices = torch._unique_dim(
input,
dim,
sorted=sorted,
return_inverse=return_inverse,
return_counts=return_counts
return_inverse=return_inverse
)
else:
output, inverse_indices, counts = torch._unique(
output, inverse_indices = torch._unique(
input,
sorted=sorted,
return_inverse=return_inverse,
return_counts=return_counts
)
if return_inverse and return_counts:
return output, inverse_indices, counts
elif return_inverse:
if return_inverse:
return output, inverse_indices
elif return_counts:
return output, counts
else:
return output

View File

@ -1205,11 +1205,10 @@ def conv_tbc(g, input, weight, bias, pad):
return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
@parse_args('v', 'i', 'i', 'i')
def _unique(g, input, sorted, return_inverse, return_counts):
@parse_args('v', 'i', 'i')
def _unique(g, input, sorted, return_inverse):
return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
return_inverse_i=return_inverse, return_counts_i=return_counts,
outputs=3)
return_inverse_i=return_inverse, outputs=2)
# Metaprogram symbolics for each ATen native specialized cast operator.

View File

@ -315,32 +315,26 @@ class Tensor(torch._C._TensorBase):
else:
return super(Tensor, self).split_with_sizes(split_size, dim)
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
def unique(self, sorted=True, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the tensor as a 1-D tensor.
See :func:`torch.unique`
"""
if dim is not None:
output, inverse_indices, counts = torch._unique_dim(
output, inverse_indices = torch._unique_dim(
self,
sorted=sorted,
return_inverse=return_inverse,
return_counts=return_counts,
dim=dim
)
else:
output, inverse_indices, counts = torch._unique(
output, inverse_indices = torch._unique(
self,
sorted=sorted,
return_inverse=return_inverse,
return_counts=return_counts
return_inverse=return_inverse
)
if return_inverse and return_counts:
return output, inverse_indices, counts
elif return_inverse:
if return_inverse:
return output, inverse_indices
elif return_counts:
return output, counts
else:
return output