Use int64_t for nll_loss with cuda inputs (#85395)

Related #85005
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85395
Approved by: https://github.com/t-vi, https://github.com/lezcano
This commit is contained in:
Masaki Kozuki
2022-09-29 17:02:04 +00:00
committed by PyTorch MergeBot
parent 5f26df0345
commit ef0baba23f
3 changed files with 94 additions and 30 deletions

View File

@ -17884,6 +17884,39 @@ class TestNNDeviceType(NNTestCase):
with self.assertRaisesRegex(RuntimeError, msg):
F.nll_loss(x, t, weight=weight)
# Ref: https://github.com/pytorch/pytorch/issue/85005
@onlyCUDA
@largeTensorTest("45GB", "cpu")
@largeTensorTest("45GB", "cuda")
@parametrize_test("reduction", ("none", "mean", "sum"))
def test_nll_loss_large_tensor(self, device, reduction):
shape = [int(2 ** 16), int(2 ** 16) + 1]
input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True)
labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device)
out = F.nll_loss(input, labels, reduction=reduction)
with torch.no_grad():
input_cpu = input.cpu().float().requires_grad_()
labels_cpu = labels.cpu()
out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction)
# workaround to reduce memory usage vs. self.assertEqual, see #84944
rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
if reduction == "sum":
orig_rtol, orig_atol = rtol, atol
rtol, atol = 7 * rtol, 3 * atol
with torch.no_grad():
self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol))
if reduction == "sum":
rtol, atol = orig_rtol, orig_atol
if reduction != "none":
out.backward()
out_cpu.backward()
with torch.no_grad():
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol))
def _nll_loss_helper(self, input_size, reduction, expected, device):
input = torch.rand(input_size, requires_grad=True, device=device)
num_channels = input_size[1]
@ -18190,6 +18223,30 @@ class TestNNDeviceType(NNTestCase):
# i.e. we don't count the ignored_idx at all.
check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:]))
# Ref: https://github.com/pytorch/pytorch/issue/85005
@onlyCUDA
@largeTensorTest("45GB", "cpu")
@largeTensorTest("45GB", "cuda")
@parametrize_test("reduction", ("none", "mean", "sum"))
def test_cross_entropy_large_tensor(self, device, reduction):
logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True)
labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda')
loss = F.cross_entropy(logits, labels, reduction=reduction)
if reduction != "none":
loss.backward()
with torch.no_grad():
logits_cpu = logits.cpu().detach().requires_grad_()
labels_cpu = labels.cpu().detach()
loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction)
if reduction != "none":
loss_cpu.backward()
# workaround to reduce memory usage vs. self.assertEqual, see #84944
rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol))
if reduction != "none":
self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol))
def test_softshrink_negative(self, device):
input = torch.randn(5, device=device, requires_grad=True)