Fix #117011: add the TORCH_CHECK(grad_output) of upsample_nearest::backward() (#117100)

add the TORCH_CHECK(grad_output) of upsample_nearest::backward()

Fixes #117011

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117100
Approved by: https://github.com/lezcano
This commit is contained in:
wangkang1
2024-01-11 18:06:17 +00:00
committed by PyTorch MergeBot
parent f89725fb41
commit 9e3580f793
3 changed files with 3 additions and 0 deletions

View File

@ -175,6 +175,7 @@ static void upsample_nearest1d_backward_out_cuda_template(
dim3 gdim{ceil_div(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest1d_backward_out_frame", [&] {

View File

@ -410,6 +410,7 @@ static void upsample_nearest2d_backward_out_cuda_template(
dim3 gdim{ceil_div(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_out_frame", [&] {

View File

@ -255,6 +255,7 @@ static void upsample_nearest3d_backward_out_cuda_template(
dim3 gdim{ceil_div(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {