mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add the TORCH_CHECK(grad_output) of upsample_nearest::backward() Fixes #117011 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117100 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
f89725fb41
commit
9e3580f793
@ -175,6 +175,7 @@ static void upsample_nearest1d_backward_out_cuda_template(
|
||||
dim3 gdim{ceil_div(n, bdim.x)};
|
||||
// safe check for int32 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
|
||||
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest1d_backward_out_frame", [&] {
|
||||
|
@ -410,6 +410,7 @@ static void upsample_nearest2d_backward_out_cuda_template(
|
||||
dim3 gdim{ceil_div(n, bdim.x)};
|
||||
// safe check for int32 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
|
||||
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest2d_backward_out_frame", [&] {
|
||||
|
@ -255,6 +255,7 @@ static void upsample_nearest3d_backward_out_cuda_template(
|
||||
dim3 gdim{ceil_div(n, bdim.x)};
|
||||
// safe check for int32 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
|
||||
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {
|
||||
|
Reference in New Issue
Block a user