[ROCm] Bump AOTriton to 0.11b (#161754)

Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.11b:

* Invoke AITER Assembly kernels on gfx942/gfx950 when inputs meet requirements
  - AITER ASM kernels deliver over 500TFLOPS training performance. See
    [AOTriton 0.11b Release Page](https://github.com/ROCm/aotriton/releases/tag/0.11b) for more
    details.
* Now returns natural based `logsumexp` tensor, matching CUDA's behavior
  - PR #156903 is reverted in this PR as well since it is not needed anymore.
* Enables `CausalVariant.LOWER_RIGHT`

The build system changes drastically along with new packaging scheme of
AOTriton 0.11

* AOTriton 0.11 packs GPU images separately from AOTriton runtime
* `aotriton.cmake` now selectively downloads image packs according to
  `PYTORCH_ROCM_ARCH`
* `aotriton.cmake` now only use pre-compiled runtime library that exactly
  matches the ROCM in the build environment. For PyTorch builds with ROCm
  versions not listed in the file, the build process will build AOTriton
  runtime without GPU images from source
  - This avoids any further ABI breaks like ROCM 6.4 -> 7.0
  - recursive git clone is disabled since building AOTriton runtime does not
    require submodules.

Bug fixes:

* Fix a kernel bug introduced when implementing SWA

Known Problems:

* gfx1100 target (Radeon RX 7000 Series) is moved back to experimental status
  due to accuracy issues. Triton compiler fixes are needed to restore the
  support status.
* Enabling TF32 tests affects accuracy for later non-TF32 tests on ROCM 7.0.
  This issue is under investigation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161754
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
This commit is contained in:
Xinya Zhang
2025-09-03 20:45:39 +00:00
committed by PyTorch MergeBot
parent 994f2a5dbc
commit 98efc9e93d
12 changed files with 492 additions and 177 deletions

View File

@ -1396,13 +1396,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
at::Tensor v_t = value.transpose(1, 2);
at::Tensor output_t = res.transpose(1, 2);
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
is_causal = true;
#if AOTRITON_V3_API == 0
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
#endif
}
at::Tensor atomic_counter;
if (is_causal) {
@ -1426,7 +1429,51 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
hipError_t err; // TODO: Error handling
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
using aotriton::v3::flash::WindowValue;
aotriton::v3::flash::attn_fwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.Sm_scale = softmax_scale;
params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2;
params.Out = mk_aotensor(output_t, "Out");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = dropout_p;
params.philox_seed_ptr = seed;
params.philox_offset1 = offset1;
params.philox_offset2 = offset2;
params.philox_seed_output = seed_output;
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
params.window_left = WindowValue::TopLeftAligned;
params.window_right = WindowValue::TopLeftAligned;
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
params.window_left = WindowValue::BottomRightAligned;
params.window_right = WindowValue::BottomRightAligned;
}
if (bias.has_value()) {
params.B = mk_aotensor(bias.value(), "bias");
}
if (seqstart_q.has_value()) {
params.varlen_type = VarlenType::CompactVarlen;
params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k");
} else {
params.varlen_type = VarlenType::None;
}
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
#endif // AOTRITON_V3_API
} else if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),

View File

