mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton) - [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`). * MI300X is supported. More architectures will be added once Triton support them. - [x] Only supports power of two sequence lengths. * Now it support arbitrary sequence length - [ ] No support for varlen APIs. * varlen API will be supported in the next release of AOTriton - [x] Only support head dimension 16,32,64,128. * Now it support arbitrary head dimension <= 256 - [x] Performance is still being optimized. * Kernel is selected according to autotune information from Triton. Other improvements from AOTriton include * Allow more flexible Tensor storage layout * More flexible API This is a more extensive fix to #112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561 Approved by: https://github.com/malfet, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
3a5f48d55f
commit
a37e22de70
@ -736,13 +736,28 @@ if(MSVC)
|
||||
append_cxx_flag_if_supported("/utf-8" CMAKE_CXX_FLAGS)
|
||||
endif()
|
||||
|
||||
# CAVEAT: do NOT check USE_ROCM here, because USE_ROCM is always True until
|
||||
# include(cmake/Dependencies.cmake)
|
||||
# Note for ROCM platform:
|
||||
# 1. USE_ROCM is always ON until include(cmake/Dependencies.cmake)
|
||||
# 2. USE_CUDA will become OFF during re-configuration
|
||||
# Truth Table:
|
||||
# CUDA 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
|
||||
# CUDA 2nd pass: USE_CUDA=True;USE_ROCM=False, FLASH evaluates to ON by default
|
||||
# ROCM 1st pass: USE_CUDA=True;USE_ROCM=True, FLASH evaluates to ON by default
|
||||
# ROCM 2nd pass: USE_CUDA=False;USE_ROCM=True, FLASH evaluates to ON by default
|
||||
# CPU 1st pass: USE_CUDA=False(Cmd Option);USE_ROCM=True, FLASH evaluates to OFF by default
|
||||
# CPU 2nd pass: USE_CUDA=False(Cmd Option);USE_ROCM=False, FLASH evaluates to OFF by default
|
||||
# Thus we cannot tell ROCM 2nd pass and CPU 1st pass
|
||||
#
|
||||
# The only solution is to include(cmake/Dependencies.cmake), and defer the
|
||||
# aotriton build decision later.
|
||||
|
||||
include(cmake/Dependencies.cmake)
|
||||
|
||||
cmake_dependent_option(
|
||||
USE_FLASH_ATTENTION
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention.\
|
||||
Will be disabled if not supported by the platform" ON
|
||||
"USE_CUDA AND NOT MSVC" OFF)
|
||||
"USE_CUDA OR USE_ROCM;NOT MSVC" OFF)
|
||||
|
||||
# We are currenlty not using alibi attention for Flash
|
||||
# So we disable this feature by default
|
||||
@ -758,8 +773,6 @@ cmake_dependent_option(
|
||||
Will be disabled if not supported by the platform" ON
|
||||
"USE_CUDA" OFF)
|
||||
|
||||
include(cmake/Dependencies.cmake)
|
||||
|
||||
if(DEBUG_CUDA)
|
||||
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
|
||||
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
|
||||
|
@ -21,6 +21,10 @@
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
#if USE_ROCM
|
||||
#include <aotriton/flash.h>
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Note [SDPA Runtime Dispatch]
|
||||
* SDPA relies on a runtime dispatch mechanism to select the appropriate
|
||||
@ -182,32 +186,18 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
||||
// Check that the gpu is capable of running flash attention
|
||||
using sm80 = SMVersion<8, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
#if USE_ROCM
|
||||
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
|
||||
static const char *over_arch = [] {
|
||||
auto rc = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
|
||||
if (rc) {
|
||||
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
|
||||
"Later changes to this environment variable with os.environ "
|
||||
"(or other methods) will not affect SDPA function's behavior.");
|
||||
}
|
||||
return rc;
|
||||
}();
|
||||
const char* real_arch = dprops->gcnArchName;
|
||||
const char* arch = over_arch ? over_arch : real_arch;
|
||||
if (mi200 != arch) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
|
||||
arch,
|
||||
".",
|
||||
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
|
||||
over_arch ? real_arch : "");
|
||||
}
|
||||
return false;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm80, sm90>(dprops)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
|
@ -59,42 +59,141 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
|
||||
// OORT headers
|
||||
#include <oort/attn_fwd.h>
|
||||
#include <oort/bwd_kernel_dk_dv.h>
|
||||
#include <oort/bwd_kernel_dq.h>
|
||||
#include <oort/bwd_preprocess.h>
|
||||
// AOTriton headers
|
||||
#include <aotriton/dtypes.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
#include <aotriton/util.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
namespace {
|
||||
|
||||
c10::once_flag fa_gcn_arch_override_flag;
|
||||
const char* fa_override_arch = nullptr;
|
||||
|
||||
void init_fa_override_arch() {
|
||||
fa_override_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
|
||||
if (fa_override_arch) {
|
||||
TORCH_WARN("ROCM flash attention backend only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
|
||||
"Later changes to this environment variable with os.environ "
|
||||
"(or other methods) will not affect this backend's behavior.");
|
||||
void check_gpu_arch(hipStream_t stream) {
|
||||
auto ret = aotriton::v2::flash::check_gpu(stream);
|
||||
if (hipSuccess != ret) {
|
||||
TORCH_CHECK(false,
|
||||
"FlashAttention only supports MI200/MI300X GPUs (gfx90a:sramecc+:xnack- or gfx94a:sramecc+:xnack-)")
|
||||
}
|
||||
}
|
||||
|
||||
void check_gpu_arch() {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
|
||||
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
|
||||
c10::call_once(fa_gcn_arch_override_flag, init_fa_override_arch);
|
||||
if (fa_override_arch) {
|
||||
TORCH_CHECK(mi200 == fa_override_arch,
|
||||
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName) + " override as " + fa_override_arch);
|
||||
} else {
|
||||
TORCH_CHECK(mi200 == dprops->gcnArchName,
|
||||
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName));
|
||||
}
|
||||
aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
|
||||
{
|
||||
#define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
|
||||
CAST_TYPE(kByte, kUInt8);
|
||||
CAST_TYPE(kUInt16, kUInt16);
|
||||
CAST_TYPE(kUInt32, kUInt32);
|
||||
CAST_TYPE(kUInt64, kUInt64);
|
||||
CAST_TYPE(kChar, kInt8);
|
||||
CAST_TYPE(kShort, kInt16);
|
||||
CAST_TYPE(kInt, kInt32);
|
||||
CAST_TYPE(kLong, kInt64);
|
||||
CAST_TYPE(kHalf, kFloat16);
|
||||
CAST_TYPE(kFloat, kFloat32);
|
||||
CAST_TYPE(kBFloat16, kBFloat16);
|
||||
return aotriton::DType::kUnknown;
|
||||
#undef CAST_TYPE
|
||||
}
|
||||
|
||||
template<typename TargetType, int Rank>
|
||||
struct IntArrayRefCaster {
|
||||
// std::array<TargetType, Rank> cast(IntArrayRef);
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 1> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 2> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 2>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 3> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 3>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1)),
|
||||
static_cast<TargetType>(ref.at(2))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
template<typename TargetType>
|
||||
struct IntArrayRefCaster<TargetType, 4> {
|
||||
static auto cast(at::IntArrayRef ref) {
|
||||
return std::array<TargetType, 4>{{
|
||||
static_cast<TargetType>(ref.at(0)),
|
||||
static_cast<TargetType>(ref.at(1)),
|
||||
static_cast<TargetType>(ref.at(2)),
|
||||
static_cast<TargetType>(ref.at(3))
|
||||
}};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<int Rank = 4>
|
||||
aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
|
||||
{
|
||||
const auto strides = q.strides();
|
||||
int real_rank = strides.size();
|
||||
if (real_rank != Rank) { // Lazy convertion of tensor_name
|
||||
TORCH_CHECK(false,
|
||||
std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
|
||||
+ " but is " + std::to_string(real_rank));
|
||||
}
|
||||
return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
|
||||
IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
|
||||
IntArrayRefCaster<uint64_t, Rank>::cast(strides),
|
||||
cast_dtype(q.dtype()));
|
||||
}
|
||||
|
||||
template<bool COPY_FROM_INPUT, // For Input Tensor
|
||||
bool COPY_BACK> // For Output Tensor
|
||||
class TensorStorageSanitizer {
|
||||
public:
|
||||
TensorStorageSanitizer(const at::Tensor& ref,
|
||||
at::Tensor& to_sanitize)
|
||||
: ref_(ref), to_sanitize_(to_sanitize)
|
||||
{
|
||||
need_sanitize = ref_.strides() != to_sanitize_.strides();
|
||||
if (!need_sanitize)
|
||||
return;
|
||||
|
||||
temp_ = at::empty_like(ref_);
|
||||
if (COPY_FROM_INPUT) {
|
||||
temp_.copy_(to_sanitize_);
|
||||
}
|
||||
}
|
||||
|
||||
~TensorStorageSanitizer()
|
||||
{
|
||||
if (need_sanitize && COPY_BACK)
|
||||
to_sanitize_.copy_(temp_);
|
||||
}
|
||||
|
||||
at::Tensor& sanitized_tensor()
|
||||
{
|
||||
if (need_sanitize)
|
||||
return temp_;
|
||||
return to_sanitize_;
|
||||
}
|
||||
private:
|
||||
const at::Tensor& ref_;
|
||||
at::Tensor& to_sanitize_;
|
||||
at::Tensor temp_;
|
||||
bool need_sanitize = false;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
@ -114,7 +213,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
check_gpu_arch();
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
@ -206,102 +306,51 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
//reorder tensors and make contiguous
|
||||
at::Tensor q_t = q_padded.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor k_t = k_padded.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor v_t = v_padded.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor output_t = out.permute({0,2,1,3}).contiguous();
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (p_dropout > 0.0) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*seed_t.data_ptr<int64_t>(), *offset_t.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(seed_t.data_ptr<int64_t>(), offset_t.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor M = at::empty({batch_size, num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
|
||||
// Transpose tensors to meet AOTriton's Flash API
|
||||
at::Tensor q_t = q_padded.permute({0,2,1,3});
|
||||
at::Tensor k_t = k_padded.permute({0,2,1,3});
|
||||
at::Tensor v_t = v_padded.permute({0,2,1,3});
|
||||
at::Tensor output_t = out.permute({0,2,1,3});
|
||||
|
||||
constexpr int BLOCK_M = 16;
|
||||
constexpr int BLOCK_N = 16;
|
||||
dim3 grid;
|
||||
grid.x = (q_t.sizes()[2] + BLOCK_M - 1) / BLOCK_M;
|
||||
grid.y = q_t.sizes()[0] * q_t.sizes()[1];
|
||||
grid.z = 1;
|
||||
dim3 block { 64 * 4, 1, 1 }; // compiled triton kernel intrinsic
|
||||
at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
|
||||
|
||||
at::Tensor softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
|
||||
at::dtype(q.dtype()).device(q.device()));
|
||||
at::Tensor softmax_fa_t;
|
||||
if (return_softmax) {
|
||||
softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
|
||||
at::dtype(q.dtype()).device(q.device()));
|
||||
} else {
|
||||
softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device()));
|
||||
}
|
||||
|
||||
hipError_t err; // TODO: Error handling
|
||||
#define CALL_FWD(FP, STAGE, BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX) \
|
||||
do { \
|
||||
oort::attn_fwd<STAGE,BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX> fwd_opt; \
|
||||
err = fwd_opt(grid, block, \
|
||||
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
|
||||
softmax_scale, (float*)M.data_ptr(), (FP*)output_t.data_ptr(), \
|
||||
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
|
||||
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
|
||||
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
|
||||
output_t.stride(0), output_t.stride(1), output_t.stride(2), output_t.stride(3), \
|
||||
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
|
||||
*(uint64_t*)(seed_t.data_ptr()), *(uint32_t*)(offset_t.data_ptr()), \
|
||||
(FP*)(softmax_fa_t.data_ptr()), \
|
||||
stream); \
|
||||
} while(0)
|
||||
|
||||
// TODO: Ugly but works
|
||||
constexpr int kFwdUseCausal = 3;
|
||||
constexpr int kFwdNoCausal = 1;
|
||||
int d_head = q_t.sizes()[3];
|
||||
constexpr int BM = BLOCK_M;
|
||||
constexpr int BN = BLOCK_N;
|
||||
if (q_dtype == at::kHalf) {
|
||||
if (is_causal) {
|
||||
if (d_head == 16)
|
||||
CALL_FWD(__fp16,kFwdUseCausal,BM,16,BN,true,true,true);
|
||||
else if (d_head == 32)
|
||||
CALL_FWD(__fp16,kFwdUseCausal,BM,32,BN,true,true,true);
|
||||
else if (d_head == 64)
|
||||
CALL_FWD(__fp16,kFwdUseCausal,BM,64,BN,true,true,true);
|
||||
else if (d_head == 128)
|
||||
CALL_FWD(__fp16,kFwdUseCausal,BM,128,BN,true,true,true);
|
||||
} else {
|
||||
if (d_head == 16)
|
||||
CALL_FWD(__fp16,kFwdNoCausal,BM,16,BN,true,true,true);
|
||||
else if (d_head == 32)
|
||||
CALL_FWD(__fp16,kFwdNoCausal,BM,32,BN,true,true,true);
|
||||
else if (d_head == 64)
|
||||
CALL_FWD(__fp16,kFwdNoCausal,BM,64,BN,true,true,true);
|
||||
else if (d_head == 128)
|
||||
CALL_FWD(__fp16,kFwdNoCausal,BM,128,BN,true,true,true);
|
||||
}
|
||||
} else if (q_dtype == at::kBFloat16) {
|
||||
if (is_causal) {
|
||||
if (d_head == 16)
|
||||
CALL_FWD(__bf16,kFwdUseCausal,BM,16,BN,true,true,true);
|
||||
else if (d_head == 32)
|
||||
CALL_FWD(__bf16,kFwdUseCausal,BM,32,BN,true,true,true);
|
||||
else if (d_head == 64)
|
||||
CALL_FWD(__bf16,kFwdUseCausal,BM,64,BN,true,true,true);
|
||||
else if (d_head == 128)
|
||||
CALL_FWD(__bf16,kFwdUseCausal,BM,128,BN,true,true,true);
|
||||
} else {
|
||||
if (d_head == 16)
|
||||
CALL_FWD(__bf16,kFwdNoCausal,BM,16,BN,true,true,true);
|
||||
else if (d_head == 32)
|
||||
CALL_FWD(__bf16,kFwdNoCausal,BM,32,BN,true,true,true);
|
||||
else if (d_head == 64)
|
||||
CALL_FWD(__bf16,kFwdNoCausal,BM,64,BN,true,true,true);
|
||||
else if (d_head == 128)
|
||||
CALL_FWD(__bf16,kFwdNoCausal,BM,128,BN,true,true,true);
|
||||
}
|
||||
}
|
||||
|
||||
//undo reorder tensors
|
||||
q_padded = q_t.permute({0,2,1,3}).contiguous();
|
||||
k_padded = k_t.permute({0,2,1,3}).contiguous();
|
||||
v_padded = v_t.permute({0,2,1,3}).contiguous();
|
||||
out = output_t.permute({0,2,1,3}).contiguous();
|
||||
using aotriton::v2::flash::attn_fwd;
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(M, "M"),
|
||||
mk_aotensor(output_t, "Out"),
|
||||
p_dropout,
|
||||
philox_args.seed_.val,
|
||||
philox_args.offset_.val,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
stream);
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t};
|
||||
#undef CALL_FWD
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@ -354,10 +403,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
check_gpu_arch();
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
@ -440,23 +489,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
|
||||
// const at::Tensor& dout_padded = dout;
|
||||
|
||||
// bool loop = seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
at::Tensor dk_accum, dv_accum;
|
||||
if (loop) {
|
||||
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
@ -468,149 +506,52 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
}
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (p_dropout > 0.0) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
philox_args = at::PhiloxCudaState(philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
//JCG TODO WE GO IN HERE TODO backwards
|
||||
//reorder tensors and make contiguous
|
||||
at::Tensor q_t = q.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor k_t = k.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor v_t = v.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor out_t = out.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor q_t = q.permute({0,2,1,3});
|
||||
at::Tensor k_t = k.permute({0,2,1,3});
|
||||
at::Tensor v_t = v.permute({0,2,1,3});
|
||||
at::Tensor out_t = out.permute({0,2,1,3});
|
||||
at::Tensor dq_t = dq.permute({0,2,1,3});
|
||||
at::Tensor dk_t = dk.permute({0,2,1,3});
|
||||
at::Tensor dv_t = dv.permute({0,2,1,3});
|
||||
at::Tensor dout_t = dout.permute({0,2,1,3});
|
||||
|
||||
//reorder tensors and make contiguous
|
||||
at::Tensor dq_t = dq.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor dk_t = dk.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor dv_t = dv.permute({0,2,1,3}).contiguous();
|
||||
at::Tensor dout_t = dout.permute({0,2,1,3}).contiguous();
|
||||
|
||||
dim3 block { 64 * 4, 1, 1 };
|
||||
|
||||
at::Tensor new_do = at::empty_like(dout_t).contiguous();
|
||||
at::Tensor softmax_lse_cont = softmax_lse.contiguous();
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
|
||||
int d_head = head_size_og;
|
||||
hipError_t err; // TODO: Error handling
|
||||
#define CALL_BWD_PP(FP, PP_BLOCK, PP_DMODEL) \
|
||||
do { \
|
||||
dim3 pp_grid; \
|
||||
pp_grid.x = batch_size * num_heads * ((dout_t.size(2) + PP_BLOCK - 1) / PP_BLOCK); \
|
||||
pp_grid.y = 1; \
|
||||
pp_grid.z = 1; \
|
||||
oort::bwd_preprocess<PP_BLOCK, PP_DMODEL> pre_opt; \
|
||||
err = pre_opt(pp_grid, block, \
|
||||
(FP*)(out_t.data_ptr()), \
|
||||
(FP*)(dout_t.data_ptr()), \
|
||||
(FP*)(new_do.data_ptr()), \
|
||||
(float*)(delta.data_ptr()), \
|
||||
stream); \
|
||||
} while (0)
|
||||
|
||||
#define CALL_BWD_PP_DMODEL(FP, PP_BLOCK) \
|
||||
do { \
|
||||
if (d_head == 16) \
|
||||
CALL_BWD_PP(FP, PP_BLOCK, 16); \
|
||||
else if (d_head == 32) \
|
||||
CALL_BWD_PP(FP, PP_BLOCK, 32); \
|
||||
else if (d_head == 64) \
|
||||
CALL_BWD_PP(FP, PP_BLOCK, 64); \
|
||||
else if (d_head == 128) \
|
||||
CALL_BWD_PP(FP, PP_BLOCK, 128); \
|
||||
} while (0)
|
||||
|
||||
if(q_dtype == at::kHalf) {
|
||||
if (seqlen_q >= 64)
|
||||
CALL_BWD_PP_DMODEL(__fp16, 16);
|
||||
else
|
||||
CALL_BWD_PP_DMODEL(__fp16, 16);
|
||||
} else if (q_dtype == at::kBFloat16) {
|
||||
if (seqlen_q >= 64)
|
||||
CALL_BWD_PP_DMODEL(__bf16, 16);
|
||||
else
|
||||
CALL_BWD_PP_DMODEL(__bf16, 16);
|
||||
{
|
||||
TensorStorageSanitizer<true, false> dq_s(q_t, dq_t);
|
||||
TensorStorageSanitizer<true, false> dk_s(k_t, dk_t);
|
||||
TensorStorageSanitizer<true, false> dv_s(v_t, dv_t);
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_s.sanitized_tensor(), "dq"),
|
||||
mk_aotensor(dk_s.sanitized_tensor(), "dk"),
|
||||
mk_aotensor(dv_s.sanitized_tensor(), "dv"),
|
||||
mk_aotensor<2>(softmax_lse_cont, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
p_dropout,
|
||||
philox_args.seed_.val,
|
||||
philox_args.offset_.val,
|
||||
is_causal,
|
||||
stream);
|
||||
}
|
||||
#undef CALL_BWD_PP
|
||||
|
||||
#define CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT) \
|
||||
do { \
|
||||
dim3 grid; \
|
||||
grid.x = (seqlen_k + BLOCK_M - 1) / BLOCK_M; \
|
||||
grid.y = batch_size * num_heads; \
|
||||
grid.z = 1; \
|
||||
oort::bwd_kernel_dk_dv<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dk_dv_opt; \
|
||||
err = dk_dv_opt(grid, block, \
|
||||
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
|
||||
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
|
||||
(FP*)dk_t.data_ptr(),(FP*)dv_t.data_ptr(), \
|
||||
(float*)(softmax_lse.data_ptr()), \
|
||||
(float*)(delta.data_ptr()), \
|
||||
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
|
||||
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
|
||||
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
|
||||
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
|
||||
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
|
||||
grid.x = (seqlen_q + BLOCK_M - 1) / BLOCK_M; \
|
||||
oort::bwd_kernel_dq<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dq_opt; \
|
||||
err = dq_opt(grid, block, \
|
||||
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
|
||||
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
|
||||
(FP*)dq_t.data_ptr(), \
|
||||
(float*)(softmax_lse.data_ptr()), \
|
||||
(float*)(delta.data_ptr()), \
|
||||
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
|
||||
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
|
||||
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
|
||||
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
|
||||
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
|
||||
} while(0)
|
||||
|
||||
#define CALL_BWD_DROPOUT(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL) \
|
||||
do { \
|
||||
if (p_dropout > 0.0) { \
|
||||
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, true); \
|
||||
} else { \
|
||||
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, false); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CALL_BWD_DROPOUT_DMODEL(FP, BLOCK_M, BLOCK_N, CAUSAL) \
|
||||
do { \
|
||||
if (d_head == 16) \
|
||||
CALL_BWD_DROPOUT(FP, BLOCK_M, 16, BLOCK_N, CAUSAL); \
|
||||
else if (d_head == 32) \
|
||||
CALL_BWD_DROPOUT(FP, BLOCK_M, 32, BLOCK_N, CAUSAL); \
|
||||
else if (d_head == 64) \
|
||||
CALL_BWD_DROPOUT(FP, BLOCK_M, 64, BLOCK_N, CAUSAL); \
|
||||
else if (d_head == 128) \
|
||||
CALL_BWD_DROPOUT(FP, BLOCK_M, 128, BLOCK_N, CAUSAL); \
|
||||
} while (0)
|
||||
|
||||
if (q_dtype == at::kHalf) {
|
||||
if (is_causal) {
|
||||
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, true);
|
||||
} else {
|
||||
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, false);
|
||||
}
|
||||
} else if (q_dtype == at::kBFloat16) {
|
||||
if (is_causal) {
|
||||
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, true);
|
||||
} else {
|
||||
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, false);
|
||||
}
|
||||
}
|
||||
|
||||
//undo reorder tensors for returns
|
||||
dq = dq_t.permute({0,2,1,3}).contiguous();
|
||||
dk = dk_t.permute({0,2,1,3}).contiguous();
|
||||
dv = dv_t.permute({0,2,1,3}).contiguous();
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
|
@ -979,7 +979,8 @@ if(USE_ROCM)
|
||||
list(APPEND Caffe2_HIP_SRCS ${GENERATED_CXX_TORCH_CUDA})
|
||||
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
|
||||
target_link_libraries(torch_hip PRIVATE __caffe2_aotriton)
|
||||
add_dependencies(torch_hip aotriton_external)
|
||||
endif()
|
||||
set(CUDA_LINK_LIBRARIES_KEYWORD)
|
||||
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
|
||||
|
@ -1335,9 +1335,7 @@ if(USE_ROCM)
|
||||
message(STATUS "Disabling Kernel Assert for ROCm")
|
||||
endif()
|
||||
|
||||
if(USE_FLASH_ATTENTION)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/External/oort.cmake)
|
||||
endif()
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
|
||||
if(USE_CUDA)
|
||||
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
|
||||
endif()
|
||||
|
28
cmake/External/aotriton.cmake
vendored
Normal file
28
cmake/External/aotriton.cmake
vendored
Normal file
@ -0,0 +1,28 @@
|
||||
if(NOT __AOTRITON_INCLUDED)
|
||||
set(__AOTRITON_INCLUDED TRUE)
|
||||
|
||||
set(__AOTRITON_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/src")
|
||||
set(__AOTRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/aotriton/build")
|
||||
set(__AOTRITON_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
|
||||
ExternalProject_Add(aotriton_external
|
||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||
GIT_TAG 9044fe5eb16130e49a0a1f781ea15037353ad542
|
||||
SOURCE_DIR ${__AOTRITON_SOURCE_DIR}
|
||||
BINARY_DIR ${__AOTRITON_BUILD_DIR}
|
||||
PREFIX ${__AOTRITON_INSTALL_DIR}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
|
||||
-DAOTRITON_COMPRESS_KERNEL=OFF
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
-DAOTRITON_NO_PYTHON=ON
|
||||
-DAOTRITON_NO_SHARED=ON
|
||||
# CONFIGURE_COMMAND ""
|
||||
# BUILD_COMMAND ${MAKE_COMMAND}
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a"
|
||||
# INSTALL_COMMAND ${MAKE_COMMAND} install
|
||||
)
|
||||
set(AOTRITON_FOUND TRUE)
|
||||
add_library(__caffe2_aotriton INTERFACE)
|
||||
add_dependencies(__caffe2_aotriton aotriton_external)
|
||||
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.a)
|
||||
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
|
||||
endif() # __AOTRITON_INCLUDED
|
25
cmake/External/oort.cmake
vendored
25
cmake/External/oort.cmake
vendored
@ -1,25 +0,0 @@
|
||||
if(NOT __OORT_INCLUDED)
|
||||
set(__OORT_INCLUDED TRUE)
|
||||
|
||||
set(__OORT_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/src")
|
||||
set(__OORT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/build")
|
||||
set(__OORT_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
|
||||
ExternalProject_Add(oort_external
|
||||
GIT_REPOSITORY https://github.com/ROCmSoftwarePlatform/triton.git
|
||||
GIT_TAG 29e1252c1ac8e6a54deb883701e553e5b201a1ba
|
||||
SOURCE_DIR ${__OORT_SOURCE_DIR}
|
||||
SOURCE_SUBDIR mathaot
|
||||
BINARY_DIR ${__OORT_BUILD_DIR}
|
||||
PREFIX ${__OORT_INSTALL_DIR}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__OORT_INSTALL_DIR}
|
||||
# CONFIGURE_COMMAND ""
|
||||
# BUILD_COMMAND ${MAKE_COMMAND}
|
||||
BUILD_BYPRODUCTS "${__OORT_INSTALL_DIR}/lib/liboort.a"
|
||||
# INSTALL_COMMAND ${MAKE_COMMAND} install
|
||||
)
|
||||
set(OORT_FOUND TRUE)
|
||||
add_library(__caffe2_oort INTERFACE)
|
||||
add_dependencies(__caffe2_oort oort_external)
|
||||
target_link_libraries(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/lib/liboort.a)
|
||||
target_include_directories(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/include)
|
||||
endif() # __OORT_INCLUDED
|
@ -2384,7 +2384,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
# Cast up and compare
|
||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
|
||||
|
||||
@skipIfRocm # Small matrices
|
||||
@skipIfRocm # TODO: Packed QKV
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
@parametrize("is_causal", [True, False])
|
||||
@ -2786,16 +2786,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str):
|
||||
if TEST_WITH_ROCM:
|
||||
def is_power_of_2(n):
|
||||
return n & (n - 1) == 0
|
||||
if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim):
|
||||
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.")
|
||||
if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16:
|
||||
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.")
|
||||
if head_dim > 128:
|
||||
self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")
|
||||
|
||||
if isSM8XDevice and head_dim in range(193, 256 + 1):
|
||||
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
|
||||
if is_causal and seq_len_q != seq_len_k:
|
||||
@ -3416,6 +3406,7 @@ class TestAttnBias(NNTestCase):
|
||||
|
||||
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
|
||||
|
||||
@skipIfRocm # CausalVariant
|
||||
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
||||
@parametrize(
|
||||
"shape",
|
||||
|
@ -31,18 +31,19 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
||||
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
||||
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
||||
|
||||
def evaluate_gfx90a_exact():
|
||||
def evaluate_gfx_arch_exact(matching_arch):
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
||||
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
|
||||
return arch == 'gfx90a:sramecc+:xnack-'
|
||||
return arch == matching_arch
|
||||
|
||||
GFX90A_Exact = LazyVal(lambda: evaluate_gfx90a_exact())
|
||||
GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
|
||||
GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
|
||||
|
||||
def evaluate_platform_supports_flash_attention():
|
||||
if TEST_WITH_ROCM:
|
||||
return evaluate_gfx90a_exact()
|
||||
return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
|
||||
if TEST_CUDA:
|
||||
return not IS_WINDOWS and SM80OrLater
|
||||
return False
|
||||
|
Reference in New Issue
Block a user