Add check for ctc_loss targets param (#150981)

Fixes #150835

## Test Result

```python
# cuda
>>> import torch
>>> import torch.nn.functional as F
>>> device = "cuda" # "cpu" is fine
>>> num_classes = 4
>>> log_probs = torch.rand(0, 0, num_classes, device=device)
>>> targets = torch.tensor([], device=device, dtype=torch.long)
>>> input_lengths = torch.tensor([], device=device, dtype=torch.long)
>>> target_lengths = torch.tensor([], device=device, dtype=torch.long)
>>> result = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zong/code/pytorch/torch/nn/functional.py", line 3079, in ctc_loss
    return torch.ctc_loss(
           ^^^^^^^^^^^^^^^
RuntimeError: log_probs tensor must not be empty

# cpu
>>> device = "cpu"
>>> num_classes = 4
>>> log_probs = torch.rand(0, 0, num_classes, device=device)
>>> targets = torch.tensor([], device=device, dtype=torch.long)
>>> input_lengths = torch.tensor([], device=device, dtype=torch.long)
>>> target_lengths = torch.tensor([], device=device, dtype=torch.long)
>>> result = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zong/code/pytorch/torch/nn/functional.py", line 3079, in ctc_loss
    return torch.ctc_loss(
           ^^^^^^^^^^^^^^^
RuntimeError: log_probs tensor must not be empty

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150981
Approved by: https://github.com/eqy
This commit is contained in:
zeshengzong
2025-04-14 07:24:30 +00:00
committed by PyTorch MergeBot
parent bbc5fe8504
commit 01f226bfb8
3 changed files with 11 additions and 0 deletions

View File

@ -126,6 +126,7 @@ std::tuple<Tensor, Tensor, size_t, std::vector<int64_t>> ctc_loss_allocate_outpu
// the alphas from the user by only returning the loss.
template<typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor> ctc_loss_cpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
TORCH_CHECK(log_probs.numel() > 0, "log_probs tensor must not be empty");
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();

View File

@ -219,6 +219,7 @@ ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
// backward. The dispatch function will only return the loss.
template<typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const Tensor& targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t BLANK) {
TORCH_CHECK(log_probs.numel() > 0, "log_probs tensor must not be empty");
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
CheckedFrom c = "ctc_loss_gpu";

View File

@ -11532,6 +11532,15 @@ class TestNNDeviceType(NNTestCase):
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
@expectedFailureMPS
def test_ctc_loss_error(self, device):
log_probs = torch.rand(0, 0, 4, device=device)
targets = torch.tensor([], device=device, dtype=torch.long)
input_lengths = torch.tensor([], device=device, dtype=torch.long)
target_lengths = torch.tensor([], device=device, dtype=torch.long)
with self.assertRaisesRegex(RuntimeError, "log_probs tensor must not be empty"):
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none')
@expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS.
@dtypesIfCUDA(torch.half, torch.float, torch.double)
@dtypes(torch.float)