From 12116aee6852df2b040255b8fcc7deb52b897792 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 28 Mar 2024 00:27:38 +0000 Subject: [PATCH] Add Flash Attention support on ROCM (#121561) This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton) - [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`). * MI300X is supported. More architectures will be added once Triton support them. - [x] Only supports power of two sequence lengths. * Now it support arbitrary sequence length - [ ] No support for varlen APIs. * varlen API will be supported in future release of AOTriton - [x] Only support head dimension 16,32,64,128. * Now it support arbitrary head dimension <= 256 - [x] Performance is still being optimized. * Kernel is selected according to autotune information from Triton. Other improvements from AOTriton include * Allow more flexible Tensor storage layout * More flexible API This is a more extensive fix to #112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561 Approved by: https://github.com/huydhn --- CMakeLists.txt | 23 +- .../native/transformers/cuda/sdp_utils.cpp | 36 +- .../transformers/hip/flash_attn/flash_api.hip | 453 ++++++++---------- caffe2/CMakeLists.txt | 3 +- cmake/Dependencies.cmake | 4 +- cmake/External/aotriton.cmake | 28 ++ cmake/External/oort.cmake | 25 - test/test_transformers.py | 13 +- torch/testing/_internal/common_cuda.py | 9 +- 9 files changed, 266 insertions(+), 328 deletions(-) create mode 100644 cmake/External/aotriton.cmake delete mode 100644 cmake/External/oort.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f95efeeffd8..7adeac323a91 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -742,13 +742,28 @@ if(MSVC) append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS) endif() -# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until -# include(cmake/Dependencies.cmake) +# Note for ROCM platform: +# 1. USE_ROCM is always ON until include(cmake/Dependencies.cmake) +# 2. USE_CUDA will become OFF during re-configuration +# Truth Table: +# CUDA 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default +# CUDA 2nd pass: USE_CUDA=True;USE_ROCM=False, FLASH evaluates to ON by default +# ROCM 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default +# ROCM 2nd pass: USE_CUDA=False;USE_ROCM=True, FLASH evaluates to ON by default +# CPU 1st pass: USE_CUDA=False(Cmd Option);USE_ROCM=True, FLASH evaluates to OFF by default +# CPU 2nd pass: USE_CUDA=False(Cmd Option);USE_ROCM=False, FLASH evaluates to OFF by default +# Thus we cannot tell ROCM 2nd pass and CPU 1st pass +# +# The only solution is to include(cmake/Dependencies.cmake), and defer the +# aotriton build decision later. + +include(cmake/Dependencies.cmake) + cmake_dependent_option( USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA AND NOT MSVC" OFF) + "USE_CUDA OR USE_ROCM;NOT MSVC" OFF) # We are currenlty not using alibi attention for Flash # So we disable this feature by default @@ -764,8 +779,6 @@ cmake_dependent_option( Will be disabled if not supported by the platform" ON "USE_CUDA" OFF) -include(cmake/Dependencies.cmake) - if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index e2ea560b6afc..96b839820efd 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -21,6 +21,10 @@ #include #include +#if USE_ROCM +#include +#endif + /** * Note [SDPA Runtime Dispatch] * SDPA relies on a runtime dispatch mechanism to select the appropriate @@ -182,32 +186,18 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug // Check that the gpu is capable of running flash attention using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; - auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM - constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-"; - static const char *over_arch = [] { - auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE"); - if (rc) { - TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. " - "Later changes to this environment variable with os.environ " - "(or other methods) will not affect SDPA function's behavior."); - } - return rc; - }(); - const char* real_arch = dprops->gcnArchName; - const char* arch = over_arch ? over_arch : real_arch; - if (mi200 != arch) { - if (debug) { - TORCH_WARN( - "Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ", - arch, - ".", - over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "", - over_arch ? real_arch : ""); - } - return false; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (debug) { + TORCH_WARN( + "Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName); + } + return false; } #else + auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 24eebee7a75a..9a43404f5d33 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -59,42 +59,141 @@ #include #include -// OORT headers -#include -#include -#include -#include +// AOTriton headers +#include +#include +#include +#include namespace pytorch_flash { namespace { -c10::once_flag fa_gcn_arch_override_flag; -const char* fa_override_arch = nullptr; - -void init_fa_override_arch() { - fa_override_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE"); - if (fa_override_arch) { - TORCH_WARN("ROCM flash attention backend only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. " - "Later changes to this environment variable with os.environ " - "(or other methods) will not affect this backend's behavior."); +void check_gpu_arch(hipStream_t stream) { + auto ret = aotriton::v2::flash::check_gpu(stream); + if (hipSuccess != ret) { + TORCH_CHECK(false, + "FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)") } } -void check_gpu_arch() { - auto dprops = at::cuda::getCurrentDeviceProperties(); - - constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-"; - c10::call_once(fa_gcn_arch_override_flag, init_fa_override_arch); - if (fa_override_arch) { - TORCH_CHECK(mi200 == fa_override_arch, - "FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName) + " override as " + fa_override_arch); - } else { - TORCH_CHECK(mi200 == dprops->gcnArchName, - "FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName)); - } +aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype) +{ +#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname + CAST_TYPE(kByte, kUInt8); + CAST_TYPE(kUInt16, kUInt16); + CAST_TYPE(kUInt32, kUInt32); + CAST_TYPE(kUInt64, kUInt64); + CAST_TYPE(kChar, kInt8); + CAST_TYPE(kShort, kInt16); + CAST_TYPE(kInt, kInt32); + CAST_TYPE(kLong, kInt64); + CAST_TYPE(kHalf, kFloat16); + CAST_TYPE(kFloat, kFloat32); + CAST_TYPE(kBFloat16, kBFloat16); + return aotriton::DType::kUnknown; +#undef CAST_TYPE } +template +struct IntArrayRefCaster { + // std::array cast(IntArrayRef); +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ static_cast(ref.at(0)) }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)) + }}; + } +}; + +template +struct IntArrayRefCaster { + static auto cast(at::IntArrayRef ref) { + return std::array{{ + static_cast(ref.at(0)), + static_cast(ref.at(1)), + static_cast(ref.at(2)), + static_cast(ref.at(3)) + }}; + } +}; + + +template +aotriton::TensorView mk_aotensor(const at::Tensor& q, c10::string_view tensor_name) +{ + const auto strides = q.strides(); + int real_rank = strides.size(); + if (real_rank != Rank) { // Lazy convertion of tensor_name + TORCH_CHECK(false, + std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + + " but is " + std::to_string(real_rank)); + } + return aotriton::TensorView(reinterpret_cast(q.data_ptr()), + IntArrayRefCaster::cast(q.sizes()), + IntArrayRefCaster::cast(strides), + cast_dtype(q.dtype())); +} + +template // For Output Tensor +class TensorStorageSanitizer { +public: + TensorStorageSanitizer(const at::Tensor& ref, + at::Tensor& to_sanitize) + : ref_(ref), to_sanitize_(to_sanitize) + { + need_sanitize = ref_.strides() != to_sanitize_.strides(); + if (!need_sanitize) + return; + + temp_ = at::empty_like(ref_); + if (COPY_FROM_INPUT) { + temp_.copy_(to_sanitize_); + } + } + + ~TensorStorageSanitizer() + { + if (need_sanitize && COPY_BACK) + to_sanitize_.copy_(temp_); + } + + at::Tensor& sanitized_tensor() + { + if (need_sanitize) + return temp_; + return to_sanitize_; + } +private: + const at::Tensor& ref_; + at::Tensor& to_sanitize_; + at::Tensor temp_; + bool need_sanitize = false; +}; + } #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") @@ -114,7 +213,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head int window_size_right, const bool return_softmax, c10::optional gen_) { - check_gpu_arch(); + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, @@ -206,102 +306,51 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head seed_t = at::empty({}, at::dtype(at::kLong)); offset_t = at::empty({}, at::dtype(at::kLong)); } - } - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); - //reorder tensors and make contiguous - at::Tensor q_t = q_padded.permute({0,2,1,3}).contiguous(); - at::Tensor k_t = k_padded.permute({0,2,1,3}).contiguous(); - at::Tensor v_t = v_padded.permute({0,2,1,3}).contiguous(); - at::Tensor output_t = out.permute({0,2,1,3}).contiguous(); + at::PhiloxCudaState philox_args; + if (p_dropout > 0.0) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) + { + philox_args = at::PhiloxCudaState(*seed_t.data_ptr(), *offset_t.data_ptr()); + } else { // dropout + capture + philox_args = at::PhiloxCudaState(seed_t.data_ptr(), offset_t.data_ptr(), 0); + } + } - at::Tensor M = at::empty({batch_size, num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse + // Transpose tensors to meet AOTriton's Flash API + at::Tensor q_t = q_padded.permute({0,2,1,3}); + at::Tensor k_t = k_padded.permute({0,2,1,3}); + at::Tensor v_t = v_padded.permute({0,2,1,3}); + at::Tensor output_t = out.permute({0,2,1,3}); - constexpr int BLOCK_M = 16; - constexpr int BLOCK_N = 16; - dim3 grid; - grid.x = (q_t.sizes()[2] + BLOCK_M - 1) / BLOCK_M; - grid.y = q_t.sizes()[0] * q_t.sizes()[1]; - grid.z = 1; - dim3 block { 64 * 4, 1, 1 }; // compiled triton kernel intrinsic + at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse - at::Tensor softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, - at::dtype(q.dtype()).device(q.device())); + at::Tensor softmax_fa_t; + if (return_softmax) { + softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, + at::dtype(q.dtype()).device(q.device())); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); + } hipError_t err; // TODO: Error handling -#define CALL_FWD(FP, STAGE, BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX) \ - do { \ - oort::attn_fwd fwd_opt; \ - err = fwd_opt(grid, block, \ - (FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \ - softmax_scale, (float*)M.data_ptr(), (FP*)output_t.data_ptr(), \ - q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \ - k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \ - v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \ - output_t.stride(0), output_t.stride(1), output_t.stride(2), output_t.stride(3), \ - q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \ - *(uint64_t*)(seed_t.data_ptr()), *(uint32_t*)(offset_t.data_ptr()), \ - (FP*)(softmax_fa_t.data_ptr()), \ - stream); \ - } while(0) - - // TODO: Ugly but works - constexpr int kFwdUseCausal = 3; - constexpr int kFwdNoCausal = 1; - int d_head = q_t.sizes()[3]; - constexpr int BM = BLOCK_M; - constexpr int BN = BLOCK_N; - if (q_dtype == at::kHalf) { - if (is_causal) { - if (d_head == 16) - CALL_FWD(__fp16,kFwdUseCausal,BM,16,BN,true,true,true); - else if (d_head == 32) - CALL_FWD(__fp16,kFwdUseCausal,BM,32,BN,true,true,true); - else if (d_head == 64) - CALL_FWD(__fp16,kFwdUseCausal,BM,64,BN,true,true,true); - else if (d_head == 128) - CALL_FWD(__fp16,kFwdUseCausal,BM,128,BN,true,true,true); - } else { - if (d_head == 16) - CALL_FWD(__fp16,kFwdNoCausal,BM,16,BN,true,true,true); - else if (d_head == 32) - CALL_FWD(__fp16,kFwdNoCausal,BM,32,BN,true,true,true); - else if (d_head == 64) - CALL_FWD(__fp16,kFwdNoCausal,BM,64,BN,true,true,true); - else if (d_head == 128) - CALL_FWD(__fp16,kFwdNoCausal,BM,128,BN,true,true,true); - } - } else if (q_dtype == at::kBFloat16) { - if (is_causal) { - if (d_head == 16) - CALL_FWD(__bf16,kFwdUseCausal,BM,16,BN,true,true,true); - else if (d_head == 32) - CALL_FWD(__bf16,kFwdUseCausal,BM,32,BN,true,true,true); - else if (d_head == 64) - CALL_FWD(__bf16,kFwdUseCausal,BM,64,BN,true,true,true); - else if (d_head == 128) - CALL_FWD(__bf16,kFwdUseCausal,BM,128,BN,true,true,true); - } else { - if (d_head == 16) - CALL_FWD(__bf16,kFwdNoCausal,BM,16,BN,true,true,true); - else if (d_head == 32) - CALL_FWD(__bf16,kFwdNoCausal,BM,32,BN,true,true,true); - else if (d_head == 64) - CALL_FWD(__bf16,kFwdNoCausal,BM,64,BN,true,true,true); - else if (d_head == 128) - CALL_FWD(__bf16,kFwdNoCausal,BM,128,BN,true,true,true); - } - } - - //undo reorder tensors - q_padded = q_t.permute({0,2,1,3}).contiguous(); - k_padded = k_t.permute({0,2,1,3}).contiguous(); - v_padded = v_t.permute({0,2,1,3}).contiguous(); - out = output_t.permute({0,2,1,3}).contiguous(); + using aotriton::v2::flash::attn_fwd; + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(output_t, "Out"), + p_dropout, + philox_args.seed_.val, + philox_args.offset_.val, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; -#undef CALL_FWD } std::tuple @@ -354,10 +403,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { - check_gpu_arch(); + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); bool is_dropout = p_dropout > 0.0; - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, @@ -440,23 +489,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si // const at::Tensor& dout_padded = dout; - // bool loop = seqlen_k > blocksize_c; - // TODO: change later, for now set to true for simplicity - bool loop = true; - // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto opts = q.options(); auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - at::Tensor dk_accum, dv_accum; - if (loop) { - dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - // dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - // dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - } at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA @@ -468,149 +506,52 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } at::PhiloxCudaState philox_args; - if (is_dropout) { + if (p_dropout > 0.0) { if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); } else { // dropout + capture - philox_args = at::PhiloxCudaState( - philox_seed.data_ptr(), philox_offset.data_ptr(), 0); + philox_args = at::PhiloxCudaState(philox_seed.data_ptr(), philox_offset.data_ptr(), 0); } } - //JCG TODO WE GO IN HERE TODO backwards - //reorder tensors and make contiguous - at::Tensor q_t = q.permute({0,2,1,3}).contiguous(); - at::Tensor k_t = k.permute({0,2,1,3}).contiguous(); - at::Tensor v_t = v.permute({0,2,1,3}).contiguous(); - at::Tensor out_t = out.permute({0,2,1,3}).contiguous(); + at::Tensor q_t = q.permute({0,2,1,3}); + at::Tensor k_t = k.permute({0,2,1,3}); + at::Tensor v_t = v.permute({0,2,1,3}); + at::Tensor out_t = out.permute({0,2,1,3}); + at::Tensor dq_t = dq.permute({0,2,1,3}); + at::Tensor dk_t = dk.permute({0,2,1,3}); + at::Tensor dv_t = dv.permute({0,2,1,3}); + at::Tensor dout_t = dout.permute({0,2,1,3}); - //reorder tensors and make contiguous - at::Tensor dq_t = dq.permute({0,2,1,3}).contiguous(); - at::Tensor dk_t = dk.permute({0,2,1,3}).contiguous(); - at::Tensor dv_t = dv.permute({0,2,1,3}).contiguous(); - at::Tensor dout_t = dout.permute({0,2,1,3}).contiguous(); - - dim3 block { 64 * 4, 1, 1 }; - - at::Tensor new_do = at::empty_like(dout_t).contiguous(); + at::Tensor softmax_lse_cont = softmax_lse.contiguous(); at::Tensor delta = at::empty_like(softmax_lse).contiguous(); int d_head = head_size_og; hipError_t err; // TODO: Error handling -#define CALL_BWD_PP(FP, PP_BLOCK, PP_DMODEL) \ - do { \ - dim3 pp_grid; \ - pp_grid.x = batch_size * num_heads * ((dout_t.size(2) + PP_BLOCK - 1) / PP_BLOCK); \ - pp_grid.y = 1; \ - pp_grid.z = 1; \ - oort::bwd_preprocess pre_opt; \ - err = pre_opt(pp_grid, block, \ - (FP*)(out_t.data_ptr()), \ - (FP*)(dout_t.data_ptr()), \ - (FP*)(new_do.data_ptr()), \ - (float*)(delta.data_ptr()), \ - stream); \ - } while (0) - -#define CALL_BWD_PP_DMODEL(FP, PP_BLOCK) \ - do { \ - if (d_head == 16) \ - CALL_BWD_PP(FP, PP_BLOCK, 16); \ - else if (d_head == 32) \ - CALL_BWD_PP(FP, PP_BLOCK, 32); \ - else if (d_head == 64) \ - CALL_BWD_PP(FP, PP_BLOCK, 64); \ - else if (d_head == 128) \ - CALL_BWD_PP(FP, PP_BLOCK, 128); \ - } while (0) - - if(q_dtype == at::kHalf) { - if (seqlen_q >= 64) - CALL_BWD_PP_DMODEL(__fp16, 16); - else - CALL_BWD_PP_DMODEL(__fp16, 16); - } else if (q_dtype == at::kBFloat16) { - if (seqlen_q >= 64) - CALL_BWD_PP_DMODEL(__bf16, 16); - else - CALL_BWD_PP_DMODEL(__bf16, 16); + { + TensorStorageSanitizer dq_s(q_t, dq_t); + TensorStorageSanitizer dk_s(k_t, dk_t); + TensorStorageSanitizer dv_s(v_t, dv_t); + using aotriton::v2::flash::attn_bwd; + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_s.sanitized_tensor(), "dq"), + mk_aotensor(dk_s.sanitized_tensor(), "dk"), + mk_aotensor(dv_s.sanitized_tensor(), "dv"), + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + philox_args.seed_.val, + philox_args.offset_.val, + is_causal, + stream); } -#undef CALL_BWD_PP - -#define CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT) \ - do { \ - dim3 grid; \ - grid.x = (seqlen_k + BLOCK_M - 1) / BLOCK_M; \ - grid.y = batch_size * num_heads; \ - grid.z = 1; \ - oort::bwd_kernel_dk_dv dk_dv_opt; \ - err = dk_dv_opt(grid, block, \ - (FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \ - softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \ - (FP*)dk_t.data_ptr(),(FP*)dv_t.data_ptr(), \ - (float*)(softmax_lse.data_ptr()), \ - (float*)(delta.data_ptr()), \ - q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \ - k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \ - v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \ - q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \ - (uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \ - grid.x = (seqlen_q + BLOCK_M - 1) / BLOCK_M; \ - oort::bwd_kernel_dq dq_opt; \ - err = dq_opt(grid, block, \ - (FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \ - softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \ - (FP*)dq_t.data_ptr(), \ - (float*)(softmax_lse.data_ptr()), \ - (float*)(delta.data_ptr()), \ - q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \ - k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \ - v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \ - q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \ - (uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \ - } while(0) - -#define CALL_BWD_DROPOUT(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL) \ - do { \ - if (p_dropout > 0.0) { \ - CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, true); \ - } else { \ - CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, false); \ - } \ - } while (0) - -#define CALL_BWD_DROPOUT_DMODEL(FP, BLOCK_M, BLOCK_N, CAUSAL) \ - do { \ - if (d_head == 16) \ - CALL_BWD_DROPOUT(FP, BLOCK_M, 16, BLOCK_N, CAUSAL); \ - else if (d_head == 32) \ - CALL_BWD_DROPOUT(FP, BLOCK_M, 32, BLOCK_N, CAUSAL); \ - else if (d_head == 64) \ - CALL_BWD_DROPOUT(FP, BLOCK_M, 64, BLOCK_N, CAUSAL); \ - else if (d_head == 128) \ - CALL_BWD_DROPOUT(FP, BLOCK_M, 128, BLOCK_N, CAUSAL); \ - } while (0) - - if (q_dtype == at::kHalf) { - if (is_causal) { - CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, true); - } else { - CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, false); - } - } else if (q_dtype == at::kBFloat16) { - if (is_causal) { - CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, true); - } else { - CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, false); - } - } - - //undo reorder tensors for returns - dq = dq_t.permute({0,2,1,3}).contiguous(); - dk = dk_t.permute({0,2,1,3}).contiguous(); - dv = dv_t.permute({0,2,1,3}).contiguous(); // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 5d3922fea6fc..5467ece5ed8d 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -985,7 +985,8 @@ if(USE_ROCM) list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA}) hip_add_library(torch_hip ${Caffe2_HIP_SRCS}) if(USE_FLASH_ATTENTION) - target_link_libraries(torch_hip PRIVATE __caffe2_oort) + target_link_libraries(torch_hip PRIVATE __caffe2_aotriton) + add_dependencies(torch_hip aotriton_external) endif() set(CUDA_LINK_LIBRARIES_KEYWORD) torch_compile_options(torch_hip) # see cmake/public/utils.cmake diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 892bad591887..a96075245aed 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1335,9 +1335,7 @@ if(USE_ROCM) message(STATUS "Disabling Kernel Assert for ROCm") endif() - if(USE_FLASH_ATTENTION) - include(${CMAKE_CURRENT_LIST_DIR}/External/oort.cmake) - endif() + include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) if(USE_CUDA) caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) endif() diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake new file mode 100644 index 000000000000..ca9725451049 --- /dev/null +++ b/cmake/External/aotriton.cmake @@ -0,0 +1,28 @@ +if(NOT __AOTRITON_INCLUDED) + set(__AOTRITON_INCLUDED TRUE) + + set(__AOTRITON_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/src") + set(__AOTRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/build") + set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") + ExternalProject_Add(aotriton_external + GIT_REPOSITORY https://github.com/ROCm/aotriton.git + GIT_TAG 9044fe5eb16130e49a0a1f781ea15037353ad542 + SOURCE_DIR ${__AOTRITON_SOURCE_DIR} + BINARY_DIR ${__AOTRITON_BUILD_DIR} + PREFIX ${__AOTRITON_INSTALL_DIR} + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR} + -DAOTRITON_COMPRESS_KERNEL=OFF + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DAOTRITON_NO_PYTHON=ON + -DAOTRITON_NO_SHARED=ON + # CONFIGURE_COMMAND "" + # BUILD_COMMAND ${MAKE_COMMAND} + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a" + # INSTALL_COMMAND ${MAKE_COMMAND} install + ) + set(AOTRITON_FOUND TRUE) + add_library(__caffe2_aotriton INTERFACE) + add_dependencies(__caffe2_aotriton aotriton_external) + target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a) + target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) +endif() # __AOTRITON_INCLUDED diff --git a/cmake/External/oort.cmake b/cmake/External/oort.cmake deleted file mode 100644 index 29c9a1005a7f..000000000000 --- a/cmake/External/oort.cmake +++ /dev/null @@ -1,25 +0,0 @@ -if(NOT __OORT_INCLUDED) - set(__OORT_INCLUDED TRUE) - - set(__OORT_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/src") - set(__OORT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/build") - set(__OORT_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch") - ExternalProject_Add(oort_external - GIT_REPOSITORY https://github.com/ROCmSoftwarePlatform/triton.git - GIT_TAG 29e1252c1ac8e6a54deb883701e553e5b201a1ba - SOURCE_DIR ${__OORT_SOURCE_DIR} - SOURCE_SUBDIR mathaot - BINARY_DIR ${__OORT_BUILD_DIR} - PREFIX ${__OORT_INSTALL_DIR} - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__OORT_INSTALL_DIR} - # CONFIGURE_COMMAND "" - # BUILD_COMMAND ${MAKE_COMMAND} - BUILD_BYPRODUCTS "${__OORT_INSTALL_DIR}/lib/liboort.a" - # INSTALL_COMMAND ${MAKE_COMMAND} install - ) - set(OORT_FOUND TRUE) - add_library(__caffe2_oort INTERFACE) - add_dependencies(__caffe2_oort oort_external) - target_link_libraries(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/lib/liboort.a) - target_include_directories(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/include) -endif() # __OORT_INCLUDED diff --git a/test/test_transformers.py b/test/test_transformers.py index f0a6278bc2db..b716104d0d1e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2396,7 +2396,7 @@ class TestSDPACudaOnly(NNTestCase): # Cast up and compare self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) - @skipIfRocm # Small matrices + @skipIfRocm # TODO: Packed QKV @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) @parametrize("is_causal", [True, False]) @@ -2798,16 +2798,6 @@ class TestSDPACudaOnly(NNTestCase): def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, scale: str): - if TEST_WITH_ROCM: - def is_power_of_2(n): - return n & (n - 1) == 0 - if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim): - self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.") - if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16: - self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.") - if head_dim > 128: - self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.") - if isSM8XDevice and head_dim in range(193, 256 + 1): self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: @@ -3428,6 +3418,7 @@ class TestAttnBias(NNTestCase): self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) + @skipIfRocm # CausalVariant @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize( "shape", diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 13abf02f1e60..2a9055597f73 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -31,18 +31,19 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)) SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)) -def evaluate_gfx90a_exact(): +def evaluate_gfx_arch_exact(matching_arch): if not torch.cuda.is_available(): return False gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name) - return arch == 'gfx90a:sramecc+:xnack-' + return arch == matching_arch -GFX90A_Exact = LazyVal(lambda: evaluate_gfx90a_exact()) +GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-')) +GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')) def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - return evaluate_gfx90a_exact() + return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-') if TEST_CUDA: return not IS_WINDOWS and SM80OrLater return False