[cuda] Add new faster gammabeta backward kernel (#148605) (Reapply with launch bounds) (#150625)

# Changes over the previous PR

This reverts commit 61a1f09 and adds `__launch_bounds__` to the kernel.

Previously I merged 114d404 that did not work on Blackwell because it consumed too many registers. It got reverted in 61a1f09. For more context see: https://github.com/pytorch/pytorch/issues/150266.

This PR reverts the revert (i.e. reapplies the original diff), with one additional line with `__launch_bounds__` added:

```
git diff HEAD^
diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
index 0d63a2f979c..3ce2c24c18e 100644
--- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu
+++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
@@ -657,6 +657,7 @@ bool aligned_grid
 >
 __global__
 void
+__launch_bounds__(block_dim_x * block_dim_y)
  GammaBetaBackwardCUDAKernelTemplate(
     int64_t M,
     int64_t N,
```

I managed to get a Blackwell machine and verified that the fix works. The fix was verified using this repro that I got from @drisspg

<details>
<summary> Repro script that fails on Blackwell </summary>

```
import torch
from torch.nn import init
# from transformer_nuggets import init_logging
# from transformer_nuggets.utils.benchmark import profiler
# from pathlib import Path

# init_logging()

class PermuteModule(torch.nn.Module):
    def __init__(self, permutation):
        super(PermuteModule, self).__init__()
        self.permutation = permutation
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        assert len(x.shape) == len(self.permutation), f"Dimension mismatch! Unable to permute {len(x.shape)} dim input with a {len(self.permutation)} dim permutation!"
        return x.permute(*self.permutation)

def test(n_layers:int, conv_stride:int):
    _sequence = []
    for _ in range(n_layers):
        # Conv1d inputs are (N x C x L), LayerNorm expects (* x C). Dims must be permuted between modules.
        _sequence += [
            PermuteModule((0,2,1)),
            torch.nn.Conv1d(in_channels=512, out_channels=512, groups=1, kernel_size=9, dilation=1, stride=conv_stride, padding=0, bias=False),
            PermuteModule((0,2,1)),
            torch.nn.LayerNorm(512),
            torch.nn.ReLU()
        ]
    model = torch.nn.Sequential(*_sequence).to(device="cuda")
    data = torch.randn((100,2048,512), device="cuda")
    out = model(data)
    loss = torch.nn.functional.mse_loss(out, torch.rand_like(out))
    loss.backward()

torch.autograd.set_detect_anomaly(True)
print(f"Torch version: {torch.__version__}")

# with profiler(Path("conv")):
#     # print(f"layers=1, stride=1")
#     # test(n_layers=1, conv_stride=1)
#     # print(f"layers=2, stride=1")
#     # test(n_layers=2, conv_stride=1)
#     # print(f"layers=1, stride=2")
#     # test(n_layers=1, conv_stride=2)
#     print(f"layers=2, stride=2")
#     test(n_layers=2, conv_stride=2)

print(f"layers=2, stride=2")
test(n_layers=2, conv_stride=2)
# we will not reach this print statement.
print("DONE.")
```

</details>

I also re-ran my performance benchmark and found no regressions over the previous PR.

# Full description of the old PR

Original PR: https://github.com/pytorch/pytorch/pull/148605

This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way.

To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables:

1. dtype in {half, float}
2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)`
4. Whether we flush the L2 cache before running the backward pass

Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster).

In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*.

Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old:

M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary:
```
    ----------------------- ----------- ------------
    Metric Name             Metric Unit Metric Value
    ----------------------- ----------- ------------
    DRAM Frequency                  Ghz         1.59
    SM Frequency                    Ghz         1.35
    Elapsed Cycles                cycle       27,526
    Memory Throughput                 %         2.21
    DRAM Throughput                   %         0.54
    Duration                         us        20.42
    L1/TEX Cache Throughput           %         4.31
    L2 Cache Throughput               %         2.62
    SM Active Cycles              cycle     1,475.02
    Compute (SM) Throughput           %         0.29
    ----------------------- ----------- ------------
```

M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary:
```
    ----------------------- ----------- ------------
    Metric Name             Metric Unit Metric Value
    ----------------------- ----------- ------------
    DRAM Frequency                  Ghz         1.59
    SM Frequency                    Ghz         1.34
    Elapsed Cycles                cycle       10,920
    Memory Throughput                 %         5.64
    DRAM Throughput                   %         1.35
    Duration                         us         8.13
    L1/TEX Cache Throughput           %         1.92
    L2 Cache Throughput               %         6.89
    SM Active Cycles              cycle     3,554.41
    Compute (SM) Throughput           %         0.67
    ----------------------- ----------- ------------
