mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixed histc return type for CUDA (#20369)
Summary: Fixing reported [issue](https://github.com/pytorch/pytorch/issues/20208). Pull Request resolved: https://github.com/pytorch/pytorch/pull/20369 Reviewed By: zou3519 Differential Revision: D15300959 Pulled By: izdeby fbshipit-source-id: 219692f99a66ea433112dfc226132eb6867122cf
This commit is contained in:
committed by
Facebook Github Bot
parent
d0c742134d
commit
71260b98e2
@ -311,7 +311,7 @@ Tensor _histc_cuda_template(
|
||||
if (nbins <= 0) {
|
||||
AT_ERROR("bins must be > 0");
|
||||
}
|
||||
Tensor output = native::zeros({nbins}, device(DeviceType::CUDA).dtype(kLong));
|
||||
Tensor output = native::zeros({nbins}, device(DeviceType::CUDA).dtype(self.scalar_type()));
|
||||
input_t minvalue = min;
|
||||
input_t maxvalue = max;
|
||||
if (min == max) {
|
||||
@ -322,7 +322,8 @@ Tensor _histc_cuda_template(
|
||||
minvalue = minvalue - 1;
|
||||
maxvalue = maxvalue + 1;
|
||||
}
|
||||
auto ret = cuda::CUDA_tensor_histogram<int64_t, input_t, false>(
|
||||
|
||||
auto ret = cuda::CUDA_tensor_histogram<input_t, input_t, false>(
|
||||
output, self, Tensor(), nbins, minvalue, maxvalue);
|
||||
return output;
|
||||
}
|
||||
|
@ -2727,9 +2727,6 @@ class TestCuda(TestCase):
|
||||
self.assertEqual(t.cpu().bincount(), t.bincount())
|
||||
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
|
||||
|
||||
def test_histc_cuda(self):
|
||||
_TestTorchMixin._test_histc(self, device='cuda')
|
||||
|
||||
def test_tiny_half_norm_(self):
|
||||
a = torch.arange(25).cuda().float()
|
||||
a /= 100000000
|
||||
|
@ -2513,58 +2513,59 @@ class _TestTorchMixin(object):
|
||||
self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out))
|
||||
self.assertEqual(torch.zeros(shape), torch.zeros(shape, device='cpu', out=out))
|
||||
|
||||
@staticmethod
|
||||
def _test_histc(self, device):
|
||||
# negative nbins throws
|
||||
with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
|
||||
torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
|
||||
def test_histc(self):
|
||||
for device in torch.testing.get_all_device_types():
|
||||
# negative nbins throws
|
||||
with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
|
||||
torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
|
||||
|
||||
# without nbins
|
||||
actual = torch.histc(
|
||||
torch.tensor([2, 5], dtype=torch.float, device=device))
|
||||
expected = torch.zeros(100, dtype=torch.float, device=device)
|
||||
expected.data[0] = 1
|
||||
expected.data[99] = 1
|
||||
self.assertEqual(expected, actual)
|
||||
# tensor with the same element
|
||||
actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# no element falls between [min, max]
|
||||
actual = torch.histc(
|
||||
torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# element falls below min + integral bin size and
|
||||
actual = torch.histc(
|
||||
torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
|
||||
bins=5, min=1, max=5)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# non-integral bin size
|
||||
actual = torch.histc(
|
||||
torch.tensor([1, 2, 1], dtype=torch.float, device=device),
|
||||
bins=4, min=0, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# double input
|
||||
actual = torch.histc(
|
||||
torch.tensor([1, 2, 1], dtype=torch.double, device=device),
|
||||
bins=4, min=0, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
|
||||
actual)
|
||||
# mixed input
|
||||
actual = torch.histc(
|
||||
torch.tensor([1., 2, 1], dtype=torch.float, device=device),
|
||||
bins=4, min=0, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# without nbins
|
||||
actual = torch.histc(
|
||||
torch.tensor([2, 5], dtype=torch.float, device=device))
|
||||
expected = torch.zeros(100, dtype=torch.float, device=device)
|
||||
expected.data[0] = 1
|
||||
expected.data[99] = 1
|
||||
self.assertEqual(expected, actual)
|
||||
# tensor with the same element
|
||||
actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# no element falls between [min, max]
|
||||
actual = torch.histc(
|
||||
torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# element falls below min + integral bin size and
|
||||
actual = torch.histc(
|
||||
torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
|
||||
bins=5, min=1, max=5)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# non-integral bin size
|
||||
actual = torch.histc(
|
||||
torch.tensor([1, 2, 1], dtype=torch.float, device=device),
|
||||
bins=4, min=0, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
# double input
|
||||
actual = torch.histc(
|
||||
torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
|
||||
actual)
|
||||
self.assertEqual(actual.dtype, torch.double)
|
||||
# mixed input
|
||||
actual = torch.histc(
|
||||
torch.tensor([1., 2, 1], dtype=torch.float, device=device),
|
||||
bins=4, min=0, max=3)
|
||||
self.assertEqual(
|
||||
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
|
||||
actual)
|
||||
self.assertEqual(actual.dtype, torch.float)
|
||||
|
||||
# test against numpy.histogram()
|
||||
def test_against_np(tensor, bins=100, min=0, max=0):
|
||||
@ -2595,9 +2596,6 @@ class _TestTorchMixin(object):
|
||||
expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
|
||||
test_against_np(expanded)
|
||||
|
||||
def test_histc_cpu(self):
|
||||
self._test_histc(self, 'cpu')
|
||||
|
||||
def test_ones(self):
|
||||
res1 = torch.ones(100, 100)
|
||||
res2 = torch.Tensor()
|
||||
|
Reference in New Issue
Block a user