Revert "[ROCm] new implementation of upsample_bilinear2d_backward (#164572)"

This reverts commit 53f9ae0e50d4dcc47f2ca4bf854803f9d4f875ae.

Reverted https://github.com/pytorch/pytorch/pull/164572 on behalf of https://github.com/seemethere due to Looks like this is failing in our internal builds, will post a suggestion for a fix but want you to double verify that this behavior is correct ([comment](https://github.com/pytorch/pytorch/pull/164572#issuecomment-3416262676))
This commit is contained in:
PyTorch MergeBot
2025-10-17 16:27:59 +00:00
parent 85c5433d38
commit faff826a46

View File

@ -127,29 +127,6 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
#ifdef USE_ROCM
// Helper function to compute output pixel range that can contribute to input pixel
template <typename accscalar_t>
__device__ __forceinline__ void compute_output_range(
int input_pos,
accscalar_t scale,
int output_size,
bool align_corners,
int& min_output,
int& max_output) {
accscalar_t lo, hi;
if (align_corners) {
lo = static_cast<accscalar_t>(input_pos - 1) / scale;
hi = static_cast<accscalar_t>(input_pos + 1) / scale;
} else {
lo = (input_pos - static_cast<accscalar_t>(0.5)) / scale - static_cast<accscalar_t>(0.5);
hi = (input_pos + static_cast<accscalar_t>(1.5)) / scale - static_cast<accscalar_t>(0.5);
}
min_output = max(0, static_cast<int>(ceil(lo)));
max_output = min(output_size - 1, static_cast<int>(floor(hi)));
}
#endif
// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(1024)
@ -164,74 +141,8 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
// In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
#ifdef USE_ROCM
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
index += blockDim.x * gridDim.x) {
// Decode input pixel coordinates
size_t index_temp = index;
const int w1 = index_temp % width1;
index_temp /= width1;
const int h1 = index_temp % height1;
const size_t nc_idx = index_temp / height1;
accscalar_t grad_sum = 0;
// Find range of output pixels that could interpolate from this input pixel
int h2_min, h2_max, w2_min, w2_max;
compute_output_range<accscalar_t>(h1, rheight, height2, align_corners, h2_min, h2_max);
compute_output_range<accscalar_t>(w1, rwidth, width2, align_corners, w2_min, w2_max);
// Iterate over potential output pixels
for (int h2 = h2_min; h2 <= h2_max; h2++) {
for (int w2 = w2_min; w2 <= w2_max; w2++) {
// Compute source coordinates for this output pixel
const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
rheight, h2, align_corners, /*cubic=*/false);
const int h1_base = (int)h1r;
const int h1p = (h1_base < height1 - 1) ? 1 : 0;
const accscalar_t h1lambda = h1r - h1_base;
const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
rwidth, w2, align_corners, /*cubic=*/false);
const int w1_base = (int)w1r;
const int w1p = (w1_base < width1 - 1) ? 1 : 0;
const accscalar_t w1lambda = w1r - w1_base;
const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
// Check if our input pixel participates in this interpolation and accumulate all weights
// At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
// to the same pixel, so we need to accumulate weights from all matching positions
accscalar_t weight = 0;
// Check all four interpolation positions and accumulate weights
if (h1 == h1_base && w1 == w1_base) {
weight += h0lambda * w0lambda; // top-left
}
if (h1 == h1_base && w1 == w1_base + w1p) {
weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base) {
weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
}
if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
}
if (weight > 0) {
const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
grad_sum += weight * static_cast<accscalar_t>(odata[output_idx]);
}
}
}
// Write accumulated gradient (no atomics needed)
idata[index] = static_cast<scalar_t>(grad_sum);
}
#else
const size_t o_numel = nc * width2 * height2;
const size_t i_numel = nc * width1 * height1;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@ -280,7 +191,6 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast<scalar_t>(h1lambda * w1lambda * d2val),
true);
}
#endif
}
template <typename scalar_t, typename accscalar_t>
@ -477,6 +387,7 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -486,12 +397,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
#ifdef USE_ROCM
constexpr bool use_input = true;
#else
constexpr bool use_input = false;
#endif
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@ -509,8 +414,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * output_height * output_width;
upsample_bilinear2d_backward_nhwc_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)), num_threads, 0, stream>>>(
input_height,
@ -541,8 +444,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
input_width, output_width, align_corners, scales_w);
const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
upsample_bilinear2d_backward_out_frame<scalar_t, accscalar_t>
<<<ceil_div(num_kernels, static_cast<size_t>(num_threads)),
num_threads,