mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup. The changes are from Yu-Yun <YuYun.Chang@amd.com>. # Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X - The original "scatter-add" approach - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs. - The new "gather-sum" approach - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function). # Breakdown of the code changes - Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame` - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`). - Each thread processed one output pixel. - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`). - Each thread is responsible for calculating the final gradient for a single input pixel. - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`. - Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation - This is essentially the mathematical inverse of the forward pass. - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor. - Gradient calculation approach switching from "scatter-add" to "gather-sum" - Scatter-add - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor. - Gather-sum - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value. - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel. - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`). - W/o any gloabl memory access, this accumulation is extremely fast. - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`). # Why performance gets boosted - Analysis of the root cause of performance drop - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward - First and foremost, elimination of the contention of atomic operations - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time. - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points. - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue. - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect. - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`. - Improved memory access pattern and locality - Write coalescing - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs. - Read locality - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread. - Trade-off: computation for memory synchronization - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X. - Removal of atomic operations avoids expensive memory synchronization. --- Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR. Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572 Approved by: https://github.com/jeffdaily