[CUDA][cuDNN] Fix handling of CPU side input and target length tensors in CTCLoss (#152745)

https://github.com/pytorch/pytorch/pull/128271 migrated to cuDNN V8 CTCLoss which expects input and target length tensors to be on `CUDA` rather than `CPU` without adding the logic to account for the edge case of them being on `CPU`

see also #152421

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152745
Approved by: https://github.com/Skylion007
This commit is contained in:
Eddie Yan
2025-05-07 22:01:18 +00:00
committed by PyTorch MergeBot
parent 773a91c775
commit cecfc7dc53
2 changed files with 51 additions and 5 deletions

View File

@ -151,6 +151,13 @@ bool _use_cudnn_ctc_loss_tensor(
}
}
} else {
if (target_lengths.device().type() != at::kCUDA ||
input_lengths.device().type() != at::kCUDA) {
TORCH_CHECK(
false,
"CTCLoss cannot be graph captured with CPU length tensors. "
"Move CPU length tensors to GPU memory to enable graph capture.")
}
at::_assert_async(at::lt(input_lengths.max(), 256));
at::_assert_async(at::le(target_lengths, input_lengths).all());
}
@ -253,9 +260,18 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
bool deterministic,
bool zero_infinity) {
Tensor targets_t_ = targets_t;
Tensor input_lengths_ = input_lengths;
Tensor target_lengths_ = target_lengths;
if (targets_t.device().type() == at::kCPU) {
targets_t_ = targets_t.to(Device(at::kCUDA));
}
if (input_lengths.device().type() == at::kCPU) {
input_lengths_ = input_lengths.to(Device(at::kCUDA));
}
if (input_lengths.device().type() == at::kCPU) {
target_lengths_ = target_lengths.to(Device(at::kCUDA));
}
const CheckedFrom c = "cudnn_ctc_loss";
const TensorArg log_probs{log_probs_t, "log_probs", 1};
const TensorArg targets{targets_t_, "targets", 2};
@ -268,9 +284,9 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
checkBackend(c, {*targets}, Backend::CUDA);
const auto batch_size = log_probs->size(1);
int64_t input_lengths_size =
input_lengths.sizes().size() ? input_lengths.size(0) : 1;
input_lengths_.sizes().size() ? input_lengths_.size(0) : 1;
int64_t target_lengths_size =
target_lengths.sizes().size() ? target_lengths.size(0) : 1;
target_lengths_.sizes().size() ? target_lengths_.size(0) : 1;
TORCH_CHECK(
input_lengths_size == batch_size,
"input_lengths needs to have size to match batch_size");
@ -319,8 +335,8 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
log_probs_desc.desc(),
log_probs_t.data_ptr(),
targets_t_.data_ptr<int>(),
target_lengths.data_ptr<int>(),
input_lengths.data_ptr<int>(),
target_lengths_.data_ptr<int>(),
input_lengths_.data_ptr<int>(),
costs.data_ptr(),
grad_desc.desc(),
grad.data_ptr(),

View File

@ -11523,7 +11523,7 @@ class TestNNDeviceType(NNTestCase):
@onlyCUDA
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
def test_ctc_loss_cudnn_tensor(self, device):
def test_ctc_loss_cudnn_tensor_cuda(self):
batch_size = 16
input_length = 30
num_labels = 101
@ -11549,6 +11549,36 @@ class TestNNDeviceType(NNTestCase):
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
@onlyCUDA
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
def test_ctc_loss_cudnn_tensor_cpu_length_cuda(self):
# batch size
N = 50
# audio length
T = 100
# text dimension
C = 80
# max text length
S = 10
prob_device = torch.device("cuda")
other_device = torch.device("cpu")
other_dtype = torch.int32
log_probs = torch.randn(T, N, C).log_softmax(2).to(prob_device)
input_lengths = torch.full((N,), T, dtype=other_dtype).to(other_device)
target_lengths = torch.randint(low=1, high=S, size=(N,), dtype=other_dtype).to(other_device)
targets = torch.randint(low=0, high=C, size=(sum(target_lengths),), dtype=other_dtype).to(other_device)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=log_probs,
targets=targets,
input_lengths=input_lengths,
target_lengths=target_lengths,
reduction="sum",
)
@expectedFailureMPS
def test_ctc_loss_error(self, device):
log_probs = torch.rand(0, 0, 4, device=device)