mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
[CUDA][cuDNN] Fix handling of CPU side input and target length tensors in CTCLoss (#152745)
https://github.com/pytorch/pytorch/pull/128271 migrated to cuDNN V8 CTCLoss which expects input and target length tensors to be on `CUDA` rather than `CPU` without adding the logic to account for the edge case of them being on `CPU` see also #152421 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152745 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
773a91c775
commit
cecfc7dc53
@ -151,6 +151,13 @@ bool _use_cudnn_ctc_loss_tensor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (target_lengths.device().type() != at::kCUDA ||
|
||||||
|
input_lengths.device().type() != at::kCUDA) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false,
|
||||||
|
"CTCLoss cannot be graph captured with CPU length tensors. "
|
||||||
|
"Move CPU length tensors to GPU memory to enable graph capture.")
|
||||||
|
}
|
||||||
at::_assert_async(at::lt(input_lengths.max(), 256));
|
at::_assert_async(at::lt(input_lengths.max(), 256));
|
||||||
at::_assert_async(at::le(target_lengths, input_lengths).all());
|
at::_assert_async(at::le(target_lengths, input_lengths).all());
|
||||||
}
|
}
|
||||||
@ -253,9 +260,18 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
|
|||||||
bool deterministic,
|
bool deterministic,
|
||||||
bool zero_infinity) {
|
bool zero_infinity) {
|
||||||
Tensor targets_t_ = targets_t;
|
Tensor targets_t_ = targets_t;
|
||||||
|
Tensor input_lengths_ = input_lengths;
|
||||||
|
Tensor target_lengths_ = target_lengths;
|
||||||
if (targets_t.device().type() == at::kCPU) {
|
if (targets_t.device().type() == at::kCPU) {
|
||||||
targets_t_ = targets_t.to(Device(at::kCUDA));
|
targets_t_ = targets_t.to(Device(at::kCUDA));
|
||||||
}
|
}
|
||||||
|
if (input_lengths.device().type() == at::kCPU) {
|
||||||
|
input_lengths_ = input_lengths.to(Device(at::kCUDA));
|
||||||
|
}
|
||||||
|
if (input_lengths.device().type() == at::kCPU) {
|
||||||
|
target_lengths_ = target_lengths.to(Device(at::kCUDA));
|
||||||
|
}
|
||||||
|
|
||||||
const CheckedFrom c = "cudnn_ctc_loss";
|
const CheckedFrom c = "cudnn_ctc_loss";
|
||||||
const TensorArg log_probs{log_probs_t, "log_probs", 1};
|
const TensorArg log_probs{log_probs_t, "log_probs", 1};
|
||||||
const TensorArg targets{targets_t_, "targets", 2};
|
const TensorArg targets{targets_t_, "targets", 2};
|
||||||
@ -268,9 +284,9 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
|
|||||||
checkBackend(c, {*targets}, Backend::CUDA);
|
checkBackend(c, {*targets}, Backend::CUDA);
|
||||||
const auto batch_size = log_probs->size(1);
|
const auto batch_size = log_probs->size(1);
|
||||||
int64_t input_lengths_size =
|
int64_t input_lengths_size =
|
||||||
input_lengths.sizes().size() ? input_lengths.size(0) : 1;
|
input_lengths_.sizes().size() ? input_lengths_.size(0) : 1;
|
||||||
int64_t target_lengths_size =
|
int64_t target_lengths_size =
|
||||||
target_lengths.sizes().size() ? target_lengths.size(0) : 1;
|
target_lengths_.sizes().size() ? target_lengths_.size(0) : 1;
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
input_lengths_size == batch_size,
|
input_lengths_size == batch_size,
|
||||||
"input_lengths needs to have size to match batch_size");
|
"input_lengths needs to have size to match batch_size");
|
||||||
@ -319,8 +335,8 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
|
|||||||
log_probs_desc.desc(),
|
log_probs_desc.desc(),
|
||||||
log_probs_t.data_ptr(),
|
log_probs_t.data_ptr(),
|
||||||
targets_t_.data_ptr<int>(),
|
targets_t_.data_ptr<int>(),
|
||||||
target_lengths.data_ptr<int>(),
|
target_lengths_.data_ptr<int>(),
|
||||||
input_lengths.data_ptr<int>(),
|
input_lengths_.data_ptr<int>(),
|
||||||
costs.data_ptr(),
|
costs.data_ptr(),
|
||||||
grad_desc.desc(),
|
grad_desc.desc(),
|
||||||
grad.data_ptr(),
|
grad.data_ptr(),
|
||||||
|
|||||||
@ -11523,7 +11523,7 @@ class TestNNDeviceType(NNTestCase):
|
|||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
|
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
|
||||||
def test_ctc_loss_cudnn_tensor(self, device):
|
def test_ctc_loss_cudnn_tensor_cuda(self):
|
||||||
batch_size = 16
|
batch_size = 16
|
||||||
input_length = 30
|
input_length = 30
|
||||||
num_labels = 101
|
num_labels = 101
|
||||||
@ -11549,6 +11549,36 @@ class TestNNDeviceType(NNTestCase):
|
|||||||
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
|
grad_cudnn, = torch.autograd.grad(loss_cudnn, log_probs, grad_out)
|
||||||
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
|
self.assertEqual(grad_cudnn, grad_native, atol=1e-4, rtol=0)
|
||||||
|
|
||||||
|
@onlyCUDA
|
||||||
|
@skipCUDAIfRocm(msg="skipped Cudnn test on ROCm")
|
||||||
|
def test_ctc_loss_cudnn_tensor_cpu_length_cuda(self):
|
||||||
|
# batch size
|
||||||
|
N = 50
|
||||||
|
# audio length
|
||||||
|
T = 100
|
||||||
|
# text dimension
|
||||||
|
C = 80
|
||||||
|
# max text length
|
||||||
|
S = 10
|
||||||
|
|
||||||
|
prob_device = torch.device("cuda")
|
||||||
|
other_device = torch.device("cpu")
|
||||||
|
other_dtype = torch.int32
|
||||||
|
|
||||||
|
log_probs = torch.randn(T, N, C).log_softmax(2).to(prob_device)
|
||||||
|
|
||||||
|
input_lengths = torch.full((N,), T, dtype=other_dtype).to(other_device)
|
||||||
|
target_lengths = torch.randint(low=1, high=S, size=(N,), dtype=other_dtype).to(other_device)
|
||||||
|
targets = torch.randint(low=0, high=C, size=(sum(target_lengths),), dtype=other_dtype).to(other_device)
|
||||||
|
|
||||||
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
|
log_probs=log_probs,
|
||||||
|
targets=targets,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
target_lengths=target_lengths,
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
|
||||||
@expectedFailureMPS
|
@expectedFailureMPS
|
||||||
def test_ctc_loss_error(self, device):
|
def test_ctc_loss_error(self, device):
|
||||||
log_probs = torch.rand(0, 0, 4, device=device)
|
log_probs = torch.rand(0, 0, 4, device=device)
|
||||||
|
|||||||
Reference in New Issue
Block a user