[ROCm][layer_norm] Use __builtin_amdgcn_rcpf(x) instead of 1.f/x (#165589)

Replace (more) exact calculation with hardware approximation.

Benefits:
Reduced code size.
Improved performance for certain scenarios.

Experiments show low reduction in precision.
Experiments show no significant performance regressions. bfloat16 as well as float16 related calculations may benefit largely from this change.

Co-author: @mhalk @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165589
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jerry Mannil
2025-10-17 09:12:27 +00:00
committed by PyTorch MergeBot
parent 9fe3b2afbe
commit 202f83dc4e
4 changed files with 29 additions and 5 deletions

View File

@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){ if constexpr (!rms_norm){
U delta = val - curr_sum.mean; U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f; U new_count = curr_sum.count + 1.f;
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
#endif
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
} else{ } else{
return {0.f, curr_sum.sigma2 + val * val, 0}; return {0.f, curr_sum.sigma2 + val * val, 0};
@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count; U count = dataA.count + dataB.count;
U mean, sigma2; U mean, sigma2;
if (count > decltype(dataB.count){0}) { if (count > decltype(dataB.count){0}) {
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
auto coef = __builtin_amdgcn_rcpf(count);
#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
#endif
auto nA = dataA.count * coef; auto nA = dataA.count * coef;
auto nB = dataB.count * coef; auto nB = dataB.count * coef;
mean = nA*dataA.mean + nB*dataB.mean; mean = nA*dataA.mean + nB*dataB.mean;

View File

@ -1044,6 +1044,17 @@ if(USE_ROCM)
list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling) list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling)
endif(CMAKE_BUILD_TYPE MATCHES Debug) endif(CMAKE_BUILD_TYPE MATCHES Debug)
# Get EnVar 'USE_LAYERNORM_FAST_RECIPROCAL' (or default to on).
if(DEFINED ENV{USE_LAYERNORM_FAST_RECIPROCAL})
set(USE_LAYERNORM_FAST_RECIPROCAL $ENV{USE_LAYERNORM_FAST_RECIPROCAL})
else()
set(USE_LAYERNORM_FAST_RECIPROCAL ON)
endif()
if(USE_LAYERNORM_FAST_RECIPROCAL)
add_definitions(-DUSE_LAYERNORM_FAST_RECIPROCAL)
endif()
# needed for compat with newer versions of hip-clang that introduced C++20 mangling rules # needed for compat with newer versions of hip-clang that introduced C++20 mangling rules
list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17) list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17)

View File

@ -128,11 +128,12 @@ function(caffe2_print_configuration_summary)
endif() endif()
message(STATUS " USE_ROCM : ${USE_ROCM}") message(STATUS " USE_ROCM : ${USE_ROCM}")
if(${USE_ROCM}) if(${USE_ROCM})
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}") message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}") message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}") message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}")
message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}") message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}")
message(STATUS " USE_LAYERNORM_FAST_RECIPROCAL : ${USE_LAYERNORM_FAST_RECIPROCAL}")
endif() endif()
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")

View File

@ -156,6 +156,10 @@
# USE_ROCM_KERNEL_ASSERT=1 # USE_ROCM_KERNEL_ASSERT=1
# Enable kernel assert in ROCm platform # Enable kernel assert in ROCm platform
# #
# USE_LAYERNORM_FAST_RECIPROCAL
# If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t.
# layer normalization. Default: enabled.
#
# USE_ROCM_CK_GEMM=1 # USE_ROCM_CK_GEMM=1
# Enable building CK GEMM backend in ROCm platform # Enable building CK GEMM backend in ROCm platform
# #