Fixed histc return type for CUDA (#20369)

Summary:
Fixing reported [issue](https://github.com/pytorch/pytorch/issues/20208).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20369

Reviewed By: zou3519

Differential Revision: D15300959

Pulled By: izdeby

fbshipit-source-id: 219692f99a66ea433112dfc226132eb6867122cf
This commit is contained in:
Iurii Zdebskyi
2019-05-20 08:01:47 -07:00
committed by Facebook Github Bot
parent d0c742134d
commit 71260b98e2
3 changed files with 55 additions and 59 deletions

View File

@ -311,7 +311,7 @@ Tensor _histc_cuda_template(
if (nbins <= 0) {
AT_ERROR("bins must be > 0");
}
Tensor output = native::zeros({nbins}, device(DeviceType::CUDA).dtype(kLong));
Tensor output = native::zeros({nbins}, device(DeviceType::CUDA).dtype(self.scalar_type()));
input_t minvalue = min;
input_t maxvalue = max;
if (min == max) {
@ -322,7 +322,8 @@ Tensor _histc_cuda_template(
minvalue = minvalue - 1;
maxvalue = maxvalue + 1;
}
auto ret = cuda::CUDA_tensor_histogram<int64_t, input_t, false>(
auto ret = cuda::CUDA_tensor_histogram<input_t, input_t, false>(
output, self, Tensor(), nbins, minvalue, maxvalue);
return output;
}

View File

@ -2727,9 +2727,6 @@ class TestCuda(TestCase):
self.assertEqual(t.cpu().bincount(), t.bincount())
self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w))
def test_histc_cuda(self):
_TestTorchMixin._test_histc(self, device='cuda')
def test_tiny_half_norm_(self):
a = torch.arange(25).cuda().float()
a /= 100000000

View File

@ -2513,58 +2513,59 @@ class _TestTorchMixin(object):
self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out))
self.assertEqual(torch.zeros(shape), torch.zeros(shape, device='cpu', out=out))
@staticmethod
def _test_histc(self, device):
# negative nbins throws
with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
def test_histc(self):
for device in torch.testing.get_all_device_types():
# negative nbins throws
with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'):
torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1)
# without nbins
actual = torch.histc(
torch.tensor([2, 5], dtype=torch.float, device=device))
expected = torch.zeros(100, dtype=torch.float, device=device)
expected.data[0] = 1
expected.data[99] = 1
self.assertEqual(expected, actual)
# tensor with the same element
actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
self.assertEqual(
torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
actual)
# no element falls between [min, max]
actual = torch.histc(
torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
self.assertEqual(
torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
actual)
# element falls below min + integral bin size and
actual = torch.histc(
torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
bins=5, min=1, max=5)
self.assertEqual(
torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
actual)
# non-integral bin size
actual = torch.histc(
torch.tensor([1, 2, 1], dtype=torch.float, device=device),
bins=4, min=0, max=3)
self.assertEqual(
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
actual)
# double input
actual = torch.histc(
torch.tensor([1, 2, 1], dtype=torch.double, device=device),
bins=4, min=0, max=3)
self.assertEqual(
torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
actual)
# mixed input
actual = torch.histc(
torch.tensor([1., 2, 1], dtype=torch.float, device=device),
bins=4, min=0, max=3)
self.assertEqual(
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
actual)
# without nbins
actual = torch.histc(
torch.tensor([2, 5], dtype=torch.float, device=device))
expected = torch.zeros(100, dtype=torch.float, device=device)
expected.data[0] = 1
expected.data[99] = 1
self.assertEqual(expected, actual)
# tensor with the same element
actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5)
self.assertEqual(
torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device),
actual)
# no element falls between [min, max]
actual = torch.histc(
torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3)
self.assertEqual(
torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device),
actual)
# element falls below min + integral bin size and
actual = torch.histc(
torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device),
bins=5, min=1, max=5)
self.assertEqual(
torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device),
actual)
# non-integral bin size
actual = torch.histc(
torch.tensor([1, 2, 1], dtype=torch.float, device=device),
bins=4, min=0, max=3)
self.assertEqual(
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
actual)
# double input
actual = torch.histc(
torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3)
self.assertEqual(
torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device),
actual)
self.assertEqual(actual.dtype, torch.double)
# mixed input
actual = torch.histc(
torch.tensor([1., 2, 1], dtype=torch.float, device=device),
bins=4, min=0, max=3)
self.assertEqual(
torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device),
actual)
self.assertEqual(actual.dtype, torch.float)
# test against numpy.histogram()
def test_against_np(tensor, bins=100, min=0, max=0):
@ -2595,9 +2596,6 @@ class _TestTorchMixin(object):
expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2)
test_against_np(expanded)
def test_histc_cpu(self):
self._test_histc(self, 'cpu')
def test_ones(self):
res1 = torch.ones(100, 100)
res2 = torch.Tensor()