mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Changes over the previous PR This reverts commit 61a1f09 and adds `__launch_bounds__` to the kernel. Previously I merged 114d404 that did not work on Blackwell because it consumed too many registers. It got reverted in 61a1f09. For more context see: https://github.com/pytorch/pytorch/issues/150266. This PR reverts the revert (i.e. reapplies the original diff), with one additional line with `__launch_bounds__` added: ``` git diff HEAD^ diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 0d63a2f979c..3ce2c24c18e 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -657,6 +657,7 @@ bool aligned_grid > __global__ void +__launch_bounds__(block_dim_x * block_dim_y) GammaBetaBackwardCUDAKernelTemplate( int64_t M, int64_t N, ``` I managed to get a Blackwell machine and verified that the fix works. The fix was verified using this repro that I got from @drisspg <details> <summary> Repro script that fails on Blackwell </summary> ``` import torch from torch.nn import init # from transformer_nuggets import init_logging # from transformer_nuggets.utils.benchmark import profiler # from pathlib import Path # init_logging() class PermuteModule(torch.nn.Module): def __init__(self, permutation): super(PermuteModule, self).__init__() self.permutation = permutation def forward(self, x:torch.Tensor) -> torch.Tensor: assert len(x.shape) == len(self.permutation), f"Dimension mismatch! Unable to permute {len(x.shape)} dim input with a {len(self.permutation)} dim permutation!" return x.permute(*self.permutation) def test(n_layers:int, conv_stride:int): _sequence = [] for _ in range(n_layers): # Conv1d inputs are (N x C x L), LayerNorm expects (* x C). Dims must be permuted between modules. _sequence += [ PermuteModule((0,2,1)), torch.nn.Conv1d(in_channels=512, out_channels=512, groups=1, kernel_size=9, dilation=1, stride=conv_stride, padding=0, bias=False), PermuteModule((0,2,1)), torch.nn.LayerNorm(512), torch.nn.ReLU() ] model = torch.nn.Sequential(*_sequence).to(device="cuda") data = torch.randn((100,2048,512), device="cuda") out = model(data) loss = torch.nn.functional.mse_loss(out, torch.rand_like(out)) loss.backward() torch.autograd.set_detect_anomaly(True) print(f"Torch version: {torch.__version__}") # with profiler(Path("conv")): # # print(f"layers=1, stride=1") # # test(n_layers=1, conv_stride=1) # # print(f"layers=2, stride=1") # # test(n_layers=2, conv_stride=1) # # print(f"layers=1, stride=2") # # test(n_layers=1, conv_stride=2) # print(f"layers=2, stride=2") # test(n_layers=2, conv_stride=2) print(f"layers=2, stride=2") test(n_layers=2, conv_stride=2) # we will not reach this print statement. print("DONE.") ``` </details> I also re-ran my performance benchmark and found no regressions over the previous PR. # Full description of the old PR Original PR: https://github.com/pytorch/pytorch/pull/148605 This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way. To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables: 1. dtype in {half, float} 2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)` 3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)` 4. Whether we flush the L2 cache before running the backward pass Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster). In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*. Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old: M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary: ``` ----------------------- ----------- ------------ Metric Name Metric Unit Metric Value ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.35 Elapsed Cycles cycle 27,526 Memory Throughput % 2.21 DRAM Throughput % 0.54 Duration us 20.42 L1/TEX Cache Throughput % 4.31 L2 Cache Throughput % 2.62 SM Active Cycles cycle 1,475.02 Compute (SM) Throughput % 0.29 ----------------------- ----------- ------------ ``` M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary: ``` ----------------------- ----------- ------------ Metric Name Metric Unit Metric Value ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.34 Elapsed Cycles cycle 10,920 Memory Throughput % 5.64 DRAM Throughput % 1.35 Duration us 8.13 L1/TEX Cache Throughput % 1.92 L2 Cache Throughput % 6.89 SM Active Cycles cycle 3,554.41 Compute (SM) Throughput % 0.67 ----------------------- ----------- ------------ ``` Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following: <img width="1508" alt="image" src="https://github.com/user-attachments/assets/06179599-b2f0-4a45-8664-247a1067950b" /> There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made:  For dtype=float32, we get a similar chart: <img width="1499" alt="image" src="https://github.com/user-attachments/assets/c4d31a76-03b0-426c-9114-e1bfad29b530" /> The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension). The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough. I am including the regressions here for completeness' sake: <img width="1500" alt="image" src="https://github.com/user-attachments/assets/31c17cfb-ed9b-4106-b9c8-5c359751f530" /> To see this better: 1. Click the image 2. Right click the expanded image and open in a new tab 3. Go to that tab and left click once to zoom in If you want to see the full data, here it is:  I also measured binary size and compile time since those are important for developers: Binary size comparison  ``` # Original -rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so # This PR -rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so ``` The diff in bytes is 302kB which is about a 0.1% increase. Compile time difference: ``` # Original real 0m10.931s user 0m9.676s sys 0m1.004s # this PR real 0m16.720s user 0m15.514s sys 0m1.066s # Command I ran time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o ``` So the new PR is 6 seconds longer compile time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150625 Approved by: https://github.com/ngimel, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
c0991b0316
commit
73b4938f7c
@ -540,191 +540,365 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel(
|
||||
}
|
||||
}
|
||||
|
||||
// This implementation gets called if M and N divide with 32. This case should
|
||||
// be the most common. We can then make better use of warp level intrinsics
|
||||
// to improve performance.
|
||||
|
||||
template <typename T, typename T_ACC>
|
||||
__global__ void GammaBetaBackwardCUDAKernel_32x32(
|
||||
template <typename T, typename T_ACC,
|
||||
unsigned int block_dim_x,
|
||||
unsigned int block_dim_y,
|
||||
unsigned int rows_per_block_y,
|
||||
bool check_x,
|
||||
bool check_y>
|
||||
__device__
|
||||
__forceinline__
|
||||
void
|
||||
blockReduceGammaBetaBackwardsHelper(
|
||||
int64_t M_start,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* dY,
|
||||
const T* X,
|
||||
const T_ACC* mean,
|
||||
const T_ACC* rstd,
|
||||
T* dg,
|
||||
T* db) {
|
||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||
T_ACC* s_dg;
|
||||
T_ACC* s_db;
|
||||
const T* __restrict__ dY,
|
||||
const T* __restrict__ X,
|
||||
const T_ACC* __restrict__ mean,
|
||||
const T_ACC* __restrict__ rstd,
|
||||
T* __restrict__ dg,
|
||||
T* __restrict__ db,
|
||||
T_ACC &dg_sum,
|
||||
T_ACC &db_sum
|
||||
) {
|
||||
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
|
||||
int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x;
|
||||
|
||||
int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1);
|
||||
int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y;
|
||||
T_ACC warp_mean = 0, warp_rstd = 0;
|
||||
if (lane_id < rows_per_thread_y && mean_index + lane_id < M) {
|
||||
warp_mean = mean[mean_index + lane_id];
|
||||
warp_rstd = rstd[mean_index + lane_id];
|
||||
}
|
||||
// We do a WARP_SYNC() here because we use WARP_SHFL below to access
|
||||
// warp_mean and warp_rstd.
|
||||
WARP_SYNC();
|
||||
|
||||
T_ACC dY_regs[rows_per_thread_y] = {0};
|
||||
T_ACC X_regs[rows_per_thread_y] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_thread_y; ++i) {
|
||||
int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i;
|
||||
bool active = true;
|
||||
if (check_x && thread_x >= N) {
|
||||
active = false;
|
||||
}
|
||||
if (check_y && current_y >= M) {
|
||||
active = false;
|
||||
}
|
||||
if (active) {
|
||||
dY_regs[i] = dY[current_y * N + thread_x];
|
||||
X_regs[i] = X[current_y * N + thread_x];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_thread_y; ++i) {
|
||||
T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize);
|
||||
T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize);
|
||||
dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg;
|
||||
db_sum += dY_regs[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_ACC,
|
||||
unsigned int block_dim_x,
|
||||
unsigned int block_dim_y,
|
||||
unsigned int rows_per_block_y,
|
||||
bool check_x,
|
||||
bool check_y>
|
||||
__device__
|
||||
__forceinline__
|
||||
void
|
||||
blockReduceGammaBetaBackwardsWithChecks(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* __restrict__ dY,
|
||||
const T* __restrict__ X,
|
||||
const T_ACC* __restrict__ mean,
|
||||
const T_ACC* __restrict__ rstd,
|
||||
T* __restrict__ dg,
|
||||
T* __restrict__ db,
|
||||
T_ACC &dg_sum,
|
||||
T_ACC &db_sum
|
||||
) {
|
||||
for (int64_t M_start = blockIdx.y * rows_per_block_y;
|
||||
M_start < M;
|
||||
M_start += rows_per_block_y * gridDim.y) {
|
||||
int64_t M_end = M_start + rows_per_block_y - 1;
|
||||
if (!check_y || M_end < M) {
|
||||
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, false>
|
||||
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
} else {
|
||||
blockReduceGammaBetaBackwardsHelper<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, check_x, true>
|
||||
(M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// block_dim_x is the number of threads in the x dimension per block.
|
||||
// block_dim_y is the number of threads in the y dimension per block.
|
||||
// rows_per_block_y is the size of the tile (number of data elements)
|
||||
// in the y dimension per block.
|
||||
// partial_reduction indicates whether we need to reduce across threads
|
||||
// or not. If set to true, we will not reduce across threads. This can
|
||||
// be faster in the M >> N case but requires another kernel to do a full
|
||||
// final reduction.
|
||||
// aligned_grid means the data size is a multiple of tile size. In that
|
||||
// case we don't need to check for boundary conditions which can provide
|
||||
// a further speedup by not needing instructions to check for edge cases
|
||||
// and not needing predicate registers.
|
||||
template <typename T, typename T_ACC,
|
||||
unsigned int block_dim_x, unsigned int block_dim_y,
|
||||
unsigned int rows_per_block_y,
|
||||
bool partial_reduction,
|
||||
bool aligned_grid
|
||||
>
|
||||
__global__
|
||||
void
|
||||
__launch_bounds__(block_dim_x * block_dim_y)
|
||||
GammaBetaBackwardCUDAKernelTemplate(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* __restrict__ dY,
|
||||
const T* __restrict__ X,
|
||||
const T_ACC* __restrict__ mean,
|
||||
const T_ACC* __restrict__ rstd,
|
||||
T* __restrict__ dg,
|
||||
T* __restrict__ db) {
|
||||
// This assert is a compile-time check only.
|
||||
constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y;
|
||||
static_assert(rows_per_thread_y <= kWarpSize);
|
||||
|
||||
T_ACC dg_sum = 0;
|
||||
T_ACC db_sum = 0;
|
||||
|
||||
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
if (aligned_grid) {
|
||||
// When N and M align perfectly with block_dim_x and block_dim_y, we
|
||||
// can skip boundary condition checks that waste instruction issue slots.
|
||||
blockReduceGammaBetaBackwardsWithChecks
|
||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, false>
|
||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
} else {
|
||||
// In the general case we need to check boundary conditions in the M
|
||||
// dimension. However, we can still avoid boundary checks in the N dimension
|
||||
// for the inner blocks. So try to avoid those checks when possible.
|
||||
if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) {
|
||||
blockReduceGammaBetaBackwardsWithChecks
|
||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false, true>
|
||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
} else {
|
||||
blockReduceGammaBetaBackwardsWithChecks
|
||||
<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true, true>
|
||||
(M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum);
|
||||
}
|
||||
}
|
||||
|
||||
if (j < N) {
|
||||
constexpr int unroll_factor = 8;
|
||||
int laneId = threadIdx.x & (C10_WARP_SIZE - 1);
|
||||
int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x;
|
||||
|
||||
T_ACC mean_reg, mean_reg_tmp;
|
||||
T_ACC rstd_reg, rstd_reg_tmp;
|
||||
T dY_reg;
|
||||
T X_reg;
|
||||
|
||||
// Main loop
|
||||
int bcounter;
|
||||
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor);
|
||||
bcounter++) {
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
|
||||
if (laneId < unroll_factor) {
|
||||
mean_reg_tmp = mean[offset + laneId];
|
||||
rstd_reg_tmp = rstd[offset + laneId];
|
||||
// When partial_reduction is requested, we don't reduce within a block.
|
||||
// We also don't reduce if we are only a single block in the y dimension.
|
||||
if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) {
|
||||
if (aligned_grid || thread_x < N) {
|
||||
int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y;
|
||||
if (dg) {
|
||||
dg[thread_y * N + thread_x] = dg_sum;
|
||||
}
|
||||
WARP_SYNC();
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < unroll_factor; ++ii) {
|
||||
dY_reg = dY[(offset + ii) * N + j];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize);
|
||||
rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize);
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
if (db) {
|
||||
db[thread_y * N + thread_x] = db_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Remainder loop
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
for (int ii = 0; ii < unroll_factor; ii++) {
|
||||
if ((offset + ii) < M) {
|
||||
mean_reg = mean[offset + ii];
|
||||
rstd_reg = rstd[offset + ii];
|
||||
dY_reg = dY[(offset + ii) * N + j];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
}
|
||||
|
||||
// This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and
|
||||
// gets called when M; N divide by 32. We can use warp shuffles
|
||||
// for the final reduction step. This removes 4 shmem loads and
|
||||
// stores with their corresponding __syncthreads()
|
||||
|
||||
// This greatly reduces bank conflicts at the expense of a little
|
||||
// extra shared memory. It does not impact occupancy
|
||||
int padded_bx = (1 + blockDim.x);
|
||||
|
||||
} else {
|
||||
// The caller requested a full reduction so we must reduce across
|
||||
// warps using shared memory and warp shuffles.
|
||||
static_assert(rows_per_thread_y <= C10_WARP_SIZE);
|
||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||
T_ACC* s_dg;
|
||||
T_ACC* s_db;
|
||||
int padded_bx = (block_dim_x + 1);
|
||||
// Transpose dg and db.
|
||||
s_dg = s_data_typed;
|
||||
s_db = s_data_typed + (padded_bx * blockDim.y);
|
||||
s_db = s_data_typed + (padded_bx * block_dim_y);
|
||||
s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum;
|
||||
s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum;
|
||||
__syncthreads();
|
||||
|
||||
// Load transposed so that a warp holds an entire column
|
||||
T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y];
|
||||
T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y];
|
||||
for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) {
|
||||
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
|
||||
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t j = blockIdx.x * blockDim.x + threadIdx.y;
|
||||
if (dg) {
|
||||
dg[j] = reg_dg;
|
||||
// Because block_dim_x != block_dim_y in the general case, we need
|
||||
// some code to handle the general case.
|
||||
static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0);
|
||||
constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE;
|
||||
int thread_id = threadIdx.y * block_dim_x + threadIdx.x;
|
||||
int warp_id = thread_id / C10_WARP_SIZE;
|
||||
int lane_id = thread_id & (C10_WARP_SIZE - 1);
|
||||
#pragma unroll
|
||||
for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) {
|
||||
T_ACC reg_db, reg_dg;
|
||||
if (lane_id < block_dim_y) {
|
||||
reg_dg = s_dg[lane_id * padded_bx + i];
|
||||
reg_db = s_db[lane_id * padded_bx + i];
|
||||
}
|
||||
if (db) {
|
||||
db[j] = reg_db;
|
||||
#pragma unroll
|
||||
for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) {
|
||||
reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize);
|
||||
reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize);
|
||||
}
|
||||
// Reduce is done. Now write it out to global memory.
|
||||
int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i;
|
||||
if (threadIdx.x == 0 && (aligned_grid || out_index < N)) {
|
||||
if (dg) {
|
||||
dg[out_index] = reg_dg;
|
||||
}
|
||||
if (db) {
|
||||
db[out_index] = reg_db;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename T_ACC>
|
||||
__global__ void GammaBetaBackwardCUDAKernel(
|
||||
template<typename T, typename T_ACC,
|
||||
int block_dim_x, int block_dim_y,
|
||||
int rows_per_block_y,
|
||||
bool partial_reduction>
|
||||
void LaunchAndCheckGammaBetaBackwardKernel(
|
||||
bool aligned_grid,
|
||||
dim3 blocks,
|
||||
dim3 threads,
|
||||
size_t shmem_sz,
|
||||
cudaStream_t cuda_stream,
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const T_ACC* mean_data,
|
||||
const T_ACC* rstd_data,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
T* dgamma_data,
|
||||
T* dbeta_data) {
|
||||
if (aligned_grid) {
|
||||
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, true>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
} else {
|
||||
GammaBetaBackwardCUDAKernelTemplate<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, partial_reduction, false>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename T, typename T_ACC,
|
||||
int block_dim_x, int block_dim_y,
|
||||
int rows_per_block_y>
|
||||
void ConfigureAndLaunchGammaBetaBackwardKernel(
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const T_ACC* mean_data,
|
||||
const T_ACC* rstd_data,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
const T* dY,
|
||||
const T* X,
|
||||
const T_ACC* mean,
|
||||
const T_ACC* rstd,
|
||||
T* dg,
|
||||
T* db) {
|
||||
alignas(sizeof(double)) extern __shared__ char s_data1[];
|
||||
T_ACC* s_data_typed = reinterpret_cast<T_ACC*>(&s_data1);
|
||||
T_ACC* s_dg;
|
||||
T_ACC* s_db;
|
||||
Tensor* dgamma,
|
||||
Tensor* dbeta,
|
||||
cudaStream_t cuda_stream) {
|
||||
T* dgamma_data =
|
||||
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
||||
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
||||
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
|
||||
dim3 threads{block_dim_x, block_dim_y};
|
||||
dim3 blocks;
|
||||
blocks.x = (N + block_dim_x - 1) / block_dim_x;
|
||||
blocks.y = 1;
|
||||
size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2;
|
||||
if (blocks.y == 1 && threads.y == 1) {
|
||||
// Optimization: since there is just one thread doing all the summation, we don't need a reduction
|
||||
// across threads. So we set partial_reduction to true.
|
||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
|
||||
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
|
||||
} else {
|
||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, false>(
|
||||
aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data);
|
||||
}
|
||||
|
||||
const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x;
|
||||
}
|
||||
|
||||
T_ACC dg_sum = 0;
|
||||
T_ACC db_sum = 0;
|
||||
|
||||
if (j < N) {
|
||||
constexpr int unroll_factor = 8;
|
||||
|
||||
T_ACC mean_reg;
|
||||
T_ACC rstd_reg;
|
||||
T dY_reg;
|
||||
T X_reg;
|
||||
|
||||
// Main Loop
|
||||
int bcounter;
|
||||
for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < unroll_factor; ++ii) {
|
||||
dY_reg = dY[(offset + ii) * N + j];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
mean_reg = mean[offset + ii];
|
||||
rstd_reg = rstd[offset + ii];
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
template<typename T, typename T_ACC>
|
||||
void LaunchGammaBetaBackwardCUDAKernel(
|
||||
const T* dY_data,
|
||||
const T* X_data,
|
||||
const T_ACC* mean_data,
|
||||
const T_ACC* rstd_data,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
Tensor* dgamma,
|
||||
Tensor* dbeta,
|
||||
cudaStream_t cuda_stream) {
|
||||
constexpr int block_dim_x = 32;
|
||||
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) {
|
||||
// We have a situation where M >> N and N is small.
|
||||
// In this case we can speed up the computation by parallelizing in the M dimension.
|
||||
// We launch multiple blocks in the y-dimension, and compute partial sums for the
|
||||
// gradient in the first pass. Then we do a .sum(0) to do a final reduction.
|
||||
// Although we launch 2 kernels, we can get up to a 10x speedup for large M.
|
||||
constexpr int block_dim_y = 1;
|
||||
constexpr int rows_per_block_y = 32;
|
||||
bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0);
|
||||
dim3 threads{block_dim_x, block_dim_y};
|
||||
dim3 blocks;
|
||||
blocks.x = (N + block_dim_x - 1) / block_dim_x;
|
||||
// int rows_per_block = my_gamma_beta_unroll_factor *
|
||||
blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y;
|
||||
constexpr int max_grid_size = 64 * 1024 / 2;
|
||||
blocks.y = std::min<unsigned int>(max_grid_size / blocks.x, blocks.y);
|
||||
Tensor dgamma_blocks;
|
||||
Tensor dbeta_blocks;
|
||||
T * dgamma_blocks_ptr = nullptr;
|
||||
T * dbeta_blocks_ptr = nullptr;
|
||||
if (dgamma->defined()) {
|
||||
auto options = dgamma->options();
|
||||
dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
|
||||
dgamma_blocks_ptr = dgamma_blocks.data_ptr<T>();
|
||||
}
|
||||
|
||||
// Remainder loop
|
||||
int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor;
|
||||
for (int ii = 0; ii < unroll_factor; ii++ ){
|
||||
if ((offset + ii) < M) {
|
||||
dY_reg = dY[(offset + ii) * N + j ];
|
||||
X_reg = X[(offset + ii) * N + j];
|
||||
mean_reg = mean[offset + ii];
|
||||
rstd_reg = rstd[offset + ii];
|
||||
dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg;
|
||||
db_sum += dY_reg;
|
||||
}
|
||||
if (dbeta->defined()) {
|
||||
auto options = dbeta->options();
|
||||
dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options);
|
||||
dbeta_blocks_ptr = dbeta_blocks.data_ptr<T>();
|
||||
}
|
||||
LaunchAndCheckGammaBetaBackwardKernel<T, T_ACC, block_dim_x, block_dim_y, rows_per_block_y, true>(
|
||||
aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr);
|
||||
|
||||
// Do the final reduction in shared memory
|
||||
s_dg = s_data_typed;
|
||||
s_db = s_data_typed + blockDim.x * blockDim.y;
|
||||
s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum;
|
||||
s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum;
|
||||
__syncthreads();
|
||||
|
||||
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
s_dg[threadIdx.y * blockDim.x + threadIdx.x] +=
|
||||
s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
|
||||
s_db[threadIdx.y * blockDim.x + threadIdx.x] +=
|
||||
s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.y == 0) {
|
||||
if (dg) {
|
||||
dg[j] = s_dg[threadIdx.x];
|
||||
}
|
||||
if (db) {
|
||||
db[j] = s_db[threadIdx.x];
|
||||
}
|
||||
*dgamma = dgamma_blocks.sum(0);
|
||||
*dbeta = dbeta_blocks.sum(0);
|
||||
} else {
|
||||
// We are in the normal case where M is not that large.
|
||||
// We can change the tile shape (which is the last template parameter) in accordance with M.
|
||||
// For small M it is faster to have a smaller tile, otherwise we could have idle threads.
|
||||
// For larger M we use a bigger tile size.
|
||||
if (M < 64) {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 1, 8>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
} else if (M < 128) {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 8, 64>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
} else if (M < 256) {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 16, 128>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
} else {
|
||||
ConfigureAndLaunchGammaBetaBackwardKernel<T, T_ACC, block_dim_x, 32, 256>(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1250,6 +1424,7 @@ void LayerNormBackwardKernelImplInternal(
|
||||
dgamma->defined() ? dgamma->template data_ptr<T>() : nullptr;
|
||||
T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T>() : nullptr;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
if (M < 128) {
|
||||
// For small batch size, do colwise reduce directly.
|
||||
const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads;
|
||||
@ -1265,7 +1440,6 @@ void LayerNormBackwardKernelImplInternal(
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
#if defined(USE_ROCM)
|
||||
// For small batch size, do colwise reduce directly.
|
||||
const int part_size = warp_size;
|
||||
const dim3 threads2(warp_size, 4, 1);
|
||||
@ -1300,47 +1474,11 @@ void LayerNormBackwardKernelImplInternal(
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) {
|
||||
// This implementation relies on warp primitives and requires that M and N divide
|
||||
// exactly to warp size.
|
||||
dim3 threads{kWarpSize, kWarpSize};
|
||||
int blocks = (N + threads.x - 1) / threads.x;
|
||||
|
||||
// If M and N divide by warp_size, we can use warp shuffles for the final reduction.
|
||||
// That requires transposing values in shared memory, so we apply a padding to
|
||||
// reduce bank conflicts.
|
||||
|
||||
size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y;
|
||||
GammaBetaBackwardCUDAKernel_32x32<T, T_ACC>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
dim3 threads{16, 32};
|
||||
int blocks = (N + threads.x - 1) / threads.x;
|
||||
size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y;
|
||||
GammaBetaBackwardCUDAKernel<T, T_ACC>
|
||||
<<<blocks, threads, shmem_sz, cuda_stream>>>(
|
||||
M,
|
||||
N,
|
||||
dY_data,
|
||||
X_data,
|
||||
mean_data,
|
||||
rstd_data,
|
||||
dgamma_data,
|
||||
dbeta_data);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
LaunchGammaBetaBackwardCUDAKernel(
|
||||
dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7195,6 +7195,26 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False)
|
||||
self.assertEqual(ln.forward(x), torch.zeros_like(x))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
|
||||
def test_layer_norm_backwards_eps(self):
|
||||
dtype = torch.float
|
||||
m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55),
|
||||
(32, 32), (1024, 32), (1024, 1024),
|
||||
(33, 33), (1025, 33), (1025, 1025)]
|
||||
for m, n in m_x_n_list:
|
||||
x = torch.randn((m, n), dtype=dtype, requires_grad=True)
|
||||
grad_output = torch.rand_like(x)
|
||||
x_cuda = x.clone().detach().to("cuda").requires_grad_()
|
||||
grad_output_cuda = grad_output.clone().detach().to("cuda")
|
||||
ln = nn.LayerNorm(n, dtype=dtype)
|
||||
ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype)
|
||||
ln_out = ln(x)
|
||||
ln_out_cuda = ln_cuda(x_cuda)
|
||||
ln_out.backward(grad_output)
|
||||
ln_out_cuda.backward(grad_output_cuda)
|
||||
self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
|
||||
self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4)
|
||||
|
||||
@largeTensorTest("40GB", device="cuda")
|
||||
def test_layer_norm_large_tensor(self):
|
||||
# test for https://github.com/pytorch/pytorch/issues/136291
|
||||
|
Reference in New Issue
Block a user