@ -24,6 +24,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/empty_strided.h>
#include <ATen/ops/_cudnn_attention_backward.h>
@ -47,6 +48,7 @@
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#else
#include <ATen/native/transformers/hip/gemm_kernel_utils.h>
// MemoryEfficient Attention Specific Imports for ROCM
#ifndef DISABLE_AOTRITON
#include <ATen/native/transformers/hip/aotriton_adapter.h>
@ -544,12 +546,15 @@ _efficient_attention_backward(
}
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
bool is_causal;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
is_causal = true;
} else if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
if (static_cast<int64_t>(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) {
is_causal = false;
} else {
TORCH_CHECK(false, "[_efficient_attention_backward] Unsupported mask type in AOTriton, for now");
is_causal = true;
#if AOTRITON_V3_API == 0
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) != custom_mask_type) {
TORCH_CHECK(false, "[_efficient_attention_forward] Unsupported mask type on ROCM, for now");
}
#endif
}
at::Tensor q_t = query.permute({0,2,1,3});
at::Tensor k_t = key.permute({0,2,1,3});
@ -568,7 +573,62 @@ _efficient_attention_backward(
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef
#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
using aotriton::v3::flash::WindowValue;
aotriton::v3::flash::attn_bwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4;
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dk_t, "dk");
params.DV = mk_aotensor(dv_t, "dv");
params.DQ = mk_aotensor(dq_t, "dq");
params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4;
params.L = mk_aotensor<2>(softmax_lse, "L");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = float(dropout_p);
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) {
params.window_left = WindowValue::TopLeftAligned;
params.window_right = WindowValue::TopLeftAligned;
} else if (static_cast<int64_t>(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) {
params.window_left = WindowValue::BottomRightAligned;
params.window_right = WindowValue::BottomRightAligned;
}
#if AOTRITON_ALWAYS_V3_API
using sdp::aotriton_adapter::mklazy_empty_like;
using sdp::aotriton_adapter::mklazy_fp32zeros;
using sdp::aotriton_adapter::LazyTensorContext;
LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" };
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
params.D = mklazy_empty_like<2>(&lazy_delta);
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
#else
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
params.D = mk_aotensor<2>(delta, "delta");
#endif
if (cu_seqlens_q.has_value()) {
params.varlen_type = VarlenType::CompactVarlen;
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k");
} else {
params.varlen_type = VarlenType::None;
}
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
#endif // AOTRITON_V3_API
} else if (cu_seqlens_q.has_value()) {
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
// varlen aka Nested tensor
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),

View File

@ -16,6 +16,7 @@
#include <c10/util/irange.h>
#include <c10/util/Array.h>
#include <c10/util/Exception.h>
#include <c10/util/string_view.h>
#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
@ -25,9 +26,12 @@
#if USE_ROCM
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
#include <ATen/native/transformers/hip/aotriton_versions.h>
#include <aotriton/flash.h>
#define USE_ROCM_ATTENTION 1
#endif
#else
#define USE_ROCM_ATTENTION 0
#endif
// Avoid potential compiler -Wall -Werror complains undefined macro
@ -129,9 +133,24 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
template<bool caller_is_meff = false>
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
#if USE_ROCM_ATTENTION && AOTRITON_VERSION_MINOR >= 9
#if USE_ROCM_ATTENTION
// AOTriton 0.9+ supports head_dim up to 512
const auto max_size = c10::SymInt(512);
const static auto max_hdim = []() {
#if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11)
// gfx11xx only support hdim <= 256 on AOTriton 0.11
auto dprops = at::cuda::getCurrentDeviceProperties();
const c10::basic_string_view<char> arch(dprops->gcnArchName);
if (arch.starts_with("gfx11")) {
return 256;
}
#endif // AOTriton 0.11
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 9)
return 512;
#else
return 256;
#endif
}();
const auto max_size = c10::SymInt(max_hdim);
#else
// All head_dim sizes must be equal and less than 256
const auto max_size = c10::SymInt(256);

View File

