Compare commits

...

12 Commits

View File

@ -10,8 +10,10 @@
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <ATen/ceil_div.h>
#include <c10/macros/Macros.h>
#include <cmath>
#include <type_traits>
namespace at::native {
@ -513,6 +515,330 @@ namespace {
}
}
#ifdef USE_ROCM
// Note [ROCm-specific GridSampler backward optimization]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// The original backward kernel suffers from severe global atomic contention on (chiplet-based) AMD GPUs.
// This optimized implementation for ROCm uses two main strategies:
// 1. Shared memory (LDS) privatization: For `grad_input`, each thread block uses a 2D
// tile in shared memory as a private accumulator. Atomics are performed on this fast
// on-chip memory. After the block completes, results are written back to global memory,
// drastically reducing expensive global atomic operations.
// 2. Channel Chunking: To handle inputs with many channels, we process channels in fixed-size
// chunks to keep LDS usage bounded and predictable.
// 3. LDS Bank Conflict Avoidance: Padding is added to the inner dimension of the LDS
// tile to prevent bank conflicts between threads in a warp.
// 4. Compile-time specialization: The kernel is templated on interpolation mode,
// padding mode, align_corners, and whether the input requires gradients,
// which were all function arguments originally. The host-side dispatcher
// launches a specialized kernel version, allowing the compiler to eliminate
// all mode-related branching inside the kernel.
//
// This optimization is currently enabled for the common Bilinear mode.
// Other modes fall back to the original kernel to ensure correctness.
// Compile-time specialization to avoid runtime thread divergence.
// A separate kernel will be created for each combination and pre-compiled.
// This is expected to be a performance win.
// Particularly, templatized INPUT_REQ_GRAD makes the compiler completely remove
// all code related to `grad_input` if users don't need the gradient for `input`.
template <typename scalar_t, typename index_t, GridSamplerInterpolation INTERP_MODE,
GridSamplerPadding PADDING_MODE, bool ALIGN_CORNERS, bool INPUT_REQ_GRAD>
C10_LAUNCH_BOUNDS_1(256)
__global__ void grid_sampler_2d_backward_kernel_optimized(
TensorInfo<const scalar_t, index_t> grad_output,
TensorInfo<const scalar_t, index_t> input,
TensorInfo<const scalar_t, index_t> grid,
TensorInfo<scalar_t, index_t> grad_input,
TensorInfo<scalar_t, index_t> grad_grid,
const index_t grad_input_memory_span) {
// Typically `float` even if `scalar_t` is `half` or `bflaot16` to ensure
// all intermediate calculations (like accumulations in LDS) are done in full precision.
// Sometimes even `double` as well.
using opmath_t = at::opmath_type<scalar_t>;
// 2D tile size.
// Each thread block will process such a tile of the output grid.
constexpr int TILE_W = 16;
constexpr int TILE_H = 16;
// 1-pixel halo for Bilinear interpolation.
// Bilinear interpolation for an output pixel at (y, x) needs to write gradients to input pixels at (y_in, x_in), (y_in+1, x_in), (y_in, x_in+1), and (y_in+1, x_in+1).
// This ensures the LDS tile can accommodate contributions for pixels just outside the tile.
// Bicubic interpolation samples from a 4x4 area will need more halo.
constexpr int HALO = []() {
if constexpr (INTERP_MODE == GridSamplerInterpolation::Bicubic) {
return 2;
} else { // Bilinear or Nearest
return 1;
}
}();
// Full dimensions of the LDS tile, including the halo.
constexpr int SMEM_H = TILE_H + 2 * HALO;
constexpr int SMEM_W = TILE_W + 2 * HALO;
// 1 more pixel to the inner dimension to avoid LDS bank conflicts.
constexpr int SMEM_W_PAD = SMEM_W + 1;
// TODO
// Process channels in chunks of 4 to avoid using too much LDS.
// - CCHUNK=4: 5.5KB LDS per block (good for high occupancy)
// - CCHUNK=8: 11KB LDS (better for channel-heavy workloads)
// - CCHUNK=2: 2.75KB LDS (for very high occupancy)
// Adaptive chunking based on channel count
//index_t C = input.sizes(1);
//constexpr int CCHUNK = C < 64 ? 8 : 4;
constexpr int CCHUNK = 4;
// Dynamic shared memory layout: [CCHUNK][SMEM_H][SMEM_W_PAD]
// Using a type-agnostic fundamental type like `unsigned char` instead of
// the template type `opmath_t` to avoid potential redeclaration.
extern __shared__ unsigned char smem_raw[];
opmath_t* smem = reinterpret_cast<opmath_t*>(smem_raw);
// 3D array of [c_chunk_idx][y][x] to 1D smem_row indexing.
// x is the fastest-varying coordinate into the padded dimension.
auto smem_idx = [&](int c_chunk_idx, int y, int x) -> opmath_t& {
return smem[(c_chunk_idx * SMEM_H + y) * SMEM_W_PAD + x];
};
// `blockIdx.z` maps to the batch index of the 3D grid of blocks.
const index_t n = blockIdx.z;
const index_t inp_H = input.sizes[2];
const index_t inp_W = input.sizes[3];
const index_t out_H = grid.sizes[1];
const index_t out_W = grid.sizes[2];
// The top-left coord of the output tile the current block is working on.
const index_t tile_h_out = blockIdx.y * TILE_H;
const index_t tile_w_out = blockIdx.x * TILE_W;
// Channel chunking based loop.
// c0 is current starting index of the current channel chunk.
for (index_t c0 = 0; c0 < input.sizes[1]; c0 += CCHUNK) {
// The number of channels to process for the curretn iteration.
// The last, partial chunk may be smaller than CCHUNK.
const int cmax = min((int)CCHUNK, (int)(input.sizes[1] - c0));
// Cooperatively zero out shared memory for the channel chunk.
for (int c_idx = 0; c_idx < cmax; ++c_idx) {
for (int i = threadIdx.y; i < SMEM_H; i += blockDim.y) {
for (int j = threadIdx.x; j < SMEM_W_PAD; j += blockDim.x) {
smem_idx(c_idx, i, j) = 0;
}
}
}
__syncthreads();
// Each thread processes one output pixel.
// The unique global coord of the pixel:
const index_t h_out = tile_h_out + threadIdx.y; // Range: [tile_h_out, tile_h_out + TILE_H)
const index_t w_out = tile_w_out + threadIdx.x; // Range: [tile_w_out, tile_w_out + TILE_W)
// Ensure the thread is within the bounds of the output grid.
if (h_out < out_H && w_out < out_W) {
const auto grid_offset = n * grid.strides[0] + h_out * grid.strides[1] + w_out * grid.strides[2];
// The sampling coord (x, y) from the `grid` tensor.
const opmath_t x = grid.data[grid_offset];
const opmath_t y = grid.data[grid_offset + grid.strides[3]];
// The gradient multipliers needed for the `grad_grid` calculation.
opmath_t gix_mult, giy_mult;
// The corresponding input coords.
opmath_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, PADDING_MODE, ALIGN_CORNERS, &gix_mult);
opmath_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, PADDING_MODE, ALIGN_CORNERS, &giy_mult);
// Thread-local accumulators in registers for `grad_grid` not contended.
opmath_t gix_agg = 0, giy_agg = 0;
// NW anchor and fractional parts
const index_t ix_nw = static_cast<index_t>(std::floor(ix));
const index_t iy_nw = static_cast<index_t>(std::floor(iy));
const opmath_t tx = ix - ix_nw;
const opmath_t ty = iy - iy_nw;
// NW, NE, SW, SE pixel values from (x, y)
// The pure geometric bilinear interpolation weights
const opmath_t nw = (1 - tx) * (1 - ty);
const opmath_t ne = tx * (1 - ty);
const opmath_t sw = (1 - tx) * ty;
const opmath_t se = tx * ty;
// The inner loop over the channels within the current chunk.
for (int c_idx = 0; c_idx < cmax; ++c_idx) {
const index_t c = c0 + c_idx;
const opmath_t gOut = static_cast<opmath_t>(grad_output.data[n * grad_output.strides[0] + c * grad_output.strides[1] + h_out * grad_output.strides[2] + w_out * grad_output.strides[3]]);
const scalar_t* inp_ptr_NC = input.data + n * input.strides[0] + c * input.strides[1];
// Compile-time check for a performance win hopefully.
// When INPUT_REQ_GRAD is false, the entire `if` block is removed by the compiler.
if (INPUT_REQ_GRAD) {
// Typed LDS atomic add (supports `float`/`double` `opmath_t`).
// `opmat_t` itself already handles half precision -> full precision.
// If opmath_t could be `double`, casting to `float` may lose precision.
auto atomic_add_smem = [] __device__ (opmath_t* addr, opmath_t v) {
if constexpr (std::is_same<opmath_t, float>::value) {
atomicAdd(reinterpret_cast<float*>(addr), static_cast<float>(v));
} else if constexpr (std::is_same<opmath_t, double>::value) {
atomicAdd(reinterpret_cast<double*>(addr), static_cast<double>(v));
} else {
static_assert(!std::is_same<opmath_t, opmath_t>::value, "Unsupported opmath_t for LDS atomicAdd");
}
};
auto accumulate = [&](index_t iy_g, index_t ix_g, opmath_t weight) {
// Global bounds check, same as the original impelmentation,
// to check if the target input coord is valid.
if (within_bounds_2d(iy_g, ix_g, inp_H, inp_W)) {
// Map global input coord to local LDS coord.
const index_t s_iy = iy_g - tile_h_out + HALO;
const index_t s_ix = ix_g - tile_w_out + HALO;
// Check if the local coord falls within the LDS tile with halo.
if (s_iy >= 0 && s_iy < SMEM_H && s_ix >= 0 && s_ix < SMEM_W) {
// Core optimization: `atomicAdd` to LDS instead of global memory.
// Grid sampling is overwhelmingly used with `float`, `half`, or BFloat16 tensors.
// `double` is rarely required for this operation, though.
//atomicAdd(reinterpret_cast<float*>(&smem_idx(c_idx, s_iy, s_ix)), static_cast<float>(weight * gOut));
atomic_add_smem(&smem_idx(c_idx, s_iy, s_ix), weight * gOut);
} else {
// FIXME:
// Simply ignore the coord outside the tile.
}
}
};
accumulate(iy_nw, ix_nw, nw);
accumulate(iy_nw, ix_nw + 1, ne);
accumulate(iy_nw + 1, ix_nw, sw);
accumulate(iy_nw + 1, ix_nw + 1, se);
}
// Accumulate grad_grid contributions in registers.
opmath_t v_nw = 0, v_ne = 0, v_sw = 0, v_se = 0;
// Perform manual bounds checking, just like the original kernel.
// Pre-condition: only for `GridSamplerInterpolation::Bilinear` and `GridSamplerPadding::Zeros`.
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
v_nw = static_cast<opmath_t>(inp_ptr_NC[iy_nw * input.strides[2] + ix_nw * input.strides[3]]);
}
if (within_bounds_2d(iy_nw, ix_nw + 1, inp_H, inp_W)) {
v_ne = static_cast<opmath_t>(inp_ptr_NC[iy_nw * input.strides[2] + (ix_nw + 1) * input.strides[3]]);
}
if (within_bounds_2d(iy_nw + 1, ix_nw, inp_H, inp_W)) {
v_sw = static_cast<opmath_t>(inp_ptr_NC[(iy_nw + 1) * input.strides[2] + ix_nw * input.strides[3]]);
}
if (within_bounds_2d(iy_nw + 1, ix_nw + 1, inp_H, inp_W)) {
v_se = static_cast<opmath_t>(inp_ptr_NC[(iy_nw + 1) * input.strides[2] + (ix_nw + 1) * input.strides[3]]);
}
gix_agg += (v_ne - v_nw) * (1 - ty) * gOut + (v_se - v_sw) * ty * gOut;
giy_agg += (v_sw - v_nw) * (1 - tx) * gOut + (v_se - v_ne) * tx * gOut;
}
scalar_t* gGrid_ptr_NHW = grad_grid.data + n * grad_grid.strides[0] + h_out * grad_grid.strides[1] + w_out * grad_grid.strides[2];
gGrid_ptr_NHW[0] = static_cast<scalar_t>(gix_mult * gix_agg);
gGrid_ptr_NHW[1] = static_cast<scalar_t>(giy_mult * giy_agg);
}
__syncthreads();
if (INPUT_REQ_GRAD) {
for (int c_idx = 0; c_idx < cmax; ++c_idx) {
const index_t c = c0 + c_idx;
const index_t NC_offset = n * grad_input.strides[0] + c * grad_input.strides[1];
// Cooperatively write back from shared to global memory.
for (int i = threadIdx.y; i < SMEM_H; i += blockDim.y) {
for (int j = threadIdx.x; j < SMEM_W; j += blockDim.x) {
opmath_t val = smem_idx(c_idx, i, j);
if (val != 0) {
// Back to the global input coord.
index_t h_in = tile_h_out - HALO + i;
index_t w_in = tile_w_out - HALO + j;
// Atomics to global memory, whose number is reduced from
// 4*TILE_W*TILE_H to just SMEM_H*SMEM_W per channel chunk.
safe_add_2d(grad_input.data, h_in, w_in, grad_input.strides[2], grad_input.strides[3],
inp_H, inp_W, static_cast<scalar_t>(val), NC_offset, grad_input_memory_span);
}
}
}
}
}
__syncthreads();
}
}
template <typename scalar_t, typename index_t>
void launch_grid_sampler_2d_backward_dispatcher(
const TensorBase &grad_input, const TensorBase &grad_grid,
const TensorBase &grad_output, const TensorBase &input,
const TensorBase &grid, GridSamplerInterpolation interpolation_mode,
GridSamplerPadding padding_mode, bool align_corners,
bool input_requires_grad) {
auto stream = at::cuda::getCurrentCUDAStream();
const auto grad_input_memory_span = input_requires_grad ? static_cast<index_t>(grad_input.numel()) : 0;
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
constexpr int TILE_W = 16;
constexpr int TILE_H = 16;
dim3 threads(TILE_W, TILE_H);
dim3 blocks(
ceil_div(grid.size(2), (long)TILE_W),
ceil_div(grid.size(1), (long)TILE_H),
grid.size(0)); // Launch per-N
// TODO
// Calculate dynamic shared memory size
//index_t C = input.sizes(1);
// constexpr int CCHUNK = C < 64 ? 8 : 4;
constexpr int CCHUNK = 4;
constexpr int HALO = 1;
constexpr int SMEM_H = TILE_H + 2 * HALO;
constexpr int SMEM_W = TILE_W + 2 * HALO;
constexpr int SMEM_W_PAD = SMEM_W + 1;
// Size calculation uses the scalar_t of the current dispatcher instantiation.
const size_t smem_bytes = CCHUNK * SMEM_H * SMEM_W_PAD * sizeof(at::opmath_type<scalar_t>);
// Dispatcher for template parameters
if (padding_mode == GridSamplerPadding::Zeros) {
if (align_corners) {
if (input_requires_grad) {
grid_sampler_2d_backward_kernel_optimized<scalar_t, index_t, GridSamplerInterpolation::Bilinear, GridSamplerPadding::Zeros, true, true>
<<<blocks, threads, smem_bytes, stream>>>(
getTensorInfo<const scalar_t, index_t>(grad_output), getTensorInfo<const scalar_t, index_t>(input), getTensorInfo<const scalar_t, index_t>(grid),
getTensorInfo<scalar_t, index_t>(grad_input), getTensorInfo<scalar_t, index_t>(grad_grid), grad_input_memory_span);
} else { // !input_requires_grad
grid_sampler_2d_backward_kernel_optimized<scalar_t, index_t, GridSamplerInterpolation::Bilinear, GridSamplerPadding::Zeros, true, false>
<<<blocks, threads, smem_bytes, stream>>>(
getTensorInfo<const scalar_t, index_t>(grad_output), getTensorInfo<const scalar_t, index_t>(input), getTensorInfo<const scalar_t, index_t>(grid),
getTensorInfo<scalar_t, index_t>(grad_input), getTensorInfo<scalar_t, index_t>(grad_grid), grad_input_memory_span);
}
} else { // !align_corners
if (input_requires_grad) {
grid_sampler_2d_backward_kernel_optimized<scalar_t, index_t, GridSamplerInterpolation::Bilinear, GridSamplerPadding::Zeros, false, true>
<<<blocks, threads, smem_bytes, stream>>>(
getTensorInfo<const scalar_t, index_t>(grad_output), getTensorInfo<const scalar_t, index_t>(input), getTensorInfo<const scalar_t, index_t>(grid),
getTensorInfo<scalar_t, index_t>(grad_input), getTensorInfo<scalar_t, index_t>(grad_grid), grad_input_memory_span);
} else { // !input_requires_grad
grid_sampler_2d_backward_kernel_optimized<scalar_t, index_t, GridSamplerInterpolation::Bilinear, GridSamplerPadding::Zeros, false, false>
<<<blocks, threads, smem_bytes, stream>>>(
getTensorInfo<const scalar_t, index_t>(grad_output), getTensorInfo<const scalar_t, index_t>(input), getTensorInfo<const scalar_t, index_t>(grid),
getTensorInfo<scalar_t, index_t>(grad_input), getTensorInfo<scalar_t, index_t>(grad_grid), grad_input_memory_span);
}
}
} else {
// Fallback for other padding modes to the original kernel
const auto count = grid.size(0) * grid.size(1) * grid.size(2);
grid_sampler_2d_backward_kernel<scalar_t, index_t><<<GET_BLOCKS(count, 256), 256, 0, stream>>>(
static_cast<index_t>(count), getTensorInfo<const scalar_t, index_t>(grad_output), getTensorInfo<const scalar_t, index_t>(input),
getTensorInfo<const scalar_t, index_t>(grid), input_requires_grad ? getTensorInfo<scalar_t, index_t>(grad_input) : TensorInfo<scalar_t, index_t>(),
getTensorInfo<scalar_t, index_t>(grad_grid), interpolation_mode, padding_mode, align_corners,
grad_input_memory_span, input_requires_grad);
}
} else {
// Fallback for Nearest and Bicubic modes to the original kernel
const auto count = grid.size(0) * grid.size(1) * grid.size(2);
grid_sampler_2d_backward_kernel<scalar_t, index_t><<<GET_BLOCKS(count, 256), 256, 0, stream>>>(
static_cast<index_t>(count), getTensorInfo<const scalar_t, index_t>(grad_output), getTensorInfo<const scalar_t, index_t>(input),
getTensorInfo<const scalar_t, index_t>(grid), input_requires_grad ? getTensorInfo<scalar_t, index_t>(grad_input) : TensorInfo<scalar_t, index_t>(),
getTensorInfo<scalar_t, index_t>(grad_grid), interpolation_mode, padding_mode, align_corners,
grad_input_memory_span, input_requires_grad);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif // USE_ROCM
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void grid_sampler_3d_backward_kernel(
@ -840,15 +1166,44 @@ void launch_grid_sampler_2d_backward_kernel(
const TensorBase &grad_input, const TensorBase &grad_grid,
const TensorBase &grad_output, const TensorBase &input,
const TensorBase &grid, int64_t interpolation_mode, int64_t padding_mode,
bool align_corners, std::array<bool,2> output_mask) {
bool align_corners, std::array<bool, 2> output_mask) {
// See NOTE [ grid_sampler Native Functions ].
// Add checks here in case this is called instead of grid_sampler.
check_grid_sampler_common(input, grid);
check_grid_sampler_2d(input, grid);
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
// Nondeterministic because of atomicAdd usage in the underlying kernel:
// When multiple threads try to add to the same memory location at once,
// the order in which they add is not guaranteed, which can lead to
// tiny, bit-level differences in the final result across different runs.
globalContext().alertNotDeterministic("grid_sampler_2d_backward_cuda");
#ifdef USE_ROCM
auto input_requires_grad = output_mask[0];
// Neither grad_input nor grad_grid is required.
if (!output_mask[0] && !output_mask[1]) {
return;
}
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
input.scalar_type(), "grid_sampler_2d_backward_cuda", [&] {
// Small enough for 32-bit integer.
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) && canUse32BitIndexMath(grad_output)) {
launch_grid_sampler_2d_backward_dispatcher<scalar_t, int>(
grad_input, grad_grid, grad_output, input, grid,
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners, input_requires_grad);
} else {
launch_grid_sampler_2d_backward_dispatcher<scalar_t, int64_t>(
grad_input, grad_grid, grad_output, input, grid,
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<GridSamplerPadding>(padding_mode),
align_corners, input_requires_grad);
}
});
#else // CUDA Path
auto N = input.size(0);
auto H = grid.size(1);
auto W = grid.size(2);
@ -897,6 +1252,7 @@ void launch_grid_sampler_2d_backward_kernel(
}
});
}
#endif // USE_ROCM
}
void launch_grid_sampler_3d_backward_kernel(