mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA] Large tensor maxpool crash fix (#165374)
Fixes #165297 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165374 Approved by: https://github.com/eqy, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
eaeaa08e3a
commit
d73c283c3a
@ -38,12 +38,41 @@ __device__ inline int min(int a, int b) {
|
||||
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
|
||||
#endif
|
||||
|
||||
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
|
||||
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
|
||||
template <typename index_t>
|
||||
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
|
||||
const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
|
||||
return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
|
||||
}
|
||||
|
||||
static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
|
||||
return min((size + pad) / stride + 1, pooled_size);
|
||||
template <typename index_t>
|
||||
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
|
||||
return std::min((size + pad) / stride + 1, pooled_size);
|
||||
}
|
||||
|
||||
static inline bool can_use_int32_nhwc(
|
||||
int64_t nbatch, int64_t channels,
|
||||
int64_t height, int64_t width,
|
||||
int64_t pooled_height, int64_t pooled_width,
|
||||
int64_t in_stride_n, int64_t in_stride_c,
|
||||
int64_t in_stride_h, int64_t in_stride_w)
|
||||
{
|
||||
constexpr int64_t int_max = std::numeric_limits<int>::max();
|
||||
|
||||
int64_t max_intra_batch =
|
||||
(height ? (height - 1) * in_stride_h : 0) +
|
||||
(width ? (width - 1) * in_stride_w : 0) +
|
||||
(channels? (channels - 1) * in_stride_c : 0);
|
||||
|
||||
int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;
|
||||
|
||||
if (max_input_offset > int_max) return false;
|
||||
|
||||
int64_t out_batch_stride = pooled_height * pooled_width * channels;
|
||||
if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;
|
||||
|
||||
if (height * width > int_max) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// kernels borrowed from Caffe
|
||||
@ -85,21 +114,25 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename index_t>
|
||||
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
|
||||
__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
|
||||
const int64_t channels, const int64_t height,
|
||||
const int64_t width, const int pooled_height, const int pooled_width,
|
||||
__global__ void max_pool_forward_nhwc(
|
||||
const scalar_t* bottom_data,
|
||||
const int nbatch,
|
||||
const index_t channels, const index_t height, const index_t width,
|
||||
const index_t pooled_height, const index_t pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w,
|
||||
const int in_stride_n, const int in_stride_c,
|
||||
const int in_stride_h, const int in_stride_w,
|
||||
const index_t in_stride_n, const index_t in_stride_c,
|
||||
const index_t in_stride_h, const index_t in_stride_w,
|
||||
const int kernel_stride_C, const int kernel_size_C,
|
||||
scalar_t* top_data, int64_t* top_mask) {
|
||||
extern __shared__ int smem[];
|
||||
int *out_mask_cached = smem;
|
||||
scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
|
||||
|
||||
extern __shared__ unsigned char smem_raw[];
|
||||
index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
|
||||
scalar_t *out_cached = reinterpret_cast<scalar_t*>(
|
||||
out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);
|
||||
|
||||
// flattening cta for pre-computation & smem initialization;
|
||||
int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
|
||||
@ -118,26 +151,26 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
|
||||
int channel_id = blockIdx.x / nbatch;
|
||||
int channel_offset = threadIdx.x + channel_id * blockDim.x;
|
||||
|
||||
top_data = top_data + batch_id * pooled_height * pooled_width * channels;
|
||||
top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
|
||||
bottom_data = bottom_data + batch_id * in_stride_n;
|
||||
top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
|
||||
top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
|
||||
bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;
|
||||
|
||||
out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
|
||||
out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
|
||||
out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
|
||||
out_mask_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
|
||||
|
||||
int oH = (pooled_height + gridDim.z-1) / gridDim.z;
|
||||
int oW = (pooled_width + gridDim.y-1) / gridDim.y;
|
||||
int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
|
||||
int oW = (static_cast<int>(pooled_width) + gridDim.y - 1) / gridDim.y;
|
||||
int ostartH = threadIdx.z + blockIdx.z*oH;
|
||||
int oendH = ::min(ostartH+oH, pooled_height);
|
||||
int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
|
||||
int ostartW = threadIdx.y + blockIdx.y*oW;
|
||||
int oendW = ::min(ostartW+oW, pooled_width);
|
||||
int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));
|
||||
|
||||
for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
|
||||
int hstart = oh * stride_h - pad_h;
|
||||
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
|
||||
index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
|
||||
index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
|
||||
for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
|
||||
int wstart = ow * stride_w - pad_w;
|
||||
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
|
||||
index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
|
||||
index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
|
||||
while(hstart < 0)
|
||||
hstart += dilation_h;
|
||||
while(wstart < 0)
|
||||
@ -185,11 +218,11 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
|
||||
// Else do it Non-Prefetch...
|
||||
else
|
||||
#endif
|
||||
for (int ih = hstart; ih < hend; ih += dilation_h) {
|
||||
for (int iw = wstart; iw < wend; iw += dilation_w) {
|
||||
for (index_t ih = hstart; ih < hend; ih += dilation_h) {
|
||||
for (index_t iw = wstart; iw < wend; iw += dilation_w) {
|
||||
int cached_index = threadIdx.x;
|
||||
const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
|
||||
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
|
||||
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
|
||||
scalar_t val = ptr_input[c * in_stride_c];
|
||||
if ((val > out_cached[cached_index]) || at::_isnan(val)) {
|
||||
out_cached[cached_index] = val;
|
||||
@ -200,15 +233,15 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
|
||||
int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
|
||||
scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
|
||||
int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
|
||||
|
||||
int cached_index = threadIdx.x;
|
||||
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
|
||||
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
|
||||
ptr_output_data[c] = out_cached[cached_index];
|
||||
ptr_output_mask[c] = out_mask_cached[cached_index];
|
||||
ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
|
||||
out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
|
||||
out_mask_cached[cached_index] = 0;
|
||||
out_mask_cached[cached_index] = index_t(0);
|
||||
cached_index += blockDim.x;
|
||||
}
|
||||
}
|
||||
@ -462,6 +495,11 @@ const Tensor& indices) {
|
||||
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
|
||||
const dim3 block(block_x, block_y, block_z);
|
||||
|
||||
bool use_int32 = can_use_int32_nhwc(
|
||||
nbatch, nInputPlane, inputHeight, inputWidth,
|
||||
outputHeight, outputWidth,
|
||||
in_stride_n, in_stride_c, in_stride_h, in_stride_w);
|
||||
|
||||
int kernel_stride_C = ceil_div(
|
||||
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
|
||||
int kernel_size_C = ceil_div(
|
||||
@ -476,18 +514,41 @@ const Tensor& indices) {
|
||||
ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
|
||||
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
|
||||
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
|
||||
size_t shmem_size;
|
||||
size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;
|
||||
|
||||
max_pool_forward_nhwc<scalar_t>
|
||||
if (use_int32) {
|
||||
shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
|
||||
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
|
||||
"shared memory too small");
|
||||
max_pool_forward_nhwc<scalar_t, int32_t>
|
||||
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input_data, nbatch,
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
input_data, static_cast<int>(nbatch),
|
||||
static_cast<int32_t>(nInputPlane),
|
||||
static_cast<int32_t>(inputHeight),
|
||||
static_cast<int32_t>(inputWidth),
|
||||
static_cast<int32_t>(outputHeight),
|
||||
static_cast<int32_t>(outputWidth),
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
in_stride_n, in_stride_c,
|
||||
in_stride_h, in_stride_w,
|
||||
static_cast<int32_t>(in_stride_n),
|
||||
static_cast<int32_t>(in_stride_c),
|
||||
static_cast<int32_t>(in_stride_h),
|
||||
static_cast<int32_t>(in_stride_w),
|
||||
kernel_stride_C, kernel_size_C,
|
||||
output_data, indices_data);
|
||||
} else {
|
||||
shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
|
||||
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
|
||||
"shared memory too small");
|
||||
max_pool_forward_nhwc<scalar_t, int64_t>
|
||||
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
|
||||
input_data, static_cast<int>(nbatch),
|
||||
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
|
||||
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
|
||||
in_stride_n, in_stride_c, in_stride_h, in_stride_w,
|
||||
kernel_stride_C, kernel_size_C,
|
||||
output_data, indices_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
break;
|
||||
}
|
||||
|
@ -7496,6 +7496,19 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
"fractional_max_pool2d requires output_ratio to either be a single Int or tuple of Ints."):
|
||||
res = arg_class(*arg_3)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
@largeTensorTest("20GB", device="cuda")
|
||||
def test_large_max_pool2d_ch_last(self):
|
||||
# https://github.com/pytorch/pytorch/issues/165297
|
||||
N, C, H, W = 70, 64, 512, 960 # dims to extend > int32
|
||||
device = torch.device("cuda")
|
||||
x_cuda = torch.randn(N, C, H, W, device=device, dtype=torch.float16)
|
||||
x_cuda = x_cuda.to(memory_format=torch.channels_last)
|
||||
pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
y_cuda_ch_last = pool(x_cuda)
|
||||
y_cuda_contig = pool(x_cuda.contiguous())
|
||||
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
|
||||
|
||||
def test_max_pool1d_invalid_output_size(self):
|
||||
arg_1 = 3
|
||||
arg_2 = 255
|
||||
@ -8465,6 +8478,18 @@ class TestNNDeviceType(NNTestCase):
|
||||
# workaround for memory usage overhead of assertEqual
|
||||
self.assertTrue(torch.allclose(a.grad.cpu(), a_cpu.grad.half()))
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("20GB", device="cuda")
|
||||
def test_large_max_pool2d_ch_last(self, device):
|
||||
# https://github.com/pytorch/pytorch/issues/165297
|
||||
N, C, H, W = 70, 64, 512, 960 # dims to extend > int32
|
||||
x_cuda = torch.randn(N, C, H, W, device=device, dtype=torch.float16)
|
||||
x_cuda = x_cuda.to(memory_format=torch.channels_last)
|
||||
pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
y_cuda_ch_last = pool(x_cuda)
|
||||
y_cuda_contig = pool(x_cuda.contiguous())
|
||||
self.assertEqual(y_cuda_ch_last, y_cuda_contig)
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest("48GB", "cpu")
|
||||
@largeTensorTest("48GB", "cuda")
|
||||
|
Reference in New Issue
Block a user