Add Flash Attention support on ROCM (#121561)

This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
This commit is contained in:
Xinya Zhang
2024-03-12 01:16:51 +00:00
committed by PyTorch MergeBot
parent 3a5f48d55f
commit a37e22de70
9 changed files with 266 additions and 328 deletions

View File

@ -736,13 +736,28 @@ if(MSVC)
append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS)
endif()
# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until
# include(cmake/Dependencies.cmake)
# Note for ROCM platform:
# 1. USE_ROCM is always ON until include(cmake/Dependencies.cmake)
# 2. USE_CUDA will become OFF during re-configuration
# Truth Table:
# CUDA 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
# CUDA 2nd pass: USE_CUDA=True;USE_ROCM=False, FLASH evaluates to ON by default
# ROCM 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
# ROCM 2nd pass: USE_CUDA=False;USE_ROCM=True, FLASH evaluates to ON by default
# CPU 1st pass: USE_CUDA=False(Cmd Option);USE_ROCM=True, FLASH evaluates to OFF by default
# CPU 2nd pass: USE_CUDA=False(Cmd Option);USE_ROCM=False, FLASH evaluates to OFF by default
# Thus we cannot tell ROCM 2nd pass and CPU 1st pass
#
# The only solution is to include(cmake/Dependencies.cmake), and defer the
# aotriton build decision later.
include(cmake/Dependencies.cmake)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
"USE_CUDA AND NOT MSVC" OFF)
"USE_CUDA OR USE_ROCM;NOT MSVC" OFF)
# We are currenlty not using alibi attention for Flash
# So we disable this feature by default
@ -758,8 +773,6 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON
"USE_CUDA" OFF)
include(cmake/Dependencies.cmake)
if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

View File

