mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add at::one_hot (#15208)
Summary: Closes: https://github.com/pytorch/pytorch/issues/15060 Differential Revision: D13528014 Pulled By: ezyang fbshipit-source-id: 5a18689a4c5638d92f9390c91517f741e5396293
This commit is contained in:
committed by
Facebook Github Bot
parent
2a64a78e7b
commit
a47749cb28
35
aten/src/ATen/native/Onehot.cpp
Normal file
35
aten/src/ATen/native/Onehot.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
AT_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||
auto shape = self.sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
AT_ERROR("Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.push_back(num_classes);
|
||||
return at::empty(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
// non-empty tensor
|
||||
AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
|
||||
if (num_classes == -1) {
|
||||
num_classes = self.max().item().toLong() + 1;
|
||||
} else {
|
||||
AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
||||
}
|
||||
|
||||
shape.push_back(num_classes);
|
||||
Tensor ret = at::zeros(shape, self.options());
|
||||
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -1721,6 +1721,10 @@
|
||||
variants: method
|
||||
device_guard: false
|
||||
|
||||
- func: one_hot(IndexTensor self, int64_t num_classes=-1) -> Tensor
|
||||
python_module: nn
|
||||
variants: function
|
||||
|
||||
- func: flip(Tensor self, IntList dims) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
@ -794,6 +794,11 @@ Utilities
|
||||
|
||||
.. autofunction:: torch.nn.utils.remove_spectral_norm
|
||||
|
||||
:hidden:`one_hot`
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autofunction:: torch.nn.utils.one_hot
|
||||
|
||||
|
||||
.. currentmodule:: torch.nn.utils.rnn
|
||||
|
||||
|
@ -2725,6 +2725,64 @@ class TestNN(NNTestCase):
|
||||
self.assertRaisesRegex(RuntimeError, expected_err_msg,
|
||||
lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect'))
|
||||
|
||||
@staticmethod
|
||||
def _test_one_hot(self, use_cuda=False):
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device))
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device))
|
||||
expected = torch.tensor([[[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1]],
|
||||
[[0, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0]]], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.tensor(4, device=device))
|
||||
expected = torch.tensor([0, 0, 0, 0, 1], device=device)
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100)
|
||||
expected = torch.empty([4, 0, 100])
|
||||
self.assertEqual(t, expected)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2)
|
||||
|
||||
def test_one_hot(self):
|
||||
self._test_one_hot(self)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
||||
def test_one_hot_cuda(self):
|
||||
self._test_one_hot(self, use_cuda=True)
|
||||
|
||||
def test_pad_scalar_error(self):
|
||||
inputs = torch.tensor(0., requires_grad=True)
|
||||
self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
|
||||
|
@ -425,7 +425,7 @@ class _TestTorchMixin(object):
|
||||
def compare_reference(input, dtype):
|
||||
input = torch.tensor(input, dtype=dtype)
|
||||
res1 = torchfn(input.clone())
|
||||
res2 = input.clone().apply_(lambda x: mathfn(x))
|
||||
res2 = input.clone().apply_(mathfn)
|
||||
torch.testing.assert_allclose(res1, res2)
|
||||
|
||||
# compare against the reference math function
|
||||
@ -1034,21 +1034,21 @@ class _TestTorchMixin(object):
|
||||
def test_reduction_empty(self):
|
||||
fns_to_test = [
|
||||
# name, function, identity
|
||||
('max', lambda *args, **kwargs: torch.max(*args, **kwargs), None),
|
||||
('max', torch.max, None),
|
||||
('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None),
|
||||
('argmax', lambda *args, **kwargs: torch.argmax(*args, **kwargs), None),
|
||||
('min', lambda *args, **kwargs: torch.min(*args, **kwargs), None),
|
||||
('argmin', lambda *args, **kwargs: torch.argmin(*args, **kwargs), None),
|
||||
('mode', lambda *args, **kwargs: torch.mode(*args, **kwargs), None),
|
||||
('median', lambda *args, **kwargs: torch.median(*args, **kwargs), None),
|
||||
('argmax', torch.argmax, None),
|
||||
('min', torch.min, None),
|
||||
('argmin', torch.argmin, None),
|
||||
('mode', torch.mode, None),
|
||||
('median', torch.median, None),
|
||||
|
||||
('prod', lambda *args, **kwargs: torch.prod(*args, **kwargs), 1),
|
||||
('sum', lambda *args, **kwargs: torch.sum(*args, **kwargs), 0),
|
||||
('norm', lambda *args, **kwargs: torch.norm(*args, p=2, **kwargs), 0),
|
||||
('mean', lambda *args, **kwargs: torch.mean(*args, **kwargs), nan),
|
||||
('var', lambda *args, **kwargs: torch.var(*args, **kwargs), nan),
|
||||
('std', lambda *args, **kwargs: torch.std(*args, **kwargs), nan),
|
||||
('logsumexp', lambda *args, **kwargs: torch.logsumexp(*args, **kwargs), -inf),
|
||||
('prod', torch.prod, 1),
|
||||
('sum', torch.sum, 0),
|
||||
('norm', torch.norm, 0),
|
||||
('mean', torch.mean, nan),
|
||||
('var', torch.var, nan),
|
||||
('std', torch.std, nan),
|
||||
('logsumexp', torch.logsumexp, -inf),
|
||||
]
|
||||
|
||||
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
|
||||
@ -5268,7 +5268,7 @@ class _TestTorchMixin(object):
|
||||
for i in range(o3.size(1)):
|
||||
for j in range(k.size(1)):
|
||||
o32[i].add(torch.xcorr2(x[i + j - 1], k[j]))
|
||||
self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k), reference)
|
||||
self._test_conv_corr_eq(torch.xcorr3, reference)
|
||||
|
||||
@unittest.skip("Not implemented yet")
|
||||
def test_xcorr3_xcorr2_eq_full(self):
|
||||
@ -5284,7 +5284,7 @@ class _TestTorchMixin(object):
|
||||
for i in range(o3.size(1)):
|
||||
for j in range(k.size(1)):
|
||||
o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1]))
|
||||
self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k), reference)
|
||||
self._test_conv_corr_eq(torch.conv3, reference)
|
||||
|
||||
@unittest.skip("Not implemented yet")
|
||||
def test_fconv3_fconv2_eq(self):
|
||||
|
@ -76,15 +76,9 @@ class OneHotCategorical(Distribution):
|
||||
def sample(self, sample_shape=torch.Size()):
|
||||
sample_shape = torch.Size(sample_shape)
|
||||
probs = self._categorical.probs
|
||||
num_events = self._categorical._num_events
|
||||
indices = self._categorical.sample(sample_shape)
|
||||
if torch._C._get_tracing_state():
|
||||
# [JIT WORKAROUND] lack of support for .scatter_()
|
||||
eye = torch.eye(self.event_shape[-1], dtype=self._param.dtype, device=self._param.device)
|
||||
return eye[indices]
|
||||
one_hot = probs.new_zeros(self._extended_shape(sample_shape))
|
||||
if indices.dim() < one_hot.dim():
|
||||
indices = indices.unsqueeze(-1)
|
||||
return one_hot.scatter_(-1, indices, 1.)
|
||||
return torch.nn.functional.one_hot(indices, num_events).to(probs)
|
||||
|
||||
def log_prob(self, value):
|
||||
if self._validate_args:
|
||||
|
@ -2784,6 +2784,55 @@ Example::
|
||||
""")
|
||||
|
||||
|
||||
one_hot = _add_docstr(torch._C._nn.one_hot, r"""
|
||||
one_hot(tensor, num_classes=0) -> LongTensor
|
||||
|
||||
Takes LongTensor with index values of shape ``(*)`` and returns a tensor
|
||||
of shape ``(*, num_classes)`` that have zeros everywhere except where the
|
||||
index of last dimension matches the corresponding value of the input tensor,
|
||||
in which case it will be 1.
|
||||
|
||||
See also `One-hot on Wikipedia`_ .
|
||||
|
||||
.. _One-hot on Wikipedia:
|
||||
https://en.wikipedia.org/wiki/One-hot
|
||||
|
||||
Arguments:
|
||||
tensor (LongTensor): class values of any shape.
|
||||
num_classes (int): Total number of classes. If set to -1, the number
|
||||
of classes will be inferred as one greater than the largest class
|
||||
value in the input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: LongTensor that has one more dimension with 1 values at the
|
||||
index of last dimension indicated by the input, and 0 everywhere
|
||||
else.
|
||||
|
||||
Examples::
|
||||
>>> torch.one_hot(torch.arange(0, 5) % 3)
|
||||
tensor([[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0]])
|
||||
>>> torch.one_hot(torch.arange(0, 5) % 3, num_classes=5)
|
||||
tensor([[1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0]])
|
||||
>>> torch.one_hot(torch.arange(0, 6).view(3,2) % 3)
|
||||
tensor([[[1, 0, 0],
|
||||
[0, 1, 0]],
|
||||
|
||||
[[0, 0, 1],
|
||||
[1, 0, 0]],
|
||||
|
||||
[[0, 1, 0],
|
||||
[0, 0, 1]]])
|
||||
""")
|
||||
|
||||
|
||||
@weak_script
|
||||
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None,
|
||||
reduce=None, reduction="mean"):
|
||||
|
Reference in New Issue
Block a user