mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[64-bit] Int64 casting for UpSampleNearest3D (#144865)
Fixes #144855 Follows approach in https://github.com/pytorch/pytorch/pull/141923 to use int64 types to increase INT_MAX limits Pull Request resolved: https://github.com/pytorch/pytorch/pull/144865 Approved by: https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c9014a135
commit
082fab0fc7
@ -55,12 +55,12 @@ __global__ void upsample_nearest3d_out_frame(
|
||||
float height_scale,
|
||||
float width_scale) {
|
||||
|
||||
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
|
||||
return;
|
||||
|
||||
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
|
||||
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
|
||||
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
|
||||
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
|
||||
|
||||
int c = (dst_idx / (dst_c_stride)) % dim_c;
|
||||
|
||||
@ -72,7 +72,7 @@ __global__ void upsample_nearest3d_out_frame(
|
||||
int dst_x = dst_idx % dst_dim_w;
|
||||
int src_x = nn_compute_source_index_fn(width_scale, dst_x, src_dim_w);
|
||||
|
||||
int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
|
||||
int64_t src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
|
||||
src_y * src_dim_w + src_x;
|
||||
for (int b = 0; b < dim_b; b++) {
|
||||
output[dst_idx] = input[src_idx];
|
||||
@ -100,12 +100,12 @@ __global__ void upsample_nearest3d_backward_out_frame(
|
||||
float height_scale,
|
||||
float width_scale) {
|
||||
|
||||
int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int64_t dst_idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
|
||||
return;
|
||||
|
||||
int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
|
||||
int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
|
||||
int64_t dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
|
||||
int64_t src_c_stride = src_dim_d * src_dim_h * src_dim_w;
|
||||
|
||||
int c = (dst_idx / (dst_c_stride)) % dim_c;
|
||||
|
||||
@ -132,7 +132,7 @@ __global__ void upsample_nearest3d_backward_out_frame(
|
||||
for (int z = src_z; z < src_z_up; z++) {
|
||||
for (int y = src_y; y < src_y_up; y++) {
|
||||
for (int x = src_x; x < src_x_up; x++) {
|
||||
int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
|
||||
int64_t src_idx = b * dim_c * src_c_stride + c * src_c_stride +
|
||||
z * src_dim_h * src_dim_w + y * src_dim_w + x;
|
||||
grad += grad_o[src_idx];
|
||||
}
|
||||
@ -180,9 +180,9 @@ static void upsample_nearest3d_out_cuda_template(
|
||||
dim3 bdim{std::min<unsigned int>(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
|
||||
dim3 gdim{ceil_div(n, bdim.x)};
|
||||
// safe check for int32 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max(),
|
||||
"upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
|
||||
// safe check for int64 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(output.numel() <= std::numeric_limits<int64_t>::max(),
|
||||
"upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got ", output.sizes());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type(), "upsample_nearest3d_out_frame", [&] {
|
||||
@ -254,11 +254,11 @@ static void upsample_nearest3d_backward_out_cuda_template(
|
||||
dim3 bdim{std::min<unsigned int>(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
|
||||
dim3 gdim{ceil_div(n, bdim.x)};
|
||||
// safe check for int32 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
|
||||
"upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
|
||||
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
|
||||
"upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
|
||||
// safe check for int64 indexing; implicitly restrict launch config for kernel
|
||||
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int64_t>::max(),
|
||||
"upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got ", grad_input.sizes());
|
||||
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int64_t>::max(),
|
||||
"upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got ", grad_output.sizes());
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {
|
||||
|
@ -174,6 +174,18 @@ class TestTorchDeviceType(TestCase):
|
||||
scalar = bytes_to_scalar(bytes_list, dtype, device)
|
||||
self.assertEqual(scalar.storage().untyped().tolist(), bytes_list)
|
||||
|
||||
# For testing in64 support in upsample_nearest3d
|
||||
@onlyCUDA
|
||||
@largeTensorTest('56GB', device='cuda')
|
||||
@dtypes(torch.bfloat16)
|
||||
@unittest.skipIf(IS_JETSON, "Large tensor tests are too large for Jetson.")
|
||||
def test_int64_upsample3d(self, device, dtype):
|
||||
x = torch.ones((1, 256, 16, 720, 1280), dtype=dtype, device=device)
|
||||
try:
|
||||
torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')
|
||||
except Exception as e:
|
||||
self.fail(f"Unexpected exception raised: {e}")
|
||||
|
||||
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64,
|
||||
torch.bool, torch.float32, torch.complex64, torch.float64,
|
||||
torch.complex128, torch.uint16, torch.uint32, torch.uint64)
|
||||
|
Reference in New Issue
Block a user