```

Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following:

<img width="1508" alt="image" src="https://github.com/user-attachments/assets/06179599-b2f0-4a45-8664-247a1067950b" />

There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made:

![image](https://github.com/user-attachments/assets/90c26f7c-e3ad-46d2-a6ce-fe4b5fb3d738)

For dtype=float32, we get a similar chart:

<img width="1499" alt="image" src="https://github.com/user-attachments/assets/c4d31a76-03b0-426c-9114-e1bfad29b530" />

The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension).

The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough.

I am including the regressions here for completeness' sake:

<img width="1500" alt="image" src="https://github.com/user-attachments/assets/31c17cfb-ed9b-4106-b9c8-5c359751f530" />

To see this better:

1. Click the image
2. Right click the expanded image and open in a new tab
3. Go to that tab and left click once to zoom in

If you want to see the full data, here it is:

![image](https://github.com/user-attachments/assets/54fb60c9-8c0c-4530-a1dd-79ecda1a69a1)

I also measured binary size and compile time since those are important for developers:

Binary size comparison

![image](https://github.com/user-attachments/assets/ceef5073-1036-47f6-b9dc-cea088beda51)

```
# Original
-rwxr-xr-x 1 ahmads users 307193112 Mar  6 08:46 ./torch/lib/libtorch_cuda.so

# This PR
-rwxr-xr-x 1 ahmads users 307193112 Mar  6 08:46 ./torch/lib/libtorch_cuda.so
```

The diff in bytes is 302kB which is about a 0.1% increase.

Compile time difference:

```
# Original

real    0m10.931s
user    0m9.676s
sys     0m1.004s

# this PR

real    0m16.720s
user    0m15.514s
sys     0m1.066s

# Command I ran
time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda  -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o

