diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 4b66b30b62e7..d58d436c511d 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,6 +158,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); + m.impl("_fused_rms_norm", native::rms_norm_composite); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index bdb169e26b14..f765b515cd0b 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -50,7 +50,7 @@ bool can_vectorize(const T * ptr, int alignment) { }; -template +template __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -84,12 +84,17 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); - mean[i] = m1; - rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + if constexpr (!rms_norm){ + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + } else { + rstd[i] = c10::cuda::compat::rsqrt(m2 + m1 * m1 + eps); + } + } } -template +template __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -103,11 +108,15 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; + if constexpr (!rms_norm){ + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; + } else { + Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; + } } } @@ -119,40 +128,48 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; -template __device__ +template __device__ WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + if constexpr (!rms_norm){ + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + } else{ + return {0.f, curr_sum.sigma2 + val * val, 0}; + } } -__device__ +template __device__ WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { - using U = decltype(dataB.count); - U delta = dataB.mean - dataA.mean; - U count = dataA.count + dataB.count; - U mean, sigma2; - if (count > decltype(dataB.count){0}) { - auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division - auto nA = dataA.count * coef; - auto nB = dataB.count * coef; - mean = nA*dataA.mean + nB*dataB.mean; - sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + if constexpr (!rms_norm){ + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; } else { - mean = U(0); - sigma2 = U(0); + return {0.f, dataB.sigma2 + dataA.sigma2, 0}; } - return {mean, sigma2, count}; } -template +template __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -171,14 +188,13 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ - wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), - WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; - wd = cuWelfordCombine(wd, wdB); + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -199,7 +215,7 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; - wd = cuWelfordCombine(wd, wdB); + wd = cuWelfordCombine(wd, wdB); } __syncthreads(); } @@ -216,7 +232,7 @@ __device__ WelfordDataLN compute_stats( } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -231,7 +247,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; - WelfordDataLN wd = compute_stats(block_row, N, s_data); + WelfordDataLN wd = compute_stats(block_row, N, s_data); using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -254,34 +270,48 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) - + static_cast(beta_vec[i].val[ii]); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + if constexpr (!rms_norm){ + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + } else { + out.val[ii] = rstd_val * static_cast(data.val[ii]); + } } } Y_vec[i] = out; } if (thrx == 0) { - mean[i1] = wd.mean; + if constexpr (!rms_norm){ + mean[i1] = wd.mean; + } rstd[i1] = rstd_val; } } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -296,7 +326,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors -template +template __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -306,11 +336,11 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ - vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } -template +template __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -321,7 +351,10 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; - const T_ACC mean_val = mean[i1]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[i1]; + } const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -337,26 +370,39 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } + } + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); } - - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); if (threadIdx.x == 0) { - buf[0] = stats_x1; + if constexpr (!rms_norm){ + buf[0] = stats_x1; + } buf[1] = stats_x2; } __syncthreads(); - stats_x1 = buf[0]; + if constexpr (!rms_norm){ + stats_x1 = buf[0]; + } stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -367,15 +413,20 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } + f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -387,7 +438,7 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); - compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); } @@ -396,7 +447,7 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. -template +template __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -409,7 +460,10 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; - const T_ACC mean_val = mean[bIdx]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[bIdx]; + } const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -441,8 +495,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } @@ -451,19 +509,29 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else{ + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } // Reduction in Shared Memory - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + } stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); if (threadIdx.x == 0) { - reduce_buf[0] = stats_x1; + if constexpr (!rms_norm){ + reduce_buf[0] = stats_x1; + } reduce_buf[1] = stats_x2; } __syncthreads(); - stats_x1 = reduce_buf[0]; + if constexpr (!rms_norm){ + stats_x1 = reduce_buf[0]; + } stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -485,8 +553,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -501,15 +573,19 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -525,17 +601,25 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + if constexpr (!rms_norm){ + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + } else { + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index])) * static_cast(rstd[i]); + } } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { - db[j] = sum2; + if constexpr (!rms_norm){ + db[j] = sum2; + } } } } @@ -545,7 +629,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -569,7 +654,9 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - warp_mean = mean[mean_index + lane_id]; + if constexpr (!rms_norm){ + warp_mean = mean[mean_index + lane_id]; + } warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -596,10 +683,14 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; + if constexpr (!rms_norm){ + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } else{ + dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; + } } } @@ -608,7 +699,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -629,10 +721,10 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -654,7 +746,8 @@ template __global__ void @@ -679,7 +772,7 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -687,11 +780,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -706,7 +799,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } - if (db) { + if (db && !rms_norm) { db[thread_y * N + thread_x] = db_sum; } } @@ -752,7 +845,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } - if (db) { + if (db && !rms_norm) { db[out_index] = reg_db; } } @@ -763,7 +856,8 @@ __launch_bounds__(block_dim_x * block_dim_y) template +bool partial_reduction, +bool rms_norm> void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -779,7 +873,7 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -790,7 +884,7 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -806,7 +900,7 @@ if (aligned_grid) { template +int rows_per_block_y, bool rms_norm> void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -829,16 +923,16 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } -template +template void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -876,19 +970,21 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - if (dbeta->defined()) { + if (dbeta->defined() && !rms_norm) { auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } - if (dbeta_blocks.defined()) { - *dbeta = dbeta_blocks.sum(0); + if constexpr (!rms_norm){ + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); + } } } else { // We are in the normal case where M is not that large. @@ -896,18 +992,18 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } -template +template void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -936,7 +1032,7 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; - vectorized_layer_norm_kernel<<>>(N, eps, X_data, + vectorized_layer_norm_kernel<<>>(N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -958,7 +1054,7 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; - vectorized_layer_norm_kernel<<>>(N, eps, X_data2, + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -968,7 +1064,7 @@ void launch_vectorized_layer_norm_kernel( } -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -987,7 +1083,7 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); - T_ACC* mean_data = mean->data_ptr(); + T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1002,14 +1098,14 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { - launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel + RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - LayerNormForwardCUDAKernel<<>>( + LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1037,7 +1133,29 @@ void LayerNormKernelImpl( }); } -template __device__ +void RmsNormKernelImpl( + const Tensor& X, + const Tensor& gamma, + int64_t M, + int64_t N, + double eps, + Tensor* Y, + Tensor* rstd) { +AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormKernelImpl", + [&]() { + using acc_t = acc_type; + // rms_norm = true + LayerNormKernelImplInternal( + // pass in at::Tensor() for gamma and nullptr for mean, it won't be accessed with rms_norm = True + X, gamma, at::Tensor(), M, N, static_cast(eps), Y, nullptr, rstd); + }); +} + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1055,7 +1173,10 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1080,7 +1201,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1098,7 +1219,11 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1114,7 +1239,7 @@ void cuLoadAddStridedInputs( } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1140,9 +1265,9 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); } __syncthreads(); // inter-warp reductions @@ -1181,7 +1306,7 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1206,7 +1331,9 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; - sum_beta += part_grad_beta_ptr[warp_offset*N]; + if constexpr (!rms_norm){ + sum_beta += part_grad_beta_ptr[warp_offset*N]; + } } } @@ -1224,7 +1351,9 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; - sum_beta += buf[read_idx+nbsize3]; + if constexpr (!rms_norm){ + sum_beta += buf[read_idx+nbsize3]; + } } __syncthreads(); } @@ -1235,12 +1364,14 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { - grad_beta[i2] = sum_beta; + if constexpr (!rms_norm){ + grad_beta[i2] = sum_beta; + } } } } -template __global__ +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1254,7 +1385,10 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); - T_ACC c_mean = mean[i1]; + T_ACC c_mean = 0; + if constexpr (!rms_norm){ + c_mean = mean[i1]; + } const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1267,21 +1401,31 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + if constexpr (!rms_norm){ + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1292,25 +1436,33 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[2*wrt_i] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*wrt_i] = sum_loss1; + } buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - sum_loss1 += buf[2*read_i]; + if constexpr (!rms_norm){ + sum_loss1 += buf[2*read_i]; + } sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { - buf[2*threadIdx.x] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*threadIdx.x] = sum_loss1; + } buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { - sum_loss1 = buf[2*threadIdx.x]; + if constexpr (!rms_norm){ + sum_loss1 = buf[2*threadIdx.x]; + } sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1323,8 +1475,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1333,8 +1489,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1344,7 +1504,7 @@ void cuComputeGradInput( } } -template +template void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1358,7 +1518,9 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); - TORCH_CHECK(mean.numel() == M); + if constexpr (!rms_norm){ + TORCH_CHECK(mean.numel() == M); + } TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1384,7 +1546,7 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; - cuComputeGradInput<<>>( + cuComputeGradInput<<>>( dY_data, X_data, M, N, @@ -1396,7 +1558,7 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1410,13 +1572,12 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); - if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { - layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1432,7 +1593,7 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel + GammaBetaBackwardSimpleCUDAKernel <<>>( M, N, @@ -1456,7 +1617,7 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( + cuComputePartGradGammaBeta<<>>( dY_data, X_data, M,N, @@ -1470,7 +1631,7 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - cuComputeGradGammaBeta<<>>( + cuComputeGradGammaBeta<<>>( part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1480,7 +1641,7 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else - LaunchGammaBetaBackwardCUDAKernel( + LaunchGammaBetaBackwardCUDAKernel( dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1508,8 +1669,29 @@ void LayerNormBackwardKernelImpl( }); } +void RMSNormBackwardKernelImpl( + const Tensor& dY, + const Tensor& X, + const Tensor& rstd, + const Tensor& gamma, + int64_t M, + int64_t N, + Tensor* dX, + Tensor* dgamma) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormBackwardKernelImpl", + [&]() { + LayerNormBackwardKernelImplInternal( + dY.contiguous(), X, rstd, rstd, gamma, M, N, dX, dgamma, dgamma); + }); +} + } // namespace + std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1638,6 +1820,108 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } +/* RMSNorm is implemented by reusing layer_norm's kernels */ +std::tuple _fused_rms_norm_cuda( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps){ + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + double eps_val = eps.value_or(std::numeric_limits::epsilon()); + + Tensor Y = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); + Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); + + if (M > 0) { + RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd); + } + + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); + + std::vector stat_shape; + for (const auto idx: c10::irange(axis)) { + stat_shape.push_back(input_shape[idx]); + } + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { + stat_shape.push_back(1); + } + + rstd = rstd.view(stat_shape); + + return std::make_tuple(std::move(Y), std::move(rstd)); +} + + +std::tuple _fused_rms_norm_backward_cuda( + const Tensor& dY, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt /* optional */, + std::array grad_input_mask) { + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + Tensor dX; + Tensor dgamma; + if (grad_input_mask[0]) { + dX = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[1]) { + dgamma = M > 0 ? at::native::empty_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT) + : at::native::zeros_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + if (M > 0 && N > 0) { + RMSNormBackwardKernelImpl( + dY, *X, rstd, *gamma, M, N, &dX, &dgamma); + } + return std::make_tuple(std::move(dX), std::move(dgamma)); +} + REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index da6bb5fec39e..207f092a676a 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,30 +261,11 @@ std::tuple math_native_layer_norm( return outputs; } -Tensor rms_norm_symint( +std::tuple rms_norm_composite( const Tensor& input, - c10::SymIntArrayRef normalized_shape, + IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - _check_rms_norm_inputs_symint(input, normalized_shape, weight); - -#ifdef USE_MPS - if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { - const Tensor weight = weight_opt.value(); - const bool any_nested = input.is_nested() || weight.is_nested(); - const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); - const bool is_input_fp = isFloatingType(input.scalar_type()); - const bool is_weight_fp = isFloatingType(weight.scalar_type()); - - if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { - auto eps_val = eps.value_or(std::numeric_limits::epsilon()); - return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); - } - } -#endif std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -321,10 +302,60 @@ Tensor rms_norm_symint( upcasted_result = upcasted_result.mul(weight_opt.value()); } - return upcasted_result; + // if nested do not make contiguous + if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous()); }); - - return result.type_as(input); - + return std::make_tuple( + std::get<0>(result).type_as(input), // Cast normalized result to original input type + std::get<1>(result) // rsqrt_val + ); } + + +Tensor rms_norm_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + const std::optional eps) { + + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + + // composite fallback for channels last + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + + // composite fallback for complex datatypes + if(input.is_complex()){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + + #ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + + if (!(GradMode::is_enabled() && any_inputs_require_grad)) { + return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + } + + if (input.device().type() == DeviceType::MPS){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + #endif + + return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); +} + } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0181f35fd6ed..0debe942dd0a 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,6 +106,12 @@ void layer_norm_cpu_out( int64_t M, int64_t N); +std::tuple rms_norm_composite( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps); + Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 71128297d5bf..7948b5acd8e9 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,7 +19,14 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary(); #include #endif -Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { +std::tuple _fused_rms_norm_mps(const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt, + const std::optional eps) { + const Tensor weight = weight_opt.value().contiguous(); + const int64_t normalized_ndim = normalized_shape.size(); + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -41,7 +48,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; - mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); + mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -58,7 +65,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c } }); - return output; + return std::make_tuple(output, Tensor()); } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 79b7e07e2284..ce13e03fb9f6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3314,9 +3314,15 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint -- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor +- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) dispatch: + CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps + CompositeImplicitAutograd: rms_norm_composite + +- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + CUDA: _fused_rms_norm_backward_cuda - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 042959c22cd4..a590713ad0f8 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -374,7 +374,6 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional -aten::_fused_rms_norm aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index d6cf2df4343f..5a962dfa57c0 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -139,6 +139,8 @@ ALLOW_LIST = [ # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp # TODO: add back restriction when c10d ops can be exported ("c10d::.*", datetime.date(9999, 1, 1)), + # Previously MPS_only did not support backward + ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_decomp.py b/test/test_decomp.py index 5d641e32e422..dcd6e69af997 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor -from torch.testing._internal.common_cuda import tf32_off +from torch.testing._internal.common_cuda import SM70OrLater, tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -1226,6 +1226,33 @@ class DecompOneOffTests(TestCase): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) + @onlyCUDA + @unittest.skipIf(not SM70OrLater, "triton") + def test_rms_norm_decomp_cuda(self, device): + @torch.compile + def rms_norm_sinh(a, b, c): + output = torch.nn.functional.rms_norm(a, b, c) + return torch.sinh(output) + + normalized_shape_arg = (3, 3, 3) + input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + + def forward_pass_fn(): + return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) + + model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( + forward_pass_fn + ) + + # check RMSNorm was fused with sinh + self.assertTrue( + "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] + ) + self.assertTrue( + "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + ) + instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e2419aab268b..f0349c2484b6 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,6 +1267,11 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") +- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" + result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) + result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) + - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index abb94b109cc0..8e9796d2f7c1 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,6 +418,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index f93a0bf84fb4..832928ebf8ae 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1743,6 +1743,81 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm_backward.default) +def _fused_rms_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + rstd: Tensor, + weight: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + + grad_out_cast = grad_out.to( + computation_dtype, memory_format=torch.contiguous_format + ) + input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) + weight_cast = ( + weight.to(computation_dtype, memory_format=torch.contiguous_format) + if weight is not None + else None + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + ) + + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + + x_hat = input_cast * rstd + + if output_mask[0]: + sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) + d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd + + if output_mask[1] and weight_cast is not None: + d_weight_full_shape = grad_out_cast * x_hat + if len(outer_dim_indices) > 0: + d_weight = torch.sum( + d_weight_full_shape, dim=outer_dim_indices, keepdim=False + ) + else: + d_weight = d_weight_full_shape + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + ) + + def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 908a980cfee9..8e13d4267edb 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5023,6 +5023,103 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask) { + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int normalized_ndim = normalized_shape.size(); + const int axis = input_ndim - normalized_ndim; + + int64_t N_rms = 1; + for (int i = 0; i < normalized_ndim; ++i) { + N_rms *= input_shape[axis + i]; + } + + Tensor dX; + Tensor dgamma; + + std::vector rstd_view_shape = rstd.sizes().vec(); + for (int i = 0; + i < std::max(static_cast(normalized_ndim - rstd.dim()), 0); + ++i) { + rstd_view_shape.push_back(1); + } + Tensor rstd_broadcast = rstd.view(rstd_view_shape); + Tensor rstd_pow3 = rstd_broadcast.pow(3); + Tensor grad_x_hat; + + if (dY.defined()) { + if (weight.defined()) { + grad_x_hat = dY * weight; + } else { + grad_x_hat = dY; + } + } + + if (grad_input_mask[0]) { + Tensor dX_from_dY_path; + Tensor dX_from_drstd_path; + + std::vector inner_sum_dims; + inner_sum_dims.reserve(normalized_ndim); + for (int i = 0; i < normalized_ndim; ++i) { + inner_sum_dims.push_back(axis + i); + } + + if (dY.defined() && grad_x_hat.defined()) { + Tensor sum_input_times_grad_x_hat = + sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true); + dX_from_dY_path = rstd_broadcast * grad_x_hat - + (input * rstd_pow3 / static_cast(N_rms)) * + sum_input_times_grad_x_hat; + } + + if (drstd.defined()) { + Tensor drstd_broadcast = drstd.view(rstd_view_shape); + dX_from_drstd_path = + -(input * rstd_pow3 / static_cast(N_rms)) * drstd_broadcast; + } + + if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) { + dX = dX_from_dY_path + dX_from_drstd_path; + } else if (dX_from_dY_path.defined()) { + dX = dX_from_dY_path; + } else if (dX_from_drstd_path.defined()) { + dX = dX_from_drstd_path; + } + } + + if (grad_input_mask[1] && weight.defined()) { + if (dY.defined()) { + Tensor x_hat = input * rstd_broadcast; + Tensor dgamma_full_shape = dY * x_hat; + + if (axis > 0) { + std::vector outer_sum_dims; + outer_sum_dims.reserve(axis); + for (int i = 0; i < axis; ++i) { + outer_sum_dims.push_back(i); + } + dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false); + } else { + dgamma = dgamma_full_shape; + } + } + } + + return std::make_tuple(dX, dgamma); +} + std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6377,6 +6474,98 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + + Tensor result_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + result_t = (input_t)*rstd_p + (input_p)*rstd_t; + } else { + result_t = input_t * rstd_p; + auto temp = input_p * rstd_t; + result_t += temp; + } + + std::optional result_p = std::nullopt; + if (weight_p.defined()) { + result_p = std::optional(input_p * rstd_p); + } + + return _affine_jvp( + result_p, + result_t, + weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, + weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, + Tensor()); +} + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + return rstd_t; +} + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 0b659973ec34..96864e165a95 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -826,6 +826,15 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask); + std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -965,6 +974,20 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2aa09cb802ec..aced2b2f539d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -29,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index e0607f984b3d..92d30ded855f 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index a5d654c51884..c76ee685c25d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 243bfb5fc87a..6fc51bd0c8f8 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -13,6 +13,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/overrides.py b/torch/overrides.py index f29ffe57e36a..2e696b2d96e4 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1,