mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] new implementation of upsample_bilinear2d_backward (#164572)
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup. The changes are from Yu-Yun <YuYun.Chang@amd.com>. # Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X - The original "scatter-add" approach - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs. - The new "gather-sum" approach - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function). # Breakdown of the code changes - Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame` - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`). - Each thread processed one output pixel. - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`). - Each thread is responsible for calculating the final gradient for a single input pixel. - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`. - Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation - This is essentially the mathematical inverse of the forward pass. - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor. - Gradient calculation approach switching from "scatter-add" to "gather-sum" - Scatter-add - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor. - Gather-sum - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value. - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel. - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`). - W/o any gloabl memory access, this accumulation is extremely fast. - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`). # Why performance gets boosted - Analysis of the root cause of performance drop - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward - First and foremost, elimination of the contention of atomic operations - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time. - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points. - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue. - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect. - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`. - Improved memory access pattern and locality - Write coalescing - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs. - Read locality - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread. - Trade-off: computation for memory synchronization - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X. - Removal of atomic operations avoids expensive memory synchronization. --- Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR. Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
b42fe389b9
commit
53f9ae0e50
@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Helper function to compute output pixel range that can contribute to input pixel
|
||||
template <typename accscalar_t>
|
||||
__device__ __forceinline__ void compute_output_range(
|
||||
int input_pos,
|
||||
accscalar_t scale,
|
||||
int output_size,
|
||||
bool align_corners,
|
||||
int& min_output,
|
||||
int& max_output) {
|
||||
accscalar_t lo, hi;
|
||||
if (align_corners) {
|
||||
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
|
||||
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
|
||||
} else {
|
||||
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
|
||||
}
|
||||
min_output = max(0, static_cast<int>(ceil(lo)));
|
||||
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Backward (adjoint) operation 1 <- 2 (accumulates)
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
C10_LAUNCH_BOUNDS_1(1024)
|
||||
@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
const bool align_corners,
|
||||
scalar_t* __restrict__ idata,
|
||||
const scalar_t* __restrict__ odata) {
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
|
||||
const size_t i_numel = nc * width1 * height1;
|
||||
#ifdef USE_ROCM
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
// Decode input pixel coordinates
|
||||
size_t index_temp = index;
|
||||
const int w1 = index_temp % width1;
|
||||
index_temp /= width1;
|
||||
const int h1 = index_temp % height1;
|
||||
const size_t nc_idx = index_temp / height1;
|
||||
|
||||
accscalar_t grad_sum = 0;
|
||||
|
||||
// Find range of output pixels that could interpolate from this input pixel
|
||||
int h2_min, h2_max, w2_min, w2_max;
|
||||
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
|
||||
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
|
||||
|
||||
// Iterate over potential output pixels
|
||||
for (int h2 = h2_min; h2 <= h2_max; h2++) {
|
||||
for (int w2 = w2_min; w2 <= w2_max; w2++) {
|
||||
// Compute source coordinates for this output pixel
|
||||
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rheight, h2, align_corners, /*cubic=*/false);
|
||||
const int h1_base = (int)h1r;
|
||||
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
|
||||
const accscalar_t h1lambda = h1r - h1_base;
|
||||
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
|
||||
|
||||
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
|
||||
rwidth, w2, align_corners, /*cubic=*/false);
|
||||
const int w1_base = (int)w1r;
|
||||
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
|
||||
const accscalar_t w1lambda = w1r - w1_base;
|
||||
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
|
||||
|
||||
// Check if our input pixel participates in this interpolation and accumulate all weights
|
||||
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
|
||||
// to the same pixel, so we need to accumulate weights from all matching positions
|
||||
accscalar_t weight = 0;
|
||||
|
||||
// Check all four interpolation positions and accumulate weights
|
||||
if (h1 == h1_base && w1 == w1_base) {
|
||||
weight += h0lambda * w0lambda; // top-left
|
||||
}
|
||||
if (h1 == h1_base && w1 == w1_base + w1p) {
|
||||
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base) {
|
||||
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
|
||||
}
|
||||
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
|
||||
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
|
||||
}
|
||||
|
||||
if (weight > 0) {
|
||||
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
|
||||
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write accumulated gradient (no atomics needed)
|
||||
idata[index] = static_cast<scalar_t>(grad_sum);
|
||||
}
|
||||
#else
|
||||
const size_t o_numel = nc * width2 * height2;
|
||||
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
size_t index_temp = index;
|
||||
@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
|
||||
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
|
||||
true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
// threads are not covering the whole input tensor.
|
||||
grad_input.zero_();
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
const int num_threads = std::min(
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
constexpr bool use_input = true;
|
||||
#else
|
||||
constexpr bool use_input = false;
|
||||
#endif
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
|
||||
@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * output_height * output_width;
|
||||
|
||||
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
|
||||
input_height,
|
||||
@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
|
||||
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
|
||||
input_width, output_width, align_corners, scales_w);
|
||||
|
||||
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
|
||||
|
||||
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
|
||||
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
|
||||
num_threads,
|
||||
|
Reference in New Issue
Block a user