mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Use int64_t for nll_loss with cuda inputs (#85395)
Related #85005 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85395 Approved by: https://github.com/t-vi, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f26df0345
commit
ef0baba23f
@ -17884,6 +17884,39 @@ class TestNNDeviceType(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
F.nll_loss(x, t, weight=weight)
|
||||
|
||||
# Ref: https://github.com/pytorch/pytorch/issue/85005
|
||||
@onlyCUDA
|
||||
@largeTensorTest("45GB", "cpu")
|
||||
@largeTensorTest("45GB", "cuda")
|
||||
@parametrize_test("reduction", ("none", "mean", "sum"))
|
||||
def test_nll_loss_large_tensor(self, device, reduction):
|
||||
shape = [int(2 ** 16), int(2 ** 16) + 1]
|
||||
|
||||
input = torch.randn(shape, device=device, dtype=torch.float32, requires_grad=True)
|
||||
labels = torch.randint(shape[0], (shape[0],), dtype=torch.long, device=device)
|
||||
|
||||
out = F.nll_loss(input, labels, reduction=reduction)
|
||||
|
||||
with torch.no_grad():
|
||||
input_cpu = input.cpu().float().requires_grad_()
|
||||
labels_cpu = labels.cpu()
|
||||
out_cpu = F.nll_loss(input_cpu, labels_cpu, reduction=reduction)
|
||||
# workaround to reduce memory usage vs. self.assertEqual, see #84944
|
||||
rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
|
||||
if reduction == "sum":
|
||||
orig_rtol, orig_atol = rtol, atol
|
||||
rtol, atol = 7 * rtol, 3 * atol
|
||||
with torch.no_grad():
|
||||
self.assertTrue(torch.allclose(out.cpu(), out_cpu, rtol=rtol, atol=atol))
|
||||
if reduction == "sum":
|
||||
rtol, atol = orig_rtol, orig_atol
|
||||
|
||||
if reduction != "none":
|
||||
out.backward()
|
||||
out_cpu.backward()
|
||||
with torch.no_grad():
|
||||
self.assertTrue(torch.allclose(input.grad.cpu(), input_cpu.grad, rtol=rtol, atol=atol))
|
||||
|
||||
def _nll_loss_helper(self, input_size, reduction, expected, device):
|
||||
input = torch.rand(input_size, requires_grad=True, device=device)
|
||||
num_channels = input_size[1]
|
||||
@ -18190,6 +18223,30 @@ class TestNNDeviceType(NNTestCase):
|
||||
# i.e. we don't count the ignored_idx at all.
|
||||
check_equal(loss, (inp1, targ_positive_ignore_index), (inp2[1:], targ_positive_ignore_index[1:]))
|
||||
|
||||
# Ref: https://github.com/pytorch/pytorch/issue/85005
|
||||
@onlyCUDA
|
||||
@largeTensorTest("45GB", "cpu")
|
||||
@largeTensorTest("45GB", "cuda")
|
||||
@parametrize_test("reduction", ("none", "mean", "sum"))
|
||||
def test_cross_entropy_large_tensor(self, device, reduction):
|
||||
logits = torch.randn(int(2 ** 16), int(2 ** 16) + 1, dtype=torch.float32, device='cuda', requires_grad=True)
|
||||
labels = torch.zeros(logits.size(0), dtype=torch.long, device='cuda')
|
||||
loss = F.cross_entropy(logits, labels, reduction=reduction)
|
||||
if reduction != "none":
|
||||
loss.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
logits_cpu = logits.cpu().detach().requires_grad_()
|
||||
labels_cpu = labels.cpu().detach()
|
||||
loss_cpu = F.cross_entropy(logits_cpu, labels_cpu, reduction=reduction)
|
||||
if reduction != "none":
|
||||
loss_cpu.backward()
|
||||
|
||||
# workaround to reduce memory usage vs. self.assertEqual, see #84944
|
||||
rtol, atol = torch.testing._comparison.get_tolerances(torch.float32, rtol=None, atol=None)
|
||||
self.assertTrue(torch.allclose(loss.cpu(), loss_cpu, rtol=rtol, atol=atol))
|
||||
if reduction != "none":
|
||||
self.assertTrue(torch.allclose(logits.grad.cpu(), logits_cpu.grad, rtol=rtol, atol=atol))
|
||||
|
||||
def test_softshrink_negative(self, device):
|
||||
input = torch.randn(5, device=device, requires_grad=True)
|
||||
|
||||
Reference in New Issue
Block a user