[64-bit] Int64 casting for UpSampleNearest3D (#144865)

Fixes #144855

Follows approach in https://github.com/pytorch/pytorch/pull/141923 to use int64 types to increase INT_MAX limits
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144865
Approved by: https://github.com/eqy
This commit is contained in:
Jack Taylor
2025-01-29 19:30:07 +00:00
committed by PyTorch MergeBot
parent 1c9014a135
commit 082fab0fc7
2 changed files with 28 additions and 16 deletions

View File

@ -55,12 +55,12 @@ __global__ void upsample_nearest3d_out_frame(
float height_scale,
float width_scale) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
return;
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
int c = (dst_idx / (dst_c_stride)) % dim_c;
@ -72,7 +72,7 @@ __global__ void upsample_nearest3d_out_frame(
int dst_x = dst_idx % dst_dim_w;
int src_x = nn_compute_source_index_fn(width_scale, dst_x, src_dim_w);
int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
int64_t src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
src_y * src_dim_w + src_x;
for (int b = 0; b < dim_b; b++) {
output[dst_idx] = input[src_idx];
@ -100,12 +100,12 @@ __global__ void upsample_nearest3d_backward_out_frame(
float height_scale,
float width_scale) {
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
return;
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
int c = (dst_idx / (dst_c_stride)) % dim_c;
@ -132,7 +132,7 @@ __global__ void upsample_nearest3d_backward_out_frame(
for (int z = src_z; z < src_z_up; z++) {
for (int y = src_y; y < src_y_up; y++) {
for (int x = src_x; x < src_x_up; x++) {
int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
int64_t src_idx = b * dim_c * src_c_stride + c * src_c_stride +
z * src_dim_h * src_dim_w + y * src_dim_w + x;
grad += grad_o[src_idx];
}
@ -180,9 +180,9 @@ static void upsample_nearest3d_out_cuda_template(
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{ceil_div(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max(),
"upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
// safe check for int64 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(output.numel() <= std::numeric_limits<int64_t>::max(),
"upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got ", output.sizes());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type(), "upsample_nearest3d_out_frame", [&] {
@ -254,11 +254,11 @@ static void upsample_nearest3d_backward_out_cuda_template(
dim3 bdim{std::min<unsigned int>(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
dim3 gdim{ceil_div(n, bdim.x)};
// safe check for int32 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
"upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
"upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
// safe check for int64 indexing; implicitly restrict launch config for kernel
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(),
"upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got ", grad_input.sizes());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(),
"upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got ", grad_output.sizes());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {

View File

@ -174,6 +174,18 @@ class TestTorchDeviceType(TestCase):
scalar = bytes_to_scalar(bytes_list, dtype, device)
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
# For testing in64 support in upsample_nearest3d
@onlyCUDA
@largeTensorTest('56GB', device='cuda')
@dtypes(torch.bfloat16)
@unittest.skipIf(IS_JETSON, "Large tensor tests are too large for Jetson.")
def test_int64_upsample3d(self, device, dtype):
x = torch.ones((1, 256, 16, 720, 1280), dtype=dtype, device=device)
try:
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
except Exception as e:
self.fail(f"Unexpected exception raised: {e}")
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
torch.bool, torch.float32, torch.complex64, torch.float64,
torch.complex128, torch.uint16, torch.uint32, torch.uint64)