mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add check for ctc_loss targets param (#150981)
Fixes #150835 ## Test Result ```python # cuda >>> import torch >>> import torch.nn.functional as F >>> device = "cuda" # "cpu" is fine >>> num_classes = 4 >>> log_probs = torch.rand(0, 0, num_classes, device=device) >>> targets = torch.tensor([], device=device, dtype=torch.long) >>> input_lengths = torch.tensor([], device=device, dtype=torch.long) >>> target_lengths = torch.tensor([], device=device, dtype=torch.long) >>> result = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/nn/functional.py", line 3079, in ctc_loss return torch.ctc_loss( ^^^^^^^^^^^^^^^ RuntimeError: log_probs tensor must not be empty # cpu >>> device = "cpu" >>> num_classes = 4 >>> log_probs = torch.rand(0, 0, num_classes, device=device) >>> targets = torch.tensor([], device=device, dtype=torch.long) >>> input_lengths = torch.tensor([], device=device, dtype=torch.long) >>> target_lengths = torch.tensor([], device=device, dtype=torch.long) >>> result = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/zong/code/pytorch/torch/nn/functional.py", line 3079, in ctc_loss return torch.ctc_loss( ^^^^^^^^^^^^^^^ RuntimeError: log_probs tensor must not be empty ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150981 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
bbc5fe8504
commit
01f226bfb8
@ -11532,6 +11532,15 @@ class TestNNDeviceType(NNTestCase):
|
||||
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
|
||||
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
|
||||
|
||||
@expectedFailureMPS
|
||||
def test_ctc_loss_error(self, device):
|
||||
log_probs = torch.rand(0, 0, 4, device=device)
|
||||
targets = torch.tensor([], device=device, dtype=torch.long)
|
||||
input_lengths = torch.tensor([], device=device, dtype=torch.long)
|
||||
target_lengths = torch.tensor([], device=device, dtype=torch.long)
|
||||
with self.assertRaisesRegex(RuntimeError, "log_probs tensor must not be empty"):
|
||||
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
|
||||
|
||||
@expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS.
|
||||
@dtypesIfCUDA(torch.half, torch.float, torch.double)
|
||||
@dtypes(torch.float)
|
||||
|
Reference in New Issue
Block a user