@ -21,6 +21,10 @@
#include <cmath>
#include <functional>
#if USE_ROCM
#include <aotriton/flash.h>
#endif
/**
* Note [SDPA Runtime Dispatch]
* SDPA relies on a runtime dispatch mechanism to select the appropriate
@ -182,32 +186,18 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
static const char *over_arch = [] {
auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (rc) {
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect SDPA function's behavior.");
}
return rc;
}();
const char* real_arch = dprops->gcnArchName;
const char* arch = over_arch ? over_arch : real_arch;
if (mi200 != arch) {
if (debug) {
TORCH_WARN(
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
arch,
".",
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
over_arch ? real_arch : "");
}
return false;
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (debug) {
TORCH_WARN(
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
}
return false;
}
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(

View File

@ -59,42 +59,141 @@
#include <c10/util/Exception.h>
#include <c10/util/CallOnce.h>
// OORT headers
#include <oort/attn_fwd.h>
#include <oort/bwd_kernel_dk_dv.h>
#include <oort/bwd_kernel_dq.h>
#include <oort/bwd_preprocess.h>
// AOTriton headers
#include <aotriton/dtypes.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#include <aotriton/util.h>
namespace pytorch_flash {
namespace {
c10::once_flag fa_gcn_arch_override_flag;
const char* fa_override_arch = nullptr;
void init_fa_override_arch() {
fa_override_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (fa_override_arch) {
TORCH_WARN("ROCM flash attention backend only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect this backend's behavior.");
void check_gpu_arch(hipStream_t stream) {
auto ret = aotriton::v2::flash::check_gpu(stream);
if (hipSuccess != ret) {
TORCH_CHECK(false,
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
}
}
void check_gpu_arch() {
auto dprops = at::cuda::getCurrentDeviceProperties();
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
c10::call_once(fa_gcn_arch_override_flag, init_fa_override_arch);
if (fa_override_arch) {
TORCH_CHECK(mi200 == fa_override_arch,
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName) + " override as " + fa_override_arch);
} else {
TORCH_CHECK(mi200 == dprops->gcnArchName,
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName));
}
aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
{
#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
CAST_TYPE(kByte, kUInt8);
CAST_TYPE(kUInt16, kUInt16);
CAST_TYPE(kUInt32, kUInt32);
CAST_TYPE(kUInt64, kUInt64);
CAST_TYPE(kChar, kInt8);
CAST_TYPE(kShort, kInt16);
CAST_TYPE(kInt, kInt32);
CAST_TYPE(kLong, kInt64);
CAST_TYPE(kHalf, kFloat16);
CAST_TYPE(kFloat, kFloat32);
CAST_TYPE(kBFloat16, kBFloat16);
return aotriton::DType::kUnknown;
#undef CAST_TYPE
}
template<typename TargetType, int Rank>
struct IntArrayRefCaster {
// std::array<TargetType, Rank> cast(IntArrayRef);
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 1> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 2> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 2>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1))
}};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 3> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 3>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1)),
static_cast<TargetType>(ref.at(2))
}};
}
};
template<typename TargetType>
struct IntArrayRefCaster<TargetType, 4> {
static auto cast(at::IntArrayRef ref) {
return std::array<TargetType, 4>{{
static_cast<TargetType>(ref.at(0)),
static_cast<TargetType>(ref.at(1)),
static_cast<TargetType>(ref.at(2)),
static_cast<TargetType>(ref.at(3))
}};
}
};
template<int Rank = 4>
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
{
const auto strides = q.strides();
int real_rank = strides.size();
if (real_rank != Rank) { // Lazy convertion of tensor_name
TORCH_CHECK(false,
std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
+ " but is " + std::to_string(real_rank));
}
return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
IntArrayRefCaster<uint64_t, Rank>::cast(strides),
cast_dtype(q.dtype()));
}
template<bool COPY_FROM_INPUT, // For Input Tensor
bool COPY_BACK> // For Output Tensor
class TensorStorageSanitizer {
public:
TensorStorageSanitizer(const at::Tensor& ref,
at::Tensor& to_sanitize)
: ref_(ref), to_sanitize_(to_sanitize)
{
need_sanitize = ref_.strides() != to_sanitize_.strides();
if (!need_sanitize)
return;
temp_ = at::empty_like(ref_);
if (COPY_FROM_INPUT) {
temp_.copy_(to_sanitize_);
}
}
~TensorStorageSanitizer()
{
if (need_sanitize && COPY_BACK)
to_sanitize_.copy_(temp_);
}
at::Tensor& sanitized_tensor()
{
if (need_sanitize)
return temp_;
return to_sanitize_;
}
private:
const at::Tensor& ref_;
at::Tensor& to_sanitize_;
at::Tensor temp_;
bool need_sanitize = false;
};
}
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
@ -114,7 +213,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
check_gpu_arch();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
@ -206,102 +306,51 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
seed_t = at::empty({}, at::dtype(at::kLong));
offset_t = at::empty({}, at::dtype(at::kLong));
}
}
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
//reorder tensors and make contiguous
at::Tensor q_t = q_padded.permute({0,2,1,3}).contiguous();
at::Tensor k_t = k_padded.permute({0,2,1,3}).contiguous();
at::Tensor v_t = v_padded.permute({0,2,1,3}).contiguous();
at::Tensor output_t = out.permute({0,2,1,3}).contiguous();
at::PhiloxCudaState philox_args;
if (p_dropout > 0.0) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None)
{
philox_args = at::PhiloxCudaState(*seed_t.data_ptr<int64_t>(), *offset_t.data_ptr<int64_t>());
} else { // dropout + capture
philox_args = at::PhiloxCudaState(seed_t.data_ptr<int64_t>(), offset_t.data_ptr<int64_t>(), 0);
}
}
at::Tensor M = at::empty({batch_size, num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
// Transpose tensors to meet AOTriton's Flash API
at::Tensor q_t = q_padded.permute({0,2,1,3});
at::Tensor k_t = k_padded.permute({0,2,1,3});
at::Tensor v_t = v_padded.permute({0,2,1,3});
at::Tensor output_t = out.permute({0,2,1,3});
constexpr int BLOCK_M = 16;
constexpr int BLOCK_N = 16;
dim3 grid;
grid.x = (q_t.sizes()[2] + BLOCK_M - 1) / BLOCK_M;
grid.y = q_t.sizes()[0] * q_t.sizes()[1];
grid.z = 1;
dim3 block { 64 * 4, 1, 1 }; // compiled triton kernel intrinsic
at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
at::Tensor softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
at::dtype(q.dtype()).device(q.device()));
at::Tensor softmax_fa_t;
if (return_softmax) {
softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
at::dtype(q.dtype()).device(q.device()));
} else {
softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device()));
}
hipError_t err; // TODO: Error handling
#define CALL_FWD(FP, STAGE, BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX) \
do { \
oort::attn_fwd<STAGE,BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX> fwd_opt; \
err = fwd_opt(grid, block, \
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
softmax_scale, (float*)M.data_ptr(), (FP*)output_t.data_ptr(), \
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
output_t.stride(0), output_t.stride(1), output_t.stride(2), output_t.stride(3), \
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
*(uint64_t*)(seed_t.data_ptr()), *(uint32_t*)(offset_t.data_ptr()), \
(FP*)(softmax_fa_t.data_ptr()), \
stream); \
} while(0)
// TODO: Ugly but works
constexpr int kFwdUseCausal = 3;
constexpr int kFwdNoCausal = 1;
int d_head = q_t.sizes()[3];
constexpr int BM = BLOCK_M;
constexpr int BN = BLOCK_N;
if (q_dtype == at::kHalf) {
if (is_causal) {
if (d_head == 16)
CALL_FWD(__fp16,kFwdUseCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__fp16,kFwdUseCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__fp16,kFwdUseCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__fp16,kFwdUseCausal,BM,128,BN,true,true,true);
} else {
if (d_head == 16)
CALL_FWD(__fp16,kFwdNoCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__fp16,kFwdNoCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__fp16,kFwdNoCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__fp16,kFwdNoCausal,BM,128,BN,true,true,true);
}
} else if (q_dtype == at::kBFloat16) {
if (is_causal) {
if (d_head == 16)
CALL_FWD(__bf16,kFwdUseCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__bf16,kFwdUseCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__bf16,kFwdUseCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__bf16,kFwdUseCausal,BM,128,BN,true,true,true);
} else {
if (d_head == 16)
CALL_FWD(__bf16,kFwdNoCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__bf16,kFwdNoCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__bf16,kFwdNoCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__bf16,kFwdNoCausal,BM,128,BN,true,true,true);
}
}
//undo reorder tensors
q_padded = q_t.permute({0,2,1,3}).contiguous();
k_padded = k_t.permute({0,2,1,3}).contiguous();
v_padded = v_t.permute({0,2,1,3}).contiguous();
out = output_t.permute({0,2,1,3}).contiguous();
using aotriton::v2::flash::attn_fwd;
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
p_dropout,
philox_args.seed_.val,
philox_args.offset_.val,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t};
#undef CALL_FWD
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -354,10 +403,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
check_gpu_arch();
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
bool is_dropout = p_dropout > 0.0;
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
@ -440,23 +489,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// const at::Tensor& dout_padded = dout;
// bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity
bool loop = true;
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
// dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
}
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
@ -468,149 +506,52 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}
at::PhiloxCudaState philox_args;
if (is_dropout) {
if (p_dropout > 0.0) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None)
{
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
} else { // dropout + capture
philox_args = at::PhiloxCudaState(
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
philox_args = at::PhiloxCudaState(philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
}
}
//JCG TODO WE GO IN HERE TODO backwards
//reorder tensors and make contiguous
at::Tensor q_t = q.permute({0,2,1,3}).contiguous();
at::Tensor k_t = k.permute({0,2,1,3}).contiguous();
at::Tensor v_t = v.permute({0,2,1,3}).contiguous();
at::Tensor out_t = out.permute({0,2,1,3}).contiguous();
at::Tensor q_t = q.permute({0,2,1,3});
at::Tensor k_t = k.permute({0,2,1,3});
at::Tensor v_t = v.permute({0,2,1,3});
at::Tensor out_t = out.permute({0,2,1,3});
at::Tensor dq_t = dq.permute({0,2,1,3});
at::Tensor dk_t = dk.permute({0,2,1,3});
at::Tensor dv_t = dv.permute({0,2,1,3});
at::Tensor dout_t = dout.permute({0,2,1,3});
//reorder tensors and make contiguous
at::Tensor dq_t = dq.permute({0,2,1,3}).contiguous();
at::Tensor dk_t = dk.permute({0,2,1,3}).contiguous();
at::Tensor dv_t = dv.permute({0,2,1,3}).contiguous();
at::Tensor dout_t = dout.permute({0,2,1,3}).contiguous();
dim3 block { 64 * 4, 1, 1 };
at::Tensor new_do = at::empty_like(dout_t).contiguous();
at::Tensor softmax_lse_cont = softmax_lse.contiguous();
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
int d_head = head_size_og;
hipError_t err; // TODO: Error handling
#define CALL_BWD_PP(FP, PP_BLOCK, PP_DMODEL) \
do { \
dim3 pp_grid; \
pp_grid.x = batch_size * num_heads * ((dout_t.size(2) + PP_BLOCK - 1) / PP_BLOCK); \
pp_grid.y = 1; \
pp_grid.z = 1; \
oort::bwd_preprocess<PP_BLOCK, PP_DMODEL> pre_opt; \
err = pre_opt(pp_grid, block, \
(FP*)(out_t.data_ptr()), \
(FP*)(dout_t.data_ptr()), \
(FP*)(new_do.data_ptr()), \
(float*)(delta.data_ptr()), \
stream); \
} while (0)
#define CALL_BWD_PP_DMODEL(FP, PP_BLOCK) \
do { \
if (d_head == 16) \
CALL_BWD_PP(FP, PP_BLOCK, 16); \
else if (d_head == 32) \
CALL_BWD_PP(FP, PP_BLOCK, 32); \
else if (d_head == 64) \
CALL_BWD_PP(FP, PP_BLOCK, 64); \
else if (d_head == 128) \
CALL_BWD_PP(FP, PP_BLOCK, 128); \
} while (0)
if(q_dtype == at::kHalf) {
if (seqlen_q >= 64)
CALL_BWD_PP_DMODEL(__fp16, 16);
else
CALL_BWD_PP_DMODEL(__fp16, 16);
} else if (q_dtype == at::kBFloat16) {
if (seqlen_q >= 64)
CALL_BWD_PP_DMODEL(__bf16, 16);
else
CALL_BWD_PP_DMODEL(__bf16, 16);
{
TensorStorageSanitizer<true, false> dq_s(q_t, dq_t);
TensorStorageSanitizer<true, false> dk_s(k_t, dk_t);
TensorStorageSanitizer<true, false> dv_s(v_t, dv_t);
using aotriton::v2::flash::attn_bwd;
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_s.sanitized_tensor(), "dq"),
mk_aotensor(dk_s.sanitized_tensor(), "dk"),
mk_aotensor(dv_s.sanitized_tensor(), "dv"),
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
philox_args.seed_.val,
philox_args.offset_.val,
is_causal,
stream);
}
#undef CALL_BWD_PP
#define CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT) \
do { \
dim3 grid; \
grid.x = (seqlen_k + BLOCK_M - 1) / BLOCK_M; \
grid.y = batch_size * num_heads; \
grid.z = 1; \
oort::bwd_kernel_dk_dv<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dk_dv_opt; \
err = dk_dv_opt(grid, block, \
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
(FP*)dk_t.data_ptr(),(FP*)dv_t.data_ptr(), \
(float*)(softmax_lse.data_ptr()), \
(float*)(delta.data_ptr()), \
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
grid.x = (seqlen_q + BLOCK_M - 1) / BLOCK_M; \
oort::bwd_kernel_dq<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dq_opt; \
err = dq_opt(grid, block, \
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
(FP*)dq_t.data_ptr(), \
(float*)(softmax_lse.data_ptr()), \
(float*)(delta.data_ptr()), \
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
} while(0)
#define CALL_BWD_DROPOUT(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL) \
do { \
if (p_dropout > 0.0) { \
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, true); \
} else { \
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, false); \
} \
} while (0)
#define CALL_BWD_DROPOUT_DMODEL(FP, BLOCK_M, BLOCK_N, CAUSAL) \
do { \
if (d_head == 16) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 16, BLOCK_N, CAUSAL); \
else if (d_head == 32) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 32, BLOCK_N, CAUSAL); \
else if (d_head == 64) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 64, BLOCK_N, CAUSAL); \
else if (d_head == 128) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 128, BLOCK_N, CAUSAL); \
} while (0)
if (q_dtype == at::kHalf) {
if (is_causal) {
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, true);
} else {
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, false);
}
} else if (q_dtype == at::kBFloat16) {
if (is_causal) {
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, true);
} else {
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, false);
}
}
//undo reorder tensors for returns
dq = dq_t.permute({0,2,1,3}).contiguous();
dk = dk_t.permute({0,2,1,3}).contiguous();
dv = dv_t.permute({0,2,1,3}).contiguous();
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {

View File

@ -979,7 +979,8 @@ if(USE_ROCM)
list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA})
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
if(USE_FLASH_ATTENTION)
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
add_dependencies(torch_hip aotriton_external)
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake

View File

@ -1335,9 +1335,7 @@ if(USE_ROCM)
message(STATUS "Disabling Kernel Assert for ROCm")
endif()
if(USE_FLASH_ATTENTION)
include(${CMAKE_CURRENT_LIST_DIR}/External/oort.cmake)
endif()
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
if(USE_CUDA)
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
endif()

28
cmake/External/aotriton.cmake vendored Normal file
View File

@ -0,0 +1,28 @@
if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INCLUDED TRUE)
set(__AOTRITON_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/src")
set(__AOTRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/build")
set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_TAG 9044fe5eb16130e49a0a1f781ea15037353ad542
SOURCE_DIR ${__AOTRITON_SOURCE_DIR}
BINARY_DIR ${__AOTRITON_BUILD_DIR}
PREFIX ${__AOTRITON_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_COMPRESS_KERNEL=OFF
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NO_SHARED=ON
# CONFIGURE_COMMAND ""
# BUILD_COMMAND ${MAKE_COMMAND}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a"
# INSTALL_COMMAND ${MAKE_COMMAND} install
)
set(AOTRITON_FOUND TRUE)
add_library(__caffe2_aotriton INTERFACE)
add_dependencies(__caffe2_aotriton aotriton_external)
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a)
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
endif() # __AOTRITON_INCLUDED

View File

@ -1,25 +0,0 @@
if(NOT __OORT_INCLUDED)
set(__OORT_INCLUDED TRUE)
set(__OORT_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/src")
set(__OORT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/build")
set(__OORT_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
ExternalProject_Add(oort_external
GIT_REPOSITORY https://github.com/ROCmSoftwarePlatform/triton.git
GIT_TAG 29e1252c1ac8e6a54deb883701e553e5b201a1ba
SOURCE_DIR ${__OORT_SOURCE_DIR}
SOURCE_SUBDIR mathaot
BINARY_DIR ${__OORT_BUILD_DIR}
PREFIX ${__OORT_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__OORT_INSTALL_DIR}
# CONFIGURE_COMMAND ""
# BUILD_COMMAND ${MAKE_COMMAND}
BUILD_BYPRODUCTS "${__OORT_INSTALL_DIR}/lib/liboort.a"
# INSTALL_COMMAND ${MAKE_COMMAND} install
)
set(OORT_FOUND TRUE)
add_library(__caffe2_oort INTERFACE)
add_dependencies(__caffe2_oort oort_external)
target_link_libraries(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/lib/liboort.a)
target_include_directories(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/include)
endif() # __OORT_INCLUDED

View File

@ -2384,7 +2384,7 @@ class TestSDPACudaOnly(NNTestCase):
# Cast up and compare
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
@skipIfRocm # Small matrices
@skipIfRocm # TODO: Packed QKV
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
@parametrize("contiguous_inputs", [True, False])
@parametrize("is_causal", [True, False])
@ -2786,16 +2786,6 @@ class TestSDPACudaOnly(NNTestCase):
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str):
if TEST_WITH_ROCM:
def is_power_of_2(n):
return n & (n - 1) == 0
if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim):
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.")
if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16:
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.")
if head_dim > 128:
self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")
if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k:
@ -3416,6 +3406,7 @@ class TestAttnBias(NNTestCase):
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
@skipIfRocm # CausalVariant
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
@parametrize(
"shape",

View File

@ -31,18 +31,19 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
def evaluate_gfx90a_exact():
def evaluate_gfx_arch_exact(matching_arch):
if not torch.cuda.is_available():
return False
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
return arch == 'gfx90a:sramecc+:xnack-'
return arch == matching_arch
GFX90A_Exact = LazyVal(lambda: evaluate_gfx90a_exact())
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
return evaluate_gfx90a_exact()
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
return False