mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Bump AOTriton to 0.11b (#161754)
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b: * Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements - AITER ASM kernels deliver over 500TFLOPS training performance. See [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more details. * Now returns natural based `logsumexp` tensor, matching CUDA's behavior - PR #156903 is reverted in this PR as well since it is not needed anymore. * Enables `CausalVariant.LOWER_RIGHT` The build system changes drastically along with new packaging scheme of AOTriton 0.11 * AOTriton 0.11 packs GPU images separately from AOTriton runtime * `aotriton.cmake` now selectively downloads image packs according to `PYTORCH_ROCM_ARCH` * `aotriton.cmake` now only use pre-compiled runtime library that exactly matches the ROCM in the build environment. For PyTorch builds with ROCm versions not listed in the file, the build process will build AOTriton runtime without GPU images from source - This avoids any further ABI breaks like ROCM 6.4 -> 7.0 - recursive git clone is disabled since building AOTriton runtime does not require submodules. Bug fixes: * Fix a kernel bug introduced when implementing SWA Known Problems: * gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status due to accuracy issues. Triton compiler fixes are needed to restore the support status. * Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0. This issue is under investigation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161754 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
994f2a5dbc
commit
98efc9e93d
@ -1396,13 +1396,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
||||
at::Tensor v_t = value.transpose(1, 2);
|
||||
at::Tensor output_t = res.transpose(1, 2);
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
is_causal = true;
|
||||
#if AOTRITON_V3_API == 0
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
|
||||
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
at::Tensor atomic_counter;
|
||||
if (is_causal) {
|
||||
@ -1426,7 +1429,51 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
||||
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||
hipError_t err; // TODO: Error handling
|
||||
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
|
||||
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
using aotriton::v3::flash::WindowValue;
|
||||
aotriton::v3::flash::attn_fwd_params params;
|
||||
params.Q = mk_aotensor(q_t, "q");
|
||||
params.K = mk_aotensor(k_t, "k");
|
||||
params.V = mk_aotensor(v_t, "v");
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2;
|
||||
params.Out = mk_aotensor(output_t, "Out");
|
||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = dropout_p;
|
||||
params.philox_seed_ptr = seed;
|
||||
params.philox_offset1 = offset1;
|
||||
params.philox_offset2 = offset2;
|
||||
params.philox_seed_output = seed_output;
|
||||
params.philox_offset_output = offset_output;
|
||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
||||
params.persistent_atomic_counter = persistent_counter;
|
||||
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
params.window_left = WindowValue::TopLeftAligned;
|
||||
params.window_right = WindowValue::TopLeftAligned;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
|
||||
params.window_left = WindowValue::BottomRightAligned;
|
||||
params.window_right = WindowValue::BottomRightAligned;
|
||||
}
|
||||
if (bias.has_value()) {
|
||||
params.B = mk_aotensor(bias.value(), "bias");
|
||||
}
|
||||
if (seqstart_q.has_value()) {
|
||||
params.varlen_type = VarlenType::CompactVarlen;
|
||||
params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q");
|
||||
params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k");
|
||||
} else {
|
||||
params.varlen_type = VarlenType::None;
|
||||
}
|
||||
err = aotriton::v3::flash::attn_fwd(params,
|
||||
aotriton::v3::flash::attn_fwd_params::kVersion,
|
||||
stream);
|
||||
#endif // AOTRITON_V3_API
|
||||
} else if (seqstart_q.has_value()) {
|
||||
// varlen aka nested tensor
|
||||
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/empty_strided.h>
|
||||
#include <ATen/ops/_cudnn_attention_backward.h>
|
||||
@ -47,6 +48,7 @@
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
|
||||
#else
|
||||
#include <ATen/native/transformers/hip/gemm_kernel_utils.h>
|
||||
// MemoryEfficient Attention Specific Imports for ROCM
|
||||
#ifndef DISABLE_AOTRITON
|
||||
#include <ATen/native/transformers/hip/aotriton_adapter.h>
|
||||
@ -544,12 +546,15 @@ _efficient_attention_backward(
|
||||
}
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
bool is_causal;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
is_causal = true;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
|
||||
is_causal = false;
|
||||
} else {
|
||||
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
|
||||
is_causal = true;
|
||||
#if AOTRITON_V3_API == 0
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
|
||||
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
at::Tensor q_t = query.permute({0,2,1,3});
|
||||
at::Tensor k_t = key.permute({0,2,1,3});
|
||||
@ -568,7 +573,62 @@ _efficient_attention_backward(
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
|
||||
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
|
||||
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
using aotriton::v3::flash::WindowValue;
|
||||
aotriton::v3::flash::attn_bwd_params params;
|
||||
params.Q = mk_aotensor(q_t, "q");
|
||||
params.K = mk_aotensor(k_t, "k");
|
||||
params.V = mk_aotensor(v_t, "v");
|
||||
params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4;
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.Out = mk_aotensor(out_t, "out");
|
||||
params.DO = mk_aotensor(dout_t, "dout");
|
||||
params.DK = mk_aotensor(dk_t, "dk");
|
||||
params.DV = mk_aotensor(dv_t, "dv");
|
||||
params.DQ = mk_aotensor(dq_t, "dq");
|
||||
params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4;
|
||||
params.L = mk_aotensor<2>(softmax_lse, "L");
|
||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = float(dropout_p);
|
||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
||||
params.philox_offset2 = 0;
|
||||
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
|
||||
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
|
||||
params.window_left = WindowValue::TopLeftAligned;
|
||||
params.window_right = WindowValue::TopLeftAligned;
|
||||
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
|
||||
params.window_left = WindowValue::BottomRightAligned;
|
||||
params.window_right = WindowValue::BottomRightAligned;
|
||||
}
|
||||
#if AOTRITON_ALWAYS_V3_API
|
||||
using sdp::aotriton_adapter::mklazy_empty_like;
|
||||
using sdp::aotriton_adapter::mklazy_fp32zeros;
|
||||
using sdp::aotriton_adapter::LazyTensorContext;
|
||||
LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" };
|
||||
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
|
||||
params.D = mklazy_empty_like<2>(&lazy_delta);
|
||||
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
|
||||
#else
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
#endif
|
||||
if (cu_seqlens_q.has_value()) {
|
||||
params.varlen_type = VarlenType::CompactVarlen;
|
||||
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q");
|
||||
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k");
|
||||
} else {
|
||||
params.varlen_type = VarlenType::None;
|
||||
}
|
||||
err = aotriton::v3::flash::attn_bwd(params,
|
||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
||||
stream);
|
||||
#endif // AOTRITON_V3_API
|
||||
} else if (cu_seqlens_q.has_value()) {
|
||||
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||
// varlen aka Nested tensor
|
||||
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Array.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/string_view.h>
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
@ -25,9 +26,12 @@
|
||||
|
||||
#if USE_ROCM
|
||||
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
|
||||
#include <ATen/native/transformers/hip/aotriton_versions.h>
|
||||
#include <aotriton/flash.h>
|
||||
#define USE_ROCM_ATTENTION 1
|
||||
#endif
|
||||
#else
|
||||
#define USE_ROCM_ATTENTION 0
|
||||
#endif
|
||||
|
||||
// Avoid potential compiler -Wall -Werror complains undefined macro
|
||||
@ -129,9 +133,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
|
||||
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
|
||||
template<bool caller_is_meff = false>
|
||||
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
|
||||
#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9
|
||||
#if USE_ROCM_ATTENTION
|
||||
// AOTriton 0.9+ supports head_dim up to 512
|
||||
const auto max_size = c10::SymInt(512);
|
||||
const static auto max_hdim = []() {
|
||||
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
|
||||
// gfx11xx only support hdim <= 256 on AOTriton 0.11
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
const c10::basic_string_view<char> arch(dprops->gcnArchName);
|
||||
if (arch.starts_with("gfx11")) {
|
||||
return 256;
|
||||
}
|
||||
#endif // AOTriton 0.11
|
||||
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
|
||||
return 512;
|
||||
#else
|
||||
return 256;
|
||||
#endif
|
||||
}();
|
||||
const auto max_size = c10::SymInt(max_hdim);
|
||||
#else
|
||||
// All head_dim sizes must be equal and less than 256
|
||||
const auto max_size = c10::SymInt(256);
|
||||
|
@ -2,8 +2,12 @@
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
// Expect to be included after headers of at::zeros_like and at::empty_like
|
||||
|
||||
#include <aotriton/dtypes.h>
|
||||
#include <aotriton/util.h>
|
||||
#include <aotriton/config.h>
|
||||
#include <ATen/native/transformers/hip/aotriton_versions.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
|
||||
@ -111,6 +115,61 @@ inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr)
|
||||
aotriton::DType::kInt32);
|
||||
}
|
||||
|
||||
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11)
|
||||
|
||||
struct LazyTensorContext {
|
||||
at::Tensor like_tensor;
|
||||
std::string_view tensor_name;
|
||||
at::Tensor tensor;
|
||||
};
|
||||
|
||||
template<int kRank, bool kRequireZeros>
|
||||
struct LazyTensorFunctions : public LazyTensorContext {
|
||||
static aotriton::TensorView<kRank> acquire(void* cookie) {
|
||||
auto ctx = (LazyTensorContext*)cookie;
|
||||
if (!ctx->tensor.defined()) {
|
||||
auto q = ctx->like_tensor;
|
||||
if constexpr (kRequireZeros) {
|
||||
ctx->tensor = at::zeros(q.sizes(),
|
||||
q.options().dtype(at::kFloat));
|
||||
} else {
|
||||
ctx->tensor = at::empty_like(q);
|
||||
}
|
||||
}
|
||||
return mk_aotensor<kRank>(ctx->tensor, ctx->tensor_name);
|
||||
}
|
||||
|
||||
static void dispose(void* cookie) {
|
||||
}
|
||||
};
|
||||
|
||||
template<int kRank, bool kRequireZeros>
|
||||
aotriton::LazyTensor<kRank> mklazy_common(LazyTensorContext* cookie)
|
||||
{
|
||||
using LTF = LazyTensorFunctions<kRank, kRequireZeros>;
|
||||
return aotriton::LazyTensor<kRank> {
|
||||
.cookie = cookie,
|
||||
.acquire = <F::acquire,
|
||||
.dispose = <F::dispose
|
||||
};
|
||||
}
|
||||
|
||||
template<int kRank>
|
||||
auto mklazy_empty_like(LazyTensorContext* cookie)
|
||||
{
|
||||
return mklazy_common<kRank, false>(cookie);
|
||||
}
|
||||
|
||||
|
||||
// Note: this will not keep the original strides
|
||||
template<int kRank>
|
||||
auto mklazy_fp32zeros(LazyTensorContext* cookie)
|
||||
{
|
||||
return mklazy_common<kRank, true>(cookie);
|
||||
}
|
||||
|
||||
#endif // >= 0.11
|
||||
|
||||
} // namespace aotriton_adapter
|
||||
|
||||
} // namespace sdp
|
||||
|
20
aten/src/ATen/native/transformers/hip/aotriton_versions.h
Normal file
20
aten/src/ATen/native/transformers/hip/aotriton_versions.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#define AOTRITON_VERSION_INT(x, y) (x * 100 + y)
|
||||
#define AOTRITON_VERSION_CURRENT (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR)
|
||||
|
||||
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11)
|
||||
#define AOTRITON_ALWAYS_V3_API 1
|
||||
#else
|
||||
#define AOTRITON_ALWAYS_V3_API 0
|
||||
#endif
|
||||
|
||||
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 10)
|
||||
#define AOTRITON_V3_API 1
|
||||
#else
|
||||
#define AOTRITON_V3_API 0
|
||||
#endif
|
||||
|
||||
#endif
|
@ -60,20 +60,13 @@
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
// AOTriton headers
|
||||
#include <aotriton/config.h>
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
|
||||
#if AOTRITON_VERSION_MINOR < 9
|
||||
#if AOTRITON_VERSION_CURRENT < AOTRITON_VERSION_INT(0, 9)
|
||||
#error "This adaptor code is only tested with AOTriton >= 0.9"
|
||||
#endif
|
||||
|
||||
#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10
|
||||
#define V3_API 1
|
||||
#else
|
||||
#define V3_API 0
|
||||
#endif
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
namespace {
|
||||
@ -93,15 +86,15 @@ calculate_swa(std::optional<int64_t> window_size_left,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
bool is_causal) {
|
||||
#if V3_API // SWA is exposed through V3 API
|
||||
#if AOTRITON_V3_API // SWA is exposed through V3 API
|
||||
bool needs_swa = false;
|
||||
using aotriton::v3::flash::WindowValue;
|
||||
// Default values when std::optional window_size_left/right have no value
|
||||
int window_left = max_seqlen_q;
|
||||
int window_right = max_seqlen_k;
|
||||
if (is_causal) {
|
||||
window_left = WindowValue::TopLeftAligned;
|
||||
window_right = WindowValue::TopLeftAligned;
|
||||
window_left = WindowValue::BottomRightAligned;
|
||||
window_right = WindowValue::BottomRightAligned;
|
||||
}
|
||||
if (window_size_left.has_value() || window_size_right.has_value()) {
|
||||
needs_swa = true;
|
||||
@ -248,10 +241,10 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
#if AOTRITON_V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
@ -278,8 +271,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
|
||||
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||
if (uses_swa) {
|
||||
#if V3_API
|
||||
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
|
||||
#if AOTRITON_V3_API
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_fwd_params params;
|
||||
@ -299,7 +292,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
params.philox_offset_output = offset_output;
|
||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
||||
params.persistent_atomic_counter = persistent_counter;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
|
||||
params.varlen_type = VarlenType::None;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
@ -449,10 +442,10 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
#if AOTRITON_V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
@ -482,8 +475,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
|
||||
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
|
||||
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
|
||||
if (uses_swa) {
|
||||
#if V3_API
|
||||
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
|
||||
#if AOTRITON_V3_API
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_fwd_params params;
|
||||
@ -505,7 +498,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
params.philox_offset_output = offset_output;
|
||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
||||
params.persistent_atomic_counter = persistent_counter;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
|
||||
params.varlen_type = VarlenType::CompactVarlen;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
@ -599,10 +592,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
|
||||
if (is_causal){
|
||||
TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels");
|
||||
}
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
@ -654,10 +643,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
#if AOTRITON_V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
@ -681,10 +670,9 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
hipError_t err; // TODO: Error handling
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
if (uses_swa) {
|
||||
#if V3_API
|
||||
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
|
||||
#if AOTRITON_V3_API
|
||||
// Fused BWD does not support SWA
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_bwd_params params;
|
||||
@ -694,21 +682,32 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.Out = mk_aotensor(out_t, "out");
|
||||
params.DO = mk_aotensor(dout_t, "dout");
|
||||
params.DK = mk_aotensor(dq_t, "dq");
|
||||
params.DV = mk_aotensor(dk_t, "dk");
|
||||
params.DQ = mk_aotensor(dv_t, "dv");
|
||||
params.DQ = mk_aotensor(dq_t, "dq");
|
||||
params.DK = mk_aotensor(dk_t, "dk");
|
||||
params.DV = mk_aotensor(dv_t, "dv");
|
||||
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = p_dropout;
|
||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
||||
params.philox_offset2 = 0;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.varlen_type = VarlenType::None;
|
||||
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
params.varlen_type = VarlenType::None;
|
||||
#if AOTRITON_ALWAYS_V3_API
|
||||
using sdp::aotriton_adapter::mklazy_empty_like;
|
||||
using sdp::aotriton_adapter::mklazy_fp32zeros;
|
||||
using sdp::aotriton_adapter::LazyTensorContext;
|
||||
LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" };
|
||||
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
|
||||
params.D = mklazy_empty_like<2>(&lazy_delta);
|
||||
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
|
||||
#else
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
#endif
|
||||
err = aotriton::v3::flash::attn_bwd(params,
|
||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
||||
stream);
|
||||
@ -843,7 +842,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous();
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
q_padded = q.unsqueeze(0).transpose(1, 2);
|
||||
@ -901,10 +899,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
#if AOTRITON_V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
@ -924,8 +922,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
hipError_t err; // TODO: Error handling
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
if (uses_swa) {
|
||||
#if V3_API
|
||||
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
|
||||
#if AOTRITON_V3_API
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_bwd_params params;
|
||||
@ -935,11 +933,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.Out = mk_aotensor(out_t, "out");
|
||||
params.DO = mk_aotensor(dout_t, "dout");
|
||||
params.DK = mk_aotensor(dq_padded, "dq");
|
||||
params.DV = mk_aotensor(dk_padded, "dk");
|
||||
params.DQ = mk_aotensor(dv_padded, "dv");
|
||||
params.DK = mk_aotensor(dk_padded, "dk");
|
||||
params.DV = mk_aotensor(dv_padded, "dv");
|
||||
params.DQ = mk_aotensor(dq_padded, "dq");
|
||||
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
|
||||
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
|
||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
@ -948,17 +945,30 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
||||
params.philox_offset2 = 0;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
|
||||
params.varlen_type = VarlenType::CompactVarlen;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
#if AOTRITON_ALWAYS_V3_API
|
||||
using sdp::aotriton_adapter::mklazy_empty_like;
|
||||
using sdp::aotriton_adapter::mklazy_fp32zeros;
|
||||
using sdp::aotriton_adapter::LazyTensorContext;
|
||||
LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" };
|
||||
LazyTensorContext lazy_dq_acc { .like_tensor = dq_padded, .tensor_name = "dq_acc" };
|
||||
params.D = mklazy_empty_like<2>(&lazy_delta);
|
||||
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
|
||||
#else
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
#endif
|
||||
err = aotriton::v3::flash::attn_bwd(params,
|
||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
||||
stream);
|
||||
#endif
|
||||
#endif // AOTRITON_ALWAYS_V3_API
|
||||
} else {
|
||||
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||
mk_aotensor(k_padded, "k"),
|
||||
|
32
aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h
Normal file
32
aten/src/ATen/native/transformers/hip/gemm_kernel_utils.h
Normal file
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
// This file is a trimmed version of cuda/mem_eff_attention/gemm_kernel_utils.h
|
||||
#pragma once
|
||||
|
||||
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK(TENSOR.is_contiguous());
|
||||
|
||||
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
|
||||
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
|
||||
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
|
||||
TORCH_CHECK( \
|
||||
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
|
||||
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
TORCH_CHECK( \
|
||||
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
TORCH_CHECK( \
|
||||
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
|
||||
}
|
185
cmake/External/aotriton.cmake
vendored
185
cmake/External/aotriton.cmake
vendored
@ -9,27 +9,122 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
# Replaces .ci/docker/aotriton_version.txt
|
||||
# Note packages information may have versions skipped (due to no ABI breaks)
|
||||
# But they must be listed from lower version to higher version
|
||||
set(__AOTRITON_VER "0.10b")
|
||||
set(__AOTRITON_VER "0.11b")
|
||||
set(__AOTRITON_MANYLINUX_LIST
|
||||
"manylinux_2_28" # rocm6.2
|
||||
"manylinux_2_28" # rocm6.3
|
||||
"manylinux_2_28" # rocm6.4
|
||||
"manylinux_2_28" # rocm6.5
|
||||
"manylinux_2_28" # rocm7.0
|
||||
)
|
||||
set(__AOTRITON_ROCM_LIST
|
||||
"rocm6.2"
|
||||
"rocm6.3"
|
||||
"rocm6.4"
|
||||
"rocm6.5"
|
||||
"rocm7.0"
|
||||
)
|
||||
set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477")
|
||||
set(__AOTRITON_CI_COMMIT "972223c501ffc22068bb035ac5d64cf54318d895")
|
||||
set(__AOTRITON_SHA256_LIST
|
||||
"861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3
|
||||
"acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4
|
||||
"7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm6.5
|
||||
"1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0
|
||||
"6cae3d5de75ee205d22e088f7dfaab1227056d02ea67f29ccdbc09f2be4e8c8f" # rocm6.2
|
||||
"72a153549ea20707331e8a1f1e3d1b8de2913f9d5af2b900c56235d578b57efe" # rocm6.3
|
||||
"c7f319dd7448cbbbab81889dd8a37d47dbc25ebcbd89760f09e6a0904e556393" # rocm6.4
|
||||
"a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0
|
||||
)
|
||||
set(__AOTRITON_IMAGE_LIST
|
||||
"amd-gfx90a"
|
||||
"amd-gfx942"
|
||||
"amd-gfx950"
|
||||
"amd-gfx11xx"
|
||||
"amd-gfx120x"
|
||||
)
|
||||
set(__AOTRITON_IMAGE_SHA256_LIST
|
||||
"c19a41c9480510ab32e6fb05e6ed0a3832d6b07634f050b836b760200befa735" # amd-gfx90a
|
||||
"3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942
|
||||
"27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950
|
||||
"ec134032087344176695505db659387374d1916adfee16f0db47dee38d9c8603" # amd-gfx11xx
|
||||
"fec05205747ff51649b1e151545267d5aa2037ba9d0338cad286882915b941b0" # amd-gfx120x
|
||||
)
|
||||
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
|
||||
set(__AOTRITON_Z "gz")
|
||||
function(aotriton_build_from_source noimage project)
|
||||
if(noimage)
|
||||
SET(RECURSIVE "OFF")
|
||||
else()
|
||||
SET(RECURSIVE "ON")
|
||||
endif()
|
||||
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
|
||||
ExternalProject_Add(${project}
|
||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||
GIT_SUBMODULES_RECURSE ${RECURSIVE}
|
||||
GIT_TAG ${__AOTRITON_CI_COMMIT}
|
||||
PREFIX ${__AOTRITON_EXTERN_PREFIX}
|
||||
CMAKE_CACHE_ARGS
|
||||
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
|
||||
-DCMAKE_INSTALL_PREFIX:FILEPATH=${__AOTRITON_INSTALL_DIR}
|
||||
CMAKE_ARGS
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
-DAOTRITON_GPU_BUILD_TIMEOUT=0
|
||||
-DAOTRITON_NO_PYTHON=ON
|
||||
-DAOTRITON_NOIMAGE_MODE=${noimage}
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||
USES_TERMINAL_DOWNLOAD TRUE
|
||||
USES_TERMINAL_CONFIGURE TRUE
|
||||
USES_TERMINAL_BUILD TRUE
|
||||
USES_TERMINAL_INSTALL TRUE
|
||||
)
|
||||
endfunction()
|
||||
|
||||
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
function(aotriton_download_runtime index project)
|
||||
list(GET __AOTRITON_ROCM_LIST ${index} __AOTRITON_ROCM)
|
||||
list(GET __AOTRITON_MANYLINUX_LIST ${index} __AOTRITON_MANYLINUX)
|
||||
list(GET __AOTRITON_SHA256_LIST ${index} __AOTRITON_SHA256)
|
||||
|
||||
string(CONCAT __AOTRITON_FILE "aotriton-"
|
||||
"${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}"
|
||||
"_${__AOTRITON_ARCH}-${__AOTRITON_ROCM}"
|
||||
"-shared.tar.${__AOTRITON_Z}")
|
||||
string(CONCAT __AOTRITON_URL
|
||||
"${__AOTRITON_BASE_URL}"
|
||||
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
|
||||
ExternalProject_Add(${project}
|
||||
URL "${__AOTRITON_URL}"
|
||||
URL_HASH SHA256=${__AOTRITON_SHA256}
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
|
||||
"${__AOTRITON_INSTALL_DIR}"
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||
)
|
||||
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
|
||||
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
|
||||
endfunction()
|
||||
|
||||
function(aotriton_download_image image project)
|
||||
list(FIND __AOTRITON_IMAGE_LIST ${image} index)
|
||||
list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_SHA256)
|
||||
|
||||
string(CONCAT __AOTRITON_FILE
|
||||
"aotriton-${__AOTRITON_VER}-images-"
|
||||
"${image}.tar.${__AOTRITON_Z}")
|
||||
string(CONCAT __AOTRITON_URL
|
||||
"${__AOTRITON_BASE_URL}"
|
||||
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
|
||||
ExternalProject_Add(${project}
|
||||
URL "${__AOTRITON_URL}"
|
||||
URL_HASH SHA256=${__AOTRITON_SHA256}
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
|
||||
"${__AOTRITON_INSTALL_DIR}"
|
||||
BUILD_BYPRODUCTS
|
||||
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
|
||||
)
|
||||
message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.")
|
||||
endfunction()
|
||||
|
||||
# Note it is INSTALL"ED"
|
||||
if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX})
|
||||
@ -40,66 +135,34 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
|
||||
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
|
||||
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
|
||||
ExternalProject_Add(aotriton_external
|
||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||
GIT_TAG ${__AOTRITON_CI_COMMIT}
|
||||
PREFIX ${__AOTRITON_EXTERN_PREFIX}
|
||||
INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
|
||||
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
-DAOTRITON_NO_PYTHON=ON
|
||||
-DAOTRITON_NO_SHARED=OFF
|
||||
# CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system.
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||
USES_TERMINAL_DOWNLOAD TRUE
|
||||
USES_TERMINAL_CONFIGURE TRUE
|
||||
USES_TERMINAL_BUILD TRUE
|
||||
USES_TERMINAL_INSTALL TRUE
|
||||
# INSTALL_COMMAND ${MAKE_COMMAND} install
|
||||
)
|
||||
aotriton_build_from_source(OFF aotriton_external)
|
||||
add_dependencies(__caffe2_aotriton aotriton_external)
|
||||
message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_EXTERN_PREFIX}")
|
||||
else()
|
||||
set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}")
|
||||
list(GET __AOTRITON_ROCM_LIST 0 __AOTRITON_ROCM_DEFAULT_STR)
|
||||
# Initialize __AOTRITON_ROCM to lowest version, in case all builds > system's ROCM
|
||||
string(SUBSTRING ${__AOTRITON_ROCM_DEFAULT_STR} 4 -1 __AOTRITON_ROCM)
|
||||
foreach(AOTRITON_ROCM_BUILD_STR IN LISTS __AOTRITON_ROCM_LIST)
|
||||
# len("rocm") == 4
|
||||
string(SUBSTRING ${AOTRITON_ROCM_BUILD_STR} 4 -1 AOTRITON_ROCM_BUILD)
|
||||
# Find the last build that <= system's ROCM
|
||||
# Assume the list is from lower to higher
|
||||
if(AOTRITON_ROCM_BUILD VERSION_GREATER __AOTRITON_SYSTEM_ROCM)
|
||||
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX)
|
||||
if(${__AOTRITON_RUNTIME_INDEX} LESS 0)
|
||||
message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \
|
||||
Build runtime from source")
|
||||
aotriton_build_from_source(ON aotriton_runtime)
|
||||
else()
|
||||
aotriton_download_runtime(${__AOTRITON_RUNTIME_INDEX} aotriton_runtime)
|
||||
endif()
|
||||
add_dependencies(__caffe2_aotriton aotriton_runtime)
|
||||
set(__AOTRITON_CHAINED_IMAGE "aotriton_runtime")
|
||||
foreach(image ${__AOTRITON_IMAGE_LIST})
|
||||
string(SUBSTRING ${image} 7 -1 gfx_pattern)
|
||||
string(REPLACE "x" "." gfx_regex ${gfx_pattern})
|
||||
foreach(target ${PYTORCH_ROCM_ARCH})
|
||||
if(target MATCHES ${gfx_regex})
|
||||
set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern})
|
||||
aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET})
|
||||
add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET})
|
||||
set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET})
|
||||
break()
|
||||
endif()
|
||||
set(__AOTRITON_ROCM ${AOTRITON_ROCM_BUILD})
|
||||
endforeach()
|
||||
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_ROCM}" __AOTRITON_ROCM_INDEX)
|
||||
list(GET __AOTRITON_SHA256_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_SHA256)
|
||||
list(GET __AOTRITON_MANYLINUX_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_MANYLINUX)
|
||||
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
string(CONCAT __AOTRITON_FILE "aotriton-"
|
||||
"${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}"
|
||||
"_${__AOTRITON_ARCH}-rocm${__AOTRITON_ROCM}"
|
||||
"-shared.tar.${__AOTRITON_Z}")
|
||||
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" # @lint-ignore
|
||||
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
|
||||
ExternalProject_Add(aotriton_external
|
||||
URL "${__AOTRITON_URL}"
|
||||
URL_HASH SHA256=${__AOTRITON_SHA256}
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball"
|
||||
"${__AOTRITON_INSTALL_DIR}"
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||
)
|
||||
add_dependencies(__caffe2_aotriton aotriton_external)
|
||||
message(STATUS "Using AOTriton from pre-compiled binary ${__AOTRITON_URL}.\
|
||||
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
|
||||
endforeach()
|
||||
endif()
|
||||
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
|
||||
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
|
||||
|
@ -51,6 +51,7 @@ from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
tf32_on_and_off,
|
||||
tf32_enabled,
|
||||
ROCM_VERSION,
|
||||
)
|
||||
|
||||
if TEST_FAIRSEQ:
|
||||
@ -339,7 +340,7 @@ class TestTransformers(NNTestCase):
|
||||
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
|
||||
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
|
||||
|
||||
@tf32_on_and_off(0.001)
|
||||
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
|
||||
@parametrize("attn_mask_dim", [2, 3, None])
|
||||
@parametrize("key_padding_mask_dim", [2, None])
|
||||
@parametrize("mask_dtype", [torch.bool, torch.float32])
|
||||
@ -523,7 +524,7 @@ class TestTransformers(NNTestCase):
|
||||
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
|
||||
self.assertEqual(fastpath_output_expanded, slowpath_output)
|
||||
|
||||
@tf32_on_and_off(0.001)
|
||||
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
|
||||
@parametrize("with_no_grad", [True, False])
|
||||
@parametrize("training", [True, False])
|
||||
@parametrize("enable_nested_tensor", [False])
|
||||
@ -1109,7 +1110,7 @@ class TestTransformers(NNTestCase):
|
||||
return_all_hiddens=False,
|
||||
)[0]
|
||||
|
||||
@tf32_on_and_off(0.003)
|
||||
@tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
|
||||
@parametrize("input_dim,attn_mask_dim,is_causal",
|
||||
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
|
||||
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
|
||||
@ -3425,6 +3426,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
'grad_value': 8.5,
|
||||
}
|
||||
if TEST_WITH_ROCM:
|
||||
fudge_factors['out'] = 5.0
|
||||
fudge_factors['grad_key'] = 45.0
|
||||
fudge_factors['grad_query'] = 360.0
|
||||
if seq_len_k >= 1024:
|
||||
@ -3434,6 +3436,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
fudge_factors['grad_query'] = 670.0
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName:
|
||||
fudge_factors['grad_value'] = 16.0
|
||||
|
||||
check_out_and_grad(
|
||||
(out_ref, out_lp_ref, out),
|
||||
@ -3546,6 +3550,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
"grad_attn_mask": 45.0,
|
||||
}
|
||||
if TEST_WITH_ROCM:
|
||||
fudge_factors['out'] = 6.0
|
||||
fudge_factors['grad_key'] = 45.0
|
||||
fudge_factors['grad_query'] = 360.0
|
||||
if seq_len_k >= 1024:
|
||||
@ -3556,7 +3561,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName:
|
||||
fudge_factors['grad_value'] = 12.0
|
||||
fudge_factors['grad_value'] = 16.0
|
||||
|
||||
check_out_and_grad(
|
||||
(out_ref, out_lp_ref, out),
|
||||
@ -3677,6 +3682,22 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
'grad_value': 4,
|
||||
}
|
||||
if TEST_WITH_ROCM:
|
||||
fudge_factors['grad_value'] = 6.0
|
||||
if TEST_WITH_CK:
|
||||
fudge_factors['out'] = 5.0
|
||||
fudge_factors['grad_key'] = 145.0
|
||||
fudge_factors['grad_query'] = 855.0 # ck min = 855.0
|
||||
if seq_len_k >= 1024:
|
||||
fudge_factors['grad_key'] = 70.0
|
||||
if seq_len_k >= 2048:
|
||||
fudge_factors['grad_key'] = 190.0
|
||||
fudge_factors['grad_query'] = 1550.0 # NEW CK MIN
|
||||
if seq_len_q >= 2048:
|
||||
fudge_factors['grad_query'] = 1100.0
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
else:
|
||||
fudge_factors['out'] = 6.0
|
||||
fudge_factors['grad_key'] = 45.0
|
||||
fudge_factors['grad_query'] = 360.0
|
||||
if seq_len_k >= 1024:
|
||||
@ -3840,15 +3861,19 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad)
|
||||
grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad)
|
||||
|
||||
check_out_and_grad(
|
||||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
fudge_factors = {
|
||||
'out': 3.0,
|
||||
'grad_query': 110.0,
|
||||
'grad_key': 8.0,
|
||||
'grad_value': 3.0,
|
||||
}
|
||||
if TEST_WITH_ROCM:
|
||||
fudge_factors['out'] = 6.0
|
||||
fudge_factors['grad_value'] = 6.0
|
||||
check_out_and_grad(
|
||||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors=fudge_factors
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@ -4484,10 +4509,6 @@ class TestAttnBias(NNTestCase):
|
||||
make_tensor = partial(
|
||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
||||
self.skipTest("No support for LOWER_RIGHT variant for now")
|
||||
return
|
||||
|
||||
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
|
||||
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
|
||||
make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
|
||||
@ -4518,10 +4539,6 @@ class TestAttnBias(NNTestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
||||
@skipIfTorchDynamo("This function already calls torch.compile.")
|
||||
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
|
||||
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
||||
self.skipTest("No support for LOWER_RIGHT variant for now")
|
||||
return
|
||||
|
||||
cnts = CompileCounterWithBackend("aot_eager")
|
||||
make_tensor = partial(
|
||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||
|
@ -49,25 +49,6 @@ aten = torch.ops.aten
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _need_scaling() -> bool:
|
||||
if hasattr(torch.version, "hip") and torch.version.hip is not None:
|
||||
gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
_is_ck_supported = False
|
||||
for arch in ["gfx942", "gfx950"]:
|
||||
if arch in gcn_arch_name:
|
||||
_is_ck_supported = True
|
||||
# Check the function exists
|
||||
_preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library
|
||||
_CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"]
|
||||
# Note: it is possible that CK is selected but not compiled in the binary.
|
||||
if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
|
||||
# Unsure about CK's behavior, keep logsumexp untouched
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class _DispatchMode(Enum):
|
||||
MONKEY_PATCH = auto()
|
||||
TORCH_FUNCTION = auto()
|
||||
@ -489,8 +470,6 @@ def _templated_ring_attention(
|
||||
is_causal=is_causal_behavior.value,
|
||||
**kwargs,
|
||||
)
|
||||
if _need_scaling():
|
||||
logsumexp *= 0.6931471805599453
|
||||
sdpa_merger.step(out, logsumexp, partial)
|
||||
|
||||
return *sdpa_merger.results(), *rest
|
||||
|
@ -24,6 +24,7 @@ else:
|
||||
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
|
||||
|
||||
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
|
||||
ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0))
|
||||
|
||||
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
|
||||
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
|
||||
@ -94,7 +95,6 @@ PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
|
||||
def evaluate_platform_supports_fp8():
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.hip:
|
||||
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
|
||||
archs = ['gfx94']
|
||||
if ROCM_VERSION >= (6, 3):
|
||||
archs.extend(['gfx120'])
|
||||
@ -123,7 +123,6 @@ def evaluate_platform_supports_fp8_grouped_gemm():
|
||||
def evaluate_platform_supports_mx_gemm():
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.hip:
|
||||
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
|
||||
if ROCM_VERSION >= (7, 0):
|
||||
return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
else:
|
||||
@ -238,7 +237,7 @@ def tf32_enabled():
|
||||
# if device is specified, it will check if device is cuda
|
||||
# if dtype is specified, it will check if dtype is float32 or complex64
|
||||
# tf32 and fp32 are different only when all the three checks pass
|
||||
def tf32_on_and_off(tf32_precision=1e-5):
|
||||
def tf32_on_and_off(tf32_precision=1e-5, only_if=True):
|
||||
def with_tf32_disabled(self, function_call):
|
||||
with tf32_off():
|
||||
function_call()
|
||||
@ -254,7 +253,7 @@ def tf32_on_and_off(tf32_precision=1e-5):
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
kwargs.update(zip(arg_names, args))
|
||||
cond = torch.cuda.is_tf32_supported()
|
||||
cond = torch.cuda.is_tf32_supported() and only_if
|
||||
if 'device' in kwargs:
|
||||
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
|
||||
if 'dtype' in kwargs:
|
||||
@ -268,7 +267,6 @@ def tf32_on_and_off(tf32_precision=1e-5):
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
|
||||
# This is a wrapper that wraps a test to run it with TF32 turned off.
|
||||
# This wrapper is designed to be used when a test uses matmul or convolutions
|
||||
# but the purpose of that test is not testing matmul or convolutions.
|
||||
|
@ -13,6 +13,17 @@ def has_triton_package() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
|
||||
major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
|
||||
return (major, minor)
|
||||
except ImportError:
|
||||
return fallback
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _device_supports_tma() -> bool:
|
||||
import torch
|
||||
|
Reference in New Issue
Block a user