```

So the new PR is 6 seconds longer compile time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150625
Approved by: https://github.com/ngimel, https://github.com/atalman
This commit is contained in:
Ahmad Sharif
2025-04-08 02:39:41 +00:00
committed by PyTorch MergeBot
parent c0991b0316
commit 73b4938f7c
2 changed files with 354 additions and 196 deletions

View File

@ -540,191 +540,365 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(
}
}
// This implementation gets called if M and N divide with 32. This case should
// be the most common. We can then make better use of warp level intrinsics
// to improve performance.
template <typename T, typename T_ACC>
__global__ void GammaBetaBackwardCUDAKernel_32x32(
template <typename T, typename T_ACC,
unsigned int block_dim_x,
unsigned int block_dim_y,
unsigned int rows_per_block_y,
bool check_x,
bool check_y>
__device__
__forceinline__
void
blockReduceGammaBetaBackwardsHelper(
int64_t M_start,
int64_t M,
int64_t N,
const T* dY,
const T* X,
const T_ACC* mean,
const T_ACC* rstd,
T* dg,
T* db) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
T_ACC* s_dg;
T_ACC* s_db;
const T* __restrict__ dY,
const T* __restrict__ X,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
T* __restrict__ dg,
T* __restrict__ db,
T_ACC &dg_sum,
T_ACC &db_sum
) {
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x;
int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1);
int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y;
T_ACC warp_mean = 0, warp_rstd = 0;
if (lane_id < rows_per_thread_y && mean_index + lane_id < M) {
warp_mean = mean[mean_index + lane_id];
warp_rstd = rstd[mean_index + lane_id];
}
// We do a WARP_SYNC() here because we use WARP_SHFL below to access
// warp_mean and warp_rstd.
WARP_SYNC();
T_ACC dY_regs[rows_per_thread_y] = {0};
T_ACC X_regs[rows_per_thread_y] = {0};
#pragma unroll
for (int i = 0; i < rows_per_thread_y; ++i) {
int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i;
bool active = true;
if (check_x && thread_x >= N) {
active = false;
}
if (check_y && current_y >= M) {
active = false;
}
if (active) {
dY_regs[i] = dY[current_y * N + thread_x];
X_regs[i] = X[current_y * N + thread_x];
}
}
#pragma unroll
for (int i = 0; i < rows_per_thread_y; ++i) {
T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize);
T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize);
dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg;
db_sum += dY_regs[i];
}
}
template <typename T, typename T_ACC,
unsigned int block_dim_x,
unsigned int block_dim_y,
unsigned int rows_per_block_y,
bool check_x,
bool check_y>
__device__
__forceinline__
void
blockReduceGammaBetaBackwardsWithChecks(
int64_t M,
int64_t N,
const T* __restrict__ dY,
const T* __restrict__ X,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
T* __restrict__ dg,
T* __restrict__ db,
T_ACC &dg_sum,
T_ACC &db_sum
) {
for (int64_t M_start = blockIdx.y * rows_per_block_y;
M_start < M;
M_start += rows_per_block_y * gridDim.y) {
int64_t M_end = M_start + rows_per_block_y - 1;
if (!check_y || M_end < M) {
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, false>
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
} else {
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, true>
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
}
}
}
// block_dim_x is the number of threads in the x dimension per block.
// block_dim_y is the number of threads in the y dimension per block.
// rows_per_block_y is the size of the tile (number of data elements)
// in the y dimension per block.
// partial_reduction indicates whether we need to reduce across threads
// or not. If set to true, we will not reduce across threads. This can
// be faster in the M >> N case but requires another kernel to do a full
// final reduction.
// aligned_grid means the data size is a multiple of tile size. In that
// case we don't need to check for boundary conditions which can provide
// a further speedup by not needing instructions to check for edge cases
// and not needing predicate registers.
template <typename T, typename T_ACC,
unsigned int block_dim_x, unsigned int block_dim_y,
unsigned int rows_per_block_y,
bool partial_reduction,
bool aligned_grid
>
__global__
void
__launch_bounds__(block_dim_x * block_dim_y)
GammaBetaBackwardCUDAKernelTemplate(
int64_t M,
int64_t N,
const T* __restrict__ dY,
const T* __restrict__ X,
const T_ACC* __restrict__ mean,
const T_ACC* __restrict__ rstd,
T* __restrict__ dg,
T* __restrict__ db) {
// This assert is a compile-time check only.
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
static_assert(rows_per_thread_y <= kWarpSize);
T_ACC dg_sum = 0;
T_ACC db_sum = 0;
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
if (aligned_grid) {
// When N and M align perfectly with block_dim_x and block_dim_y, we
// can skip boundary condition checks that waste instruction issue slots.
blockReduceGammaBetaBackwardsWithChecks
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, false>
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
} else {
// In the general case we need to check boundary conditions in the M
// dimension. However, we can still avoid boundary checks in the N dimension
// for the inner blocks. So try to avoid those checks when possible.
if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) {
blockReduceGammaBetaBackwardsWithChecks
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, true>
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
} else {
blockReduceGammaBetaBackwardsWithChecks
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true, true>
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
}
}
if (j < N) {
constexpr int unroll_factor = 8;
int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x;
T_ACC mean_reg, mean_reg_tmp;
T_ACC rstd_reg, rstd_reg_tmp;
T dY_reg;
T X_reg;
// Main loop
int bcounter;
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
bcounter++) {
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
if (laneId < unroll_factor) {
mean_reg_tmp = mean[offset + laneId];
rstd_reg_tmp = rstd[offset + laneId];
// When partial_reduction is requested, we don't reduce within a block.
// We also don't reduce if we are only a single block in the y dimension.
if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) {
if (aligned_grid || thread_x < N) {
int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y;
if (dg) {
dg[thread_y * N + thread_x] = dg_sum;
}
WARP_SYNC();
#pragma unroll
for (int ii = 0; ii < unroll_factor; ++ii) {
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
if (db) {
db[thread_y * N + thread_x] = db_sum;
}
}
// Remainder loop
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
for (int ii = 0; ii < unroll_factor; ii++) {
if ((offset + ii) < M) {
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
}
// This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
// gets called when M; N divide by 32. We can use warp shuffles
// for the final reduction step. This removes 4 shmem loads and
// stores with their corresponding __syncthreads()
// This greatly reduces bank conflicts at the expense of a little
// extra shared memory. It does not impact occupancy
int padded_bx = (1 + blockDim.x);
} else {
// The caller requested a full reduction so we must reduce across
// warps using shared memory and warp shuffles.
static_assert(rows_per_thread_y <= C10_WARP_SIZE);
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
T_ACC* s_dg;
T_ACC* s_db;
int padded_bx = (block_dim_x + 1);
// Transpose dg and db.
s_dg = s_data_typed;
s_db = s_data_typed + (padded_bx * blockDim.y);
s_db = s_data_typed + (padded_bx * block_dim_y);
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
__syncthreads();
// Load transposed so that a warp holds an entire column
T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
}
if (threadIdx.x == 0) {
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
if (dg) {
dg[j] = reg_dg;
// Because block_dim_x != block_dim_y in the general case, we need
// some code to handle the general case.
static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0);
constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE;
int thread_id = threadIdx.y * block_dim_x + threadIdx.x;
int warp_id = thread_id / C10_WARP_SIZE;
int lane_id = thread_id & (C10_WARP_SIZE - 1);
#pragma unroll
for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) {
T_ACC reg_db, reg_dg;
if (lane_id < block_dim_y) {
reg_dg = s_dg[lane_id * padded_bx + i];
reg_db = s_db[lane_id * padded_bx + i];
}
if (db) {
db[j] = reg_db;
#pragma unroll
for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) {
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
}
// Reduce is done. Now write it out to global memory.
int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i;
if (threadIdx.x == 0 && (aligned_grid || out_index < N)) {
if (dg) {
dg[out_index] = reg_dg;
}
if (db) {
db[out_index] = reg_db;
}
}
}
}
}
template <typename T, typename T_ACC>
__global__ void GammaBetaBackwardCUDAKernel(
template<typename T, typename T_ACC,
int block_dim_x, int block_dim_y,
int rows_per_block_y,
bool partial_reduction>
void LaunchAndCheckGammaBetaBackwardKernel(
bool aligned_grid,
dim3 blocks,
dim3 threads,
size_t shmem_sz,
cudaStream_t cuda_stream,
const T* dY_data,
const T* X_data,
const T_ACC* mean_data,
const T_ACC* rstd_data,
int64_t M,
int64_t N,
T* dgamma_data,
T* dbeta_data) {
if (aligned_grid) {
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, true>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
} else {
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, false>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename T, typename T_ACC,
int block_dim_x, int block_dim_y,
int rows_per_block_y>
void ConfigureAndLaunchGammaBetaBackwardKernel(
const T* dY_data,
const T* X_data,
const T_ACC* mean_data,
const T_ACC* rstd_data,
int64_t M,
int64_t N,
const T* dY,
const T* X,
const T_ACC* mean,
const T_ACC* rstd,
T* dg,
T* db) {
alignas(sizeof(double)) extern __shared__ char s_data1[];
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
T_ACC* s_dg;
T_ACC* s_db;
Tensor* dgamma,
Tensor* dbeta,
cudaStream_t cuda_stream) {
T* dgamma_data =
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
dim3 threads{block_dim_x, block_dim_y};
dim3 blocks;
blocks.x = (N + block_dim_x - 1) / block_dim_x;
blocks.y = 1;
size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2;
if (blocks.y == 1 && threads.y == 1) {
// Optimization: since there is just one thread doing all the summation, we don't need a reduction
// across threads. So we set partial_reduction to true.
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
} else {
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false>(
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
}
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
}
T_ACC dg_sum = 0;
T_ACC db_sum = 0;
if (j < N) {
constexpr int unroll_factor = 8;
T_ACC mean_reg;
T_ACC rstd_reg;
T dY_reg;
T X_reg;
// Main Loop
int bcounter;
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
#pragma unroll
for (int ii = 0; ii < unroll_factor; ++ii) {
dY_reg = dY[(offset + ii) * N + j];
X_reg = X[(offset + ii) * N + j];
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
template<typename T, typename T_ACC>
void LaunchGammaBetaBackwardCUDAKernel(
const T* dY_data,
const T* X_data,
const T_ACC* mean_data,
const T_ACC* rstd_data,
int64_t M,
int64_t N,
Tensor* dgamma,
Tensor* dbeta,
cudaStream_t cuda_stream) {
constexpr int block_dim_x = 32;
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) {
// We have a situation where M >> N and N is small.
// In this case we can speed up the computation by parallelizing in the M dimension.
// We launch multiple blocks in the y-dimension, and compute partial sums for the
// gradient in the first pass. Then we do a .sum(0) to do a final reduction.
// Although we launch 2 kernels, we can get up to a 10x speedup for large M.
constexpr int block_dim_y = 1;
constexpr int rows_per_block_y = 32;
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
dim3 threads{block_dim_x, block_dim_y};
dim3 blocks;
blocks.x = (N + block_dim_x - 1) / block_dim_x;
// int rows_per_block = my_gamma_beta_unroll_factor *
blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y;
constexpr int max_grid_size = 64 * 1024 / 2;
blocks.y = std::min<unsigned int>(max_grid_size / blocks.x, blocks.y);
Tensor dgamma_blocks;
Tensor dbeta_blocks;
T * dgamma_blocks_ptr = nullptr;
T * dbeta_blocks_ptr = nullptr;
if (dgamma->defined()) {
auto options = dgamma->options();
dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
dgamma_blocks_ptr = dgamma_blocks.data_ptr<T>();
}
// Remainder loop
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
for (int ii = 0; ii < unroll_factor; ii++ ){
if ((offset + ii) < M) {
dY_reg = dY[(offset + ii) * N + j ];
X_reg = X[(offset + ii) * N + j];
mean_reg = mean[offset + ii];
rstd_reg = rstd[offset + ii];
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
db_sum += dY_reg;
}
if (dbeta->defined()) {
auto options = dbeta->options();
dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
dbeta_blocks_ptr = dbeta_blocks.data_ptr<T>();
}
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr);
// Do the final reduction in shared memory
s_dg = s_data_typed;
s_db = s_data_typed + blockDim.x * blockDim.y;
s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
__syncthreads();
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
if (threadIdx.y < offset) {
s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
}
__syncthreads();
}
if (threadIdx.y == 0) {
if (dg) {
dg[j] = s_dg[threadIdx.x];
}
if (db) {
db[j] = s_db[threadIdx.x];
}
*dgamma = dgamma_blocks.sum(0);
*dbeta = dbeta_blocks.sum(0);
} else {
// We are in the normal case where M is not that large.
// We can change the tile shape (which is the last template parameter) in accordance with M.
// For small M it is faster to have a smaller tile, otherwise we could have idle threads.
// For larger M we use a bigger tile size.
if (M < 64) {
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 1, 8>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
} else if (M < 128) {
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 8, 64>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
} else if (M < 256) {
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 16, 128>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
} else {
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 32, 256>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
}
}
}
@ -1250,6 +1424,7 @@ void LayerNormBackwardKernelImplInternal(
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
#if defined(USE_ROCM)
if (M < 128) {
// For small batch size, do colwise reduce directly.
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
@ -1265,7 +1440,6 @@ void LayerNormBackwardKernelImplInternal(
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
#if defined(USE_ROCM)
// For small batch size, do colwise reduce directly.
const int part_size = warp_size;
const dim3 threads2(warp_size, 4, 1);
@ -1300,47 +1474,11 @@ void LayerNormBackwardKernelImplInternal(
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
// This implementation relies on warp primitives and requires that M and N divide
// exactly to warp size.
dim3 threads{kWarpSize, kWarpSize};
int blocks = (N + threads.x - 1) / threads.x;
// If M and N divide by warp_size, we can use warp shuffles for the final reduction.
// That requires transposing values in shared memory, so we apply a padding to
// reduce bank conflicts.
size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
dim3 threads{16, 32};
int blocks = (N + threads.x - 1) / threads.x;
size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
GammaBetaBackwardCUDAKernel<T, T_ACC>
<<<blocks, threads, shmem_sz, cuda_stream>>>(
M,
N,
dY_data,
X_data,
mean_data,
rstd_data,
dgamma_data,
dbeta_data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
}
#else
LaunchGammaBetaBackwardCUDAKernel(
dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
#endif
}
}

View File

@ -7195,6 +7195,26 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
self.assertEqual(ln.forward(x), torch.zeros_like(x))
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
def test_layer_norm_backwards_eps(self):
dtype = torch.float
m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55),
(32, 32), (1024, 32), (1024, 1024),
(33, 33), (1025, 33), (1025, 1025)]
for m, n in m_x_n_list:
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
grad_output = torch.rand_like(x)
x_cuda = x.clone().detach().to("cuda").requires_grad_()
grad_output_cuda = grad_output.clone().detach().to("cuda")
ln = nn.LayerNorm(n, dtype=dtype)
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype)
ln_out = ln(x)
ln_out_cuda = ln_cuda(x_cuda)
ln_out.backward(grad_output)
ln_out_cuda.backward(grad_output_cuda)
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
@largeTensorTest("40GB", device="cuda")
def test_layer_norm_large_tensor(self):
# test for https://github.com/pytorch/pytorch/issues/136291