@ -2,8 +2,12 @@
#ifdef USE_ROCM
// Expect to be included after headers of at::zeros_like and at::empty_like
#include <aotriton/dtypes.h>
#include <aotriton/util.h>
#include <aotriton/config.h>
#include <ATen/native/transformers/hip/aotriton_versions.h>
////////////////////////////////////////////////////////////////////////////////
// Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
@ -111,6 +115,61 @@ inline aotriton::TensorView<0> mk_atomictensor(const int32_t* ptr)
aotriton::DType::kInt32);
}
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11)
struct LazyTensorContext {
at::Tensor like_tensor;
std::string_view tensor_name;
at::Tensor tensor;
};
template<int kRank, bool kRequireZeros>
struct LazyTensorFunctions : public LazyTensorContext {
static aotriton::TensorView<kRank> acquire(void* cookie) {
auto ctx = (LazyTensorContext*)cookie;
if (!ctx->tensor.defined()) {
auto q = ctx->like_tensor;
if constexpr (kRequireZeros) {
ctx->tensor = at::zeros(q.sizes(),
q.options().dtype(at::kFloat));
} else {
ctx->tensor = at::empty_like(q);
}
}
return mk_aotensor<kRank>(ctx->tensor, ctx->tensor_name);
}
static void dispose(void* cookie) {
}
};
template<int kRank, bool kRequireZeros>
aotriton::LazyTensor<kRank> mklazy_common(LazyTensorContext* cookie)
{
using LTF = LazyTensorFunctions<kRank, kRequireZeros>;
return aotriton::LazyTensor<kRank> {
.cookie = cookie,
.acquire = &LTF::acquire,
.dispose = &LTF::dispose
};
}
template<int kRank>
auto mklazy_empty_like(LazyTensorContext* cookie)
{
return mklazy_common<kRank, false>(cookie);
}
// Note: this will not keep the original strides
template<int kRank>
auto mklazy_fp32zeros(LazyTensorContext* cookie)
{
return mklazy_common<kRank, true>(cookie);
}
#endif // >= 0.11
} // namespace aotriton_adapter
} // namespace sdp

View File

@ -0,0 +1,20 @@
#pragma once
#ifdef USE_ROCM
#define AOTRITON_VERSION_INT(x, y) (x * 100 + y)
#define AOTRITON_VERSION_CURRENT (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR)
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 11)
#define AOTRITON_ALWAYS_V3_API 1
#else
#define AOTRITON_ALWAYS_V3_API 0
#endif
#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 10)
#define AOTRITON_V3_API 1
#else
#define AOTRITON_V3_API 0
#endif
#endif

View File

@ -60,20 +60,13 @@
#include <c10/util/Exception.h>
// AOTriton headers
#include <aotriton/config.h>
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#if AOTRITON_VERSION_MINOR < 9
#if AOTRITON_VERSION_CURRENT < AOTRITON_VERSION_INT(0, 9)
#error "This adaptor code is only tested with AOTriton >= 0.9"
#endif
#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10
#define V3_API 1
#else
#define V3_API 0
#endif
namespace pytorch_flash {
namespace {
@ -93,15 +86,15 @@ calculate_swa(std::optional<int64_t> window_size_left,
int max_seqlen_q,
int max_seqlen_k,
bool is_causal) {
#if V3_API // SWA is exposed through V3 API
#if AOTRITON_V3_API // SWA is exposed through V3 API
bool needs_swa = false;
using aotriton::v3::flash::WindowValue;
// Default values when std::optional window_size_left/right have no value
int window_left = max_seqlen_q;
int window_right = max_seqlen_k;
if (is_causal) {
window_left = WindowValue::TopLeftAligned;
window_right = WindowValue::TopLeftAligned;
window_left = WindowValue::BottomRightAligned;
window_right = WindowValue::BottomRightAligned;
}
if (window_size_left.has_value() || window_size_right.has_value()) {
needs_swa = true;
@ -248,10 +241,10 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
seqlen_q,
seqlen_k,
is_causal);
#if V3_API
#if AOTRITON_V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
@ -278,8 +271,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
if (uses_swa) {
#if V3_API
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_fwd_params params;
@ -299,7 +292,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = CausalType::WindowedAttention;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
params.varlen_type = VarlenType::None;
params.window_left = window_left;
params.window_right = window_right;
@ -449,10 +442,10 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
max_seqlen_q,
max_seqlen_k,
is_causal);
#if V3_API
#if AOTRITON_V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
@ -482,8 +475,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
if (uses_swa) {
#if V3_API
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_fwd_params params;
@ -505,7 +498,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = CausalType::WindowedAttention;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
params.varlen_type = VarlenType::CompactVarlen;
params.window_left = window_left;
params.window_right = window_right;
@ -599,10 +592,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (is_causal){
TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels");
}
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
@ -654,10 +643,10 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
seqlen_q,
seqlen_k,
is_causal);
#if V3_API
#if AOTRITON_V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
@ -681,10 +670,9 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
hipError_t err; // TODO: Error handling
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
if (uses_swa) {
#if V3_API
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
// Fused BWD does not support SWA
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_bwd_params params;
@ -694,21 +682,32 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dq_t, "dq");
params.DV = mk_aotensor(dk_t, "dk");
params.DQ = mk_aotensor(dv_t, "dv");
params.DQ = mk_aotensor(dq_t, "dq");
params.DK = mk_aotensor(dk_t, "dk");
params.DV = mk_aotensor(dv_t, "dv");
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
params.D = mk_aotensor<2>(delta, "delta");
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = p_dropout;
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::None;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
params.window_left = window_left;
params.window_right = window_right;
params.varlen_type = VarlenType::None;
#if AOTRITON_ALWAYS_V3_API
using sdp::aotriton_adapter::mklazy_empty_like;
using sdp::aotriton_adapter::mklazy_fp32zeros;
using sdp::aotriton_adapter::LazyTensorContext;
LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" };
LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" };
params.D = mklazy_empty_like<2>(&lazy_delta);
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
#else
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
params.D = mk_aotensor<2>(delta, "delta");
#endif
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
@ -843,7 +842,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous();
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
at::Tensor q_padded, k_padded, v_padded;
q_padded = q.unsqueeze(0).transpose(1, 2);
@ -901,10 +899,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
max_seqlen_q,
max_seqlen_k,
is_causal);
#if V3_API
#if AOTRITON_V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// When AOTRITON_V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
@ -924,8 +922,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
hipError_t err; // TODO: Error handling
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
if (uses_swa) {
#if V3_API
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_bwd_params params;
@ -935,11 +933,10 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dq_padded, "dq");
params.DV = mk_aotensor(dk_padded, "dk");
params.DQ = mk_aotensor(dv_padded, "dv");
params.DK = mk_aotensor(dk_padded, "dk");
params.DV = mk_aotensor(dv_padded, "dv");
params.DQ = mk_aotensor(dq_padded, "dq");
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
params.D = mk_aotensor<2>(delta, "delta");
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
@ -948,17 +945,30 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = CausalType::WindowedAttention;
params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None;
params.varlen_type = VarlenType::CompactVarlen;
params.window_left = window_left;
params.window_right = window_right;
#if AOTRITON_ALWAYS_V3_API
using sdp::aotriton_adapter::mklazy_empty_like;
using sdp::aotriton_adapter::mklazy_fp32zeros;
using sdp::aotriton_adapter::LazyTensorContext;
LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" };
LazyTensorContext lazy_dq_acc { .like_tensor = dq_padded, .tensor_name = "dq_acc" };
params.D = mklazy_empty_like<2>(&lazy_delta);
params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc);
#else
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
params.D = mk_aotensor<2>(delta, "delta");
#endif
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
#endif
#endif // AOTRITON_ALWAYS_V3_API
} else {
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::cast_dtype;
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// This file is a trimmed version of cuda/mem_eff_attention/gemm_kernel_utils.h
#pragma once
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
TORCH_CHECK( \
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
TORCH_CHECK( \
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
TORCH_CHECK( \
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
}

View File

@ -9,27 +9,122 @@ if(NOT __AOTRITON_INCLUDED)
# Replaces .ci/docker/aotriton_version.txt
# Note packages information may have versions skipped (due to no ABI breaks)
# But they must be listed from lower version to higher version
set(__AOTRITON_VER "0.10b")
set(__AOTRITON_VER "0.11b")
set(__AOTRITON_MANYLINUX_LIST
"manylinux_2_28" # rocm6.2
"manylinux_2_28" # rocm6.3
"manylinux_2_28" # rocm6.4
"manylinux_2_28" # rocm6.5
"manylinux_2_28" # rocm7.0
)
set(__AOTRITON_ROCM_LIST
"rocm6.2"
"rocm6.3"
"rocm6.4"
"rocm6.5"
"rocm7.0"
)
set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477")
set(__AOTRITON_CI_COMMIT "972223c501ffc22068bb035ac5d64cf54318d895")
set(__AOTRITON_SHA256_LIST
"861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3
"acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4
"7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm6.5
"1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0
"6cae3d5de75ee205d22e088f7dfaab1227056d02ea67f29ccdbc09f2be4e8c8f" # rocm6.2
"72a153549ea20707331e8a1f1e3d1b8de2913f9d5af2b900c56235d578b57efe" # rocm6.3
"c7f319dd7448cbbbab81889dd8a37d47dbc25ebcbd89760f09e6a0904e556393" # rocm6.4
"a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0
)
set(__AOTRITON_IMAGE_LIST
"amd-gfx90a"
"amd-gfx942"
"amd-gfx950"
"amd-gfx11xx"
"amd-gfx120x"
)
set(__AOTRITON_IMAGE_SHA256_LIST
"c19a41c9480510ab32e6fb05e6ed0a3832d6b07634f050b836b760200befa735" # amd-gfx90a
"3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942
"27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950
"ec134032087344176695505db659387374d1916adfee16f0db47dee38d9c8603" # amd-gfx11xx
"fec05205747ff51649b1e151545267d5aa2037ba9d0338cad286882915b941b0" # amd-gfx120x
)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz")
function(aotriton_build_from_source noimage project)
if(noimage)
SET(RECURSIVE "OFF")
else()
SET(RECURSIVE "ON")
endif()
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
ExternalProject_Add(${project}
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_SUBMODULES_RECURSE ${RECURSIVE}
GIT_TAG ${__AOTRITON_CI_COMMIT}
PREFIX ${__AOTRITON_EXTERN_PREFIX}
CMAKE_CACHE_ARGS
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
-DCMAKE_INSTALL_PREFIX:FILEPATH=${__AOTRITON_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_GPU_BUILD_TIMEOUT=0
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NOIMAGE_MODE=${noimage}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
USES_TERMINAL_DOWNLOAD TRUE
USES_TERMINAL_CONFIGURE TRUE
USES_TERMINAL_BUILD TRUE
USES_TERMINAL_INSTALL TRUE
)
endfunction()
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
function(aotriton_download_runtime index project)
list(GET __AOTRITON_ROCM_LIST ${index} __AOTRITON_ROCM)
list(GET __AOTRITON_MANYLINUX_LIST ${index} __AOTRITON_MANYLINUX)
list(GET __AOTRITON_SHA256_LIST ${index} __AOTRITON_SHA256)
string(CONCAT __AOTRITON_FILE "aotriton-"
"${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}"
"_${__AOTRITON_ARCH}-${__AOTRITON_ROCM}"
"-shared.tar.${__AOTRITON_Z}")
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
)
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
endfunction()
function(aotriton_download_image image project)
list(FIND __AOTRITON_IMAGE_LIST ${image} index)
list(GET __AOTRITON_IMAGE_SHA256_LIST ${index} __AOTRITON_SHA256)
string(CONCAT __AOTRITON_FILE
"aotriton-${__AOTRITON_VER}-images-"
"${image}.tar.${__AOTRITON_Z}")
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
)
message(STATUS "Download AOTriton pre-compiled GPU images from ${__AOTRITON_URL}.")
endfunction()
# Note it is INSTALL"ED"
if(DEFINED ENV{AOTRITON_INSTALLED_PREFIX})
@ -40,66 +135,34 @@ if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_TAG ${__AOTRITON_CI_COMMIT}
PREFIX ${__AOTRITON_EXTERN_PREFIX}
INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NO_SHARED=OFF
# CONFIGURE_COMMAND ""
BUILD_COMMAND "" # No build, install command will repeat the build process due to problems in the build system.
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
USES_TERMINAL_DOWNLOAD TRUE
USES_TERMINAL_CONFIGURE TRUE
USES_TERMINAL_BUILD TRUE
USES_TERMINAL_INSTALL TRUE
# INSTALL_COMMAND ${MAKE_COMMAND} install
)
aotriton_build_from_source(OFF aotriton_external)
add_dependencies(__caffe2_aotriton aotriton_external)
message(STATUS "Using AOTriton compiled from source directory ${__AOTRITON_EXTERN_PREFIX}")
else()
set(__AOTRITON_SYSTEM_ROCM "${HIP_VERSION_MAJOR}.${HIP_VERSION_MINOR}")
list(GET __AOTRITON_ROCM_LIST 0 __AOTRITON_ROCM_DEFAULT_STR)
# Initialize __AOTRITON_ROCM to lowest version, in case all builds > system's ROCM
string(SUBSTRING ${__AOTRITON_ROCM_DEFAULT_STR} 4 -1 __AOTRITON_ROCM)
foreach(AOTRITON_ROCM_BUILD_STR IN LISTS __AOTRITON_ROCM_LIST)
# len("rocm") == 4
string(SUBSTRING ${AOTRITON_ROCM_BUILD_STR} 4 -1 AOTRITON_ROCM_BUILD)
# Find the last build that <= system's ROCM
# Assume the list is from lower to higher
if(AOTRITON_ROCM_BUILD VERSION_GREATER __AOTRITON_SYSTEM_ROCM)
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_SYSTEM_ROCM}" __AOTRITON_RUNTIME_INDEX)
if(${__AOTRITON_RUNTIME_INDEX} LESS 0)
message(STATUS "Cannot find AOTriton runtime for ROCM ${__AOTRITON_SYSTEM_ROCM}. \
Build runtime from source")
aotriton_build_from_source(ON aotriton_runtime)
else()
aotriton_download_runtime(${__AOTRITON_RUNTIME_INDEX} aotriton_runtime)
endif()
add_dependencies(__caffe2_aotriton aotriton_runtime)
set(__AOTRITON_CHAINED_IMAGE "aotriton_runtime")
foreach(image ${__AOTRITON_IMAGE_LIST})
string(SUBSTRING ${image} 7 -1 gfx_pattern)
string(REPLACE "x" "." gfx_regex ${gfx_pattern})
foreach(target ${PYTORCH_ROCM_ARCH})
if(target MATCHES ${gfx_regex})
set(__AOTRITON_DOWNLOAD_TARGET aotriton_image_${gfx_pattern})
aotriton_download_image(${image} ${__AOTRITON_DOWNLOAD_TARGET})
add_dependencies(${__AOTRITON_CHAINED_IMAGE} ${__AOTRITON_DOWNLOAD_TARGET})
set(__AOTRITON_CHAINED_IMAGE ${__AOTRITON_DOWNLOAD_TARGET})
break()
endif()
set(__AOTRITON_ROCM ${AOTRITON_ROCM_BUILD})
endforeach()
list(FIND __AOTRITON_ROCM_LIST "rocm${__AOTRITON_ROCM}" __AOTRITON_ROCM_INDEX)
list(GET __AOTRITON_SHA256_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_SHA256)
list(GET __AOTRITON_MANYLINUX_LIST ${__AOTRITON_ROCM_INDEX} __AOTRITON_MANYLINUX)
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
string(CONCAT __AOTRITON_FILE "aotriton-"
"${__AOTRITON_VER}-${__AOTRITON_MANYLINUX}"
"_${__AOTRITON_ARCH}-rocm${__AOTRITON_ROCM}"
"-shared.tar.${__AOTRITON_Z}")
string(CONCAT __AOTRITON_URL "https://github.com/ROCm/aotriton/releases/download/" # @lint-ignore
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
ExternalProject_Add(aotriton_external
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_tarball"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
)
add_dependencies(__caffe2_aotriton aotriton_external)
message(STATUS "Using AOTriton from pre-compiled binary ${__AOTRITON_URL}.\
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
endforeach()
endif()
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)

View File

@ -51,6 +51,7 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
tf32_on_and_off,
tf32_enabled,
ROCM_VERSION,
)
if TEST_FAIRSEQ:
@ -339,7 +340,7 @@ class TestTransformers(NNTestCase):
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
@tf32_on_and_off(0.001)
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
@parametrize("attn_mask_dim", [2, 3, None])
@parametrize("key_padding_mask_dim", [2, None])
@parametrize("mask_dtype", [torch.bool, torch.float32])
@ -523,7 +524,7 @@ class TestTransformers(NNTestCase):
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
self.assertEqual(fastpath_output_expanded, slowpath_output)
@tf32_on_and_off(0.001)
@tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [False])
@ -1109,7 +1110,7 @@ class TestTransformers(NNTestCase):
return_all_hiddens=False,
)[0]
@tf32_on_and_off(0.003)
@tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
@ -3425,6 +3426,7 @@ class TestSDPACudaOnly(NNTestCase):
'grad_value': 8.5,
}
if TEST_WITH_ROCM:
fudge_factors['out'] = 5.0
fudge_factors['grad_key'] = 45.0
fudge_factors['grad_query'] = 360.0
if seq_len_k >= 1024:
@ -3434,6 +3436,8 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors['grad_query'] = 670.0
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName:
fudge_factors['grad_value'] = 16.0
check_out_and_grad(
(out_ref, out_lp_ref, out),
@ -3546,6 +3550,7 @@ class TestSDPACudaOnly(NNTestCase):
"grad_attn_mask": 45.0,
}
if TEST_WITH_ROCM:
fudge_factors['out'] = 6.0
fudge_factors['grad_key'] = 45.0
fudge_factors['grad_query'] = 360.0
if seq_len_k >= 1024:
@ -3556,7 +3561,7 @@ class TestSDPACudaOnly(NNTestCase):
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
if "gfx95" in torch.cuda.get_device_properties(0).gcnArchName:
fudge_factors['grad_value'] = 12.0
fudge_factors['grad_value'] = 16.0
check_out_and_grad(
(out_ref, out_lp_ref, out),
@ -3677,6 +3682,22 @@ class TestSDPACudaOnly(NNTestCase):
'grad_value': 4,
}
if TEST_WITH_ROCM:
fudge_factors['grad_value'] = 6.0
if TEST_WITH_CK:
fudge_factors['out'] = 5.0
fudge_factors['grad_key'] = 145.0
fudge_factors['grad_query'] = 855.0 # ck min = 855.0
if seq_len_k >= 1024:
fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048:
fudge_factors['grad_key'] = 190.0
fudge_factors['grad_query'] = 1550.0 # NEW CK MIN
if seq_len_q >= 2048:
fudge_factors['grad_query'] = 1100.0
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
else:
fudge_factors['out'] = 6.0
fudge_factors['grad_key'] = 45.0
fudge_factors['grad_query'] = 360.0
if seq_len_k >= 1024:
@ -3840,15 +3861,19 @@ class TestSDPACudaOnly(NNTestCase):
grads_ref_lp = torch.autograd.grad(out_lp_ref, (query, key, value), upstream_grad)
grads_ref = torch.autograd.grad(out_ref, (query_ref, key_ref, value_ref), upstream_grad)
check_out_and_grad(
(out_ref, out_lp_ref, out),
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors = {
'out': 3.0,
'grad_query': 110.0,
'grad_key': 8.0,
'grad_value': 3.0,
}
if TEST_WITH_ROCM:
fudge_factors['out'] = 6.0
fudge_factors['grad_value'] = 6.0
check_out_and_grad(
(out_ref, out_lp_ref, out),
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors=fudge_factors
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@ -4484,10 +4509,6 @@ class TestAttnBias(NNTestCase):
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True
)
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now")
return
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
make_kv_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim))
@ -4518,10 +4539,6 @@ class TestAttnBias(NNTestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
@skipIfTorchDynamo("This function already calls torch.compile.")
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]):
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now")
return
cnts = CompileCounterWithBackend("aot_eager")
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True

View File

@ -49,25 +49,6 @@ aten = torch.ops.aten
logger = logging.getLogger(__name__)
def _need_scaling() -> bool:
if hasattr(torch.version, "hip") and torch.version.hip is not None:
gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName
_is_ck_supported = False
for arch in ["gfx942", "gfx950"]:
if arch in gcn_arch_name:
_is_ck_supported = True
# Check the function exists
_preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library
_CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"]
# Note: it is possible that CK is selected but not compiled in the binary.
if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND:
# Unsure about CK's behavior, keep logsumexp untouched
return False
return True
else:
return False
class _DispatchMode(Enum):
MONKEY_PATCH = auto()
TORCH_FUNCTION = auto()
@ -489,8 +470,6 @@ def _templated_ring_attention(
is_causal=is_causal_behavior.value,
**kwargs,
)
if _need_scaling():
logsumexp *= 0.6931471805599453
sdpa_merger.step(out, logsumexp, partial)
return *sdpa_merger.results(), *rest

View File

@ -24,6 +24,7 @@ else:
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0))
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
@ -94,7 +95,6 @@ PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
def evaluate_platform_supports_fp8():
if torch.cuda.is_available():
if torch.version.hip:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
archs = ['gfx94']
if ROCM_VERSION >= (6, 3):
archs.extend(['gfx120'])
@ -123,7 +123,6 @@ def evaluate_platform_supports_fp8_grouped_gemm():
def evaluate_platform_supports_mx_gemm():
if torch.cuda.is_available():
if torch.version.hip:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])
if ROCM_VERSION >= (7, 0):
return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName
else:
@ -238,7 +237,7 @@ def tf32_enabled():
# if device is specified, it will check if device is cuda
# if dtype is specified, it will check if dtype is float32 or complex64
# tf32 and fp32 are different only when all the three checks pass
def tf32_on_and_off(tf32_precision=1e-5):
def tf32_on_and_off(tf32_precision=1e-5, only_if=True):
def with_tf32_disabled(self, function_call):
with tf32_off():
function_call()
@ -254,7 +253,7 @@ def tf32_on_and_off(tf32_precision=1e-5):
@functools.wraps(f)
def wrapped(*args, **kwargs):
kwargs.update(zip(arg_names, args))
cond = torch.cuda.is_tf32_supported()
cond = torch.cuda.is_tf32_supported() and only_if
if 'device' in kwargs:
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
if 'dtype' in kwargs:
@ -268,7 +267,6 @@ def tf32_on_and_off(tf32_precision=1e-5):
return wrapped
return wrapper
# This is a wrapper that wraps a test to run it with TF32 turned off.
# This wrapper is designed to be used when a test uses matmul or convolutions
# but the purpose of that test is not testing matmul or convolutions.

View File

@ -13,6 +13,17 @@ def has_triton_package() -> bool:
return False
@functools.cache
def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
try:
import triton # noqa: F401
major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
return (major, minor)
except ImportError:
return fallback
@functools.cache
def _device_supports_tma() -> bool:
import torch