Add at::one_hot (#15208)

Summary: Closes: https://github.com/pytorch/pytorch/issues/15060

Differential Revision: D13528014

Pulled By: ezyang

fbshipit-source-id: 5a18689a4c5638d92f9390c91517f741e5396293
This commit is contained in:
Gao, Xiang
2018-12-20 14:09:09 -08:00
committed by Facebook Github Bot
parent 2a64a78e7b
commit a47749cb28
7 changed files with 169 additions and 24 deletions

View File

@ -0,0 +1,35 @@
#include <ATen/ATen.h>
namespace at { namespace native {
Tensor one_hot(const Tensor &self, int64_t num_classes) {
AT_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
auto shape = self.sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.numel() == 0) {
if (num_classes <= 0) {
AT_ERROR("Can not infer total number of classes from empty tensor.");
} else {
shape.push_back(num_classes);
return at::empty(shape, self.options());
}
}
// non-empty tensor
AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
if (num_classes == -1) {
num_classes = self.max().item().toLong() + 1;
} else {
AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
}
shape.push_back(num_classes);
Tensor ret = at::zeros(shape, self.options());
ret.scatter_(-1, self.unsqueeze(-1), 1);
return ret;
}
} // namespace native
} // namespace at

View File

@ -1721,6 +1721,10 @@
variants: method
device_guard: false
- func: one_hot(IndexTensor self, int64_t num_classes=-1) -> Tensor
python_module: nn
variants: function
- func: flip(Tensor self, IntList dims) -> Tensor
variants: function, method
dispatch:

View File

@ -794,6 +794,11 @@ Utilities
.. autofunction:: torch.nn.utils.remove_spectral_norm
:hidden:`one_hot`
~~~~~~~~~~~~~~~~~
.. autofunction:: torch.nn.utils.one_hot
.. currentmodule:: torch.nn.utils.rnn

View File

@ -2725,6 +2725,64 @@ class TestNN(NNTestCase):
self.assertRaisesRegex(RuntimeError, expected_err_msg,
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
@staticmethod
def _test_one_hot(self, use_cuda=False):
device = torch.device('cuda' if use_cuda else 'cpu')
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
expected = torch.tensor([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
expected = torch.tensor([[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
expected = torch.tensor([[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
expected = torch.tensor([[[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]],
[[0, 1, 0, 0, 0],
[1, 0, 0, 0, 0]]], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
expected = torch.tensor([0, 0, 0, 0, 1], device=device)
self.assertEqual(t, expected)
t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
expected = torch.empty([4, 0, 100])
self.assertEqual(t, expected)
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
def test_one_hot(self):
self._test_one_hot(self)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_one_hot_cuda(self):
self._test_one_hot(self, use_cuda=True)
def test_pad_scalar_error(self):
inputs = torch.tensor(0., requires_grad=True)
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))

View File

@ -425,7 +425,7 @@ class _TestTorchMixin(object):
def compare_reference(input, dtype):
input = torch.tensor(input, dtype=dtype)
res1 = torchfn(input.clone())
res2 = input.clone().apply_(lambda x: mathfn(x))
res2 = input.clone().apply_(mathfn)
torch.testing.assert_allclose(res1, res2)
# compare against the reference math function
@ -1034,21 +1034,21 @@ class _TestTorchMixin(object):
def test_reduction_empty(self):
fns_to_test = [
# name, function, identity
('max', lambda *args, **kwargs: torch.max(*args, **kwargs), None),
('max', torch.max, None),
('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None),
('argmax', lambda *args, **kwargs: torch.argmax(*args, **kwargs), None),
('min', lambda *args, **kwargs: torch.min(*args, **kwargs), None),
('argmin', lambda *args, **kwargs: torch.argmin(*args, **kwargs), None),
('mode', lambda *args, **kwargs: torch.mode(*args, **kwargs), None),
('median', lambda *args, **kwargs: torch.median(*args, **kwargs), None),
('argmax', torch.argmax, None),
('min', torch.min, None),
('argmin', torch.argmin, None),
('mode', torch.mode, None),
('median', torch.median, None),
('prod', lambda *args, **kwargs: torch.prod(*args, **kwargs), 1),
('sum', lambda *args, **kwargs: torch.sum(*args, **kwargs), 0),
('norm', lambda *args, **kwargs: torch.norm(*args, p=2, **kwargs), 0),
('mean', lambda *args, **kwargs: torch.mean(*args, **kwargs), nan),
('var', lambda *args, **kwargs: torch.var(*args, **kwargs), nan),
('std', lambda *args, **kwargs: torch.std(*args, **kwargs), nan),
('logsumexp', lambda *args, **kwargs: torch.logsumexp(*args, **kwargs), -inf),
('prod', torch.prod, 1),
('sum', torch.sum, 0),
('norm', torch.norm, 0),
('mean', torch.mean, nan),
('var', torch.var, nan),
('std', torch.std, nan),
('logsumexp', torch.logsumexp, -inf),
]
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
@ -5268,7 +5268,7 @@ class _TestTorchMixin(object):
for i in range(o3.size(1)):
for j in range(k.size(1)):
o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k), reference)
self._test_conv_corr_eq(torch.xcorr3, reference)
@unittest.skip("Not implemented yet")
def test_xcorr3_xcorr2_eq_full(self):
@ -5284,7 +5284,7 @@ class _TestTorchMixin(object):
for i in range(o3.size(1)):
for j in range(k.size(1)):
o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k), reference)
self._test_conv_corr_eq(torch.conv3, reference)
@unittest.skip("Not implemented yet")
def test_fconv3_fconv2_eq(self):

View File

@ -76,15 +76,9 @@ class OneHotCategorical(Distribution):
def sample(self, sample_shape=torch.Size()):
sample_shape = torch.Size(sample_shape)
probs = self._categorical.probs
num_events = self._categorical._num_events
indices = self._categorical.sample(sample_shape)
if torch._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .scatter_()
eye = torch.eye(self.event_shape[-1], dtype=self._param.dtype, device=self._param.device)
return eye[indices]
one_hot = probs.new_zeros(self._extended_shape(sample_shape))
if indices.dim() < one_hot.dim():
indices = indices.unsqueeze(-1)
return one_hot.scatter_(-1, indices, 1.)
return torch.nn.functional.one_hot(indices, num_events).to(probs)
def log_prob(self, value):
if self._validate_args:

View File

@ -2784,6 +2784,55 @@ Example::
""")
one_hot = _add_docstr(torch._C._nn.one_hot, r"""
one_hot(tensor, num_classes=0) -> LongTensor
Takes LongTensor with index values of shape ``(*)`` and returns a tensor
of shape ``(*, num_classes)`` that have zeros everywhere except where the
index of last dimension matches the corresponding value of the input tensor,
in which case it will be 1.
See also `One-hot on Wikipedia`_ .
.. _One-hot on Wikipedia:
https://en.wikipedia.org/wiki/One-hot
Arguments:
tensor (LongTensor): class values of any shape.
num_classes (int): Total number of classes. If set to -1, the number
of classes will be inferred as one greater than the largest class
value in the input tensor.
Returns:
Tensor: LongTensor that has one more dimension with 1 values at the
index of last dimension indicated by the input, and 0 everywhere
else.
Examples::
>>> torch.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 1, 0]])
>>> torch.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0]])
>>> torch.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
[0, 1, 0]],
[[0, 0, 1],
[1, 0, 0]],
[[0, 1, 0],
[0, 0, 1]]])
""")
@weak_script
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
reduce=None, reduction="mean"):