mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA][cuDNN] Fix handling of CPU
side input and target length tensors in CTCLoss
(#152745)
https://github.com/pytorch/pytorch/pull/128271 migrated to cuDNN V8 CTCLoss which expects input and target length tensors to be on `CUDA` rather than `CPU` without adding the logic to account for the edge case of them being on `CPU` see also #152421 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152745 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
773a91c775
commit
cecfc7dc53
@ -11523,7 +11523,7 @@ class TestNNDeviceType(NNTestCase):
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
|
||||
def test_ctc_loss_cudnn_tensor(self, device):
|
||||
def test_ctc_loss_cudnn_tensor_cuda(self):
|
||||
batch_size = 16
|
||||
input_length = 30
|
||||
num_labels = 101
|
||||
@ -11549,6 +11549,36 @@ class TestNNDeviceType(NNTestCase):
|
||||
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
|
||||
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
|
||||
def test_ctc_loss_cudnn_tensor_cpu_length_cuda(self):
|
||||
# batch size
|
||||
N = 50
|
||||
# audio length
|
||||
T = 100
|
||||
# text dimension
|
||||
C = 80
|
||||
# max text length
|
||||
S = 10
|
||||
|
||||
prob_device = torch.device("cuda")
|
||||
other_device = torch.device("cpu")
|
||||
other_dtype = torch.int32
|
||||
|
||||
log_probs = torch.randn(T, N, C).log_softmax(2).to(prob_device)
|
||||
|
||||
input_lengths = torch.full((N,), T, dtype=other_dtype).to(other_device)
|
||||
target_lengths = torch.randint(low=1, high=S, size=(N,), dtype=other_dtype).to(other_device)
|
||||
targets = torch.randint(low=0, high=C, size=(sum(target_lengths),), dtype=other_dtype).to(other_device)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=log_probs,
|
||||
targets=targets,
|
||||
input_lengths=input_lengths,
|
||||
target_lengths=target_lengths,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
@expectedFailureMPS
|
||||
def test_ctc_loss_error(self, device):
|
||||
log_probs = torch.rand(0, 0, 4, device=device)
|
||||
|
Reference in New Issue
Block a user