From 424156c26c5a80c9221197c09c2d1c12006f11d1 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 6 Dec 2024 21:45:18 +0000 Subject: [PATCH] [ROCm] Update to AOTriton 0.8b (#140172) Notable new features for SDPA operators on AMD systems from AOTriton 0.8b: 1. Nestedtensor support; 2. MQA/GQA support; 3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases; + The kernel should use top-left alignment, bottom right alignment will be added later 4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status. However, users are strongly recommended to update to ROCM 6.2.4, notably for its firmware updates. Related unit tests are enabled as well. Notable related changes from AOTriton 0.8b: 1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`; 2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB 3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency. + Should not be a problem, since `lzma` is part of Python Standard Library Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- .ci/docker/aotriton_version.txt | 8 +- .ci/manywheel/build_common.sh | 4 +- .ci/manywheel/build_libtorch.sh | 4 +- .ci/manywheel/build_rocm.sh | 14 + .../native/transformers/cuda/attention.cu | 57 ++- .../transformers/cuda/attention_backward.cu | 66 ++- .../native/transformers/cuda/sdp_utils.cpp | 53 +-- .../transformers/hip/flash_attn/flash_api.hip | 416 ++++++++++++++---- setup.py | 7 + test/inductor/test_flex_attention.py | 5 - test/inductor/test_flex_decoding.py | 5 +- test/test_native_mha.py | 4 +- test/test_transformers.py | 38 +- .../_internal/common_methods_invocations.py | 17 +- 14 files changed, 498 insertions(+), 200 deletions(-) diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt index 602b77d3b853..0bb9b7f4bbbf 100644 --- a/.ci/docker/aotriton_version.txt +++ b/.ci/docker/aotriton_version.txt @@ -1,5 +1,5 @@ -0.7b -manylinux_2_17 +0.8b +manylinux_2_28 rocm6.2 -9be04068c3c0857a4cfd17d7e39e71d0423ebac2 -3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291 +6f8cbcac8a92775291bb1ba8f514d4beb350baf4 +e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8 diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 7381984df0e9..004a107df948 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -253,11 +253,11 @@ make_wheel_record() { FPATH=$1 if echo $FPATH | grep RECORD >/dev/null 2>&1; then # if the RECORD file, then - echo "$FPATH,," + echo "\"$FPATH\",," else HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g') FSIZE=$(ls -nl $FPATH | awk '{print $5}') - echo "$FPATH,sha256=$HASH,$FSIZE" + echo "\"$FPATH\",sha256=$HASH,$FSIZE" fi } diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index fd330f6435c8..41d8c4e15272 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -225,11 +225,11 @@ make_wheel_record() { FPATH=$1 if echo $FPATH | grep RECORD >/dev/null 2>&1; then # if the RECORD file, then - echo "$FPATH,," + echo "\"$FPATH\",," else HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g') FSIZE=$(ls -nl $FPATH | awk '{print $5}') - echo "$FPATH,sha256=$HASH,$FSIZE" + echo "\"$FPATH\",sha256=$HASH,$FSIZE" fi } diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh index b9bb67631927..32fd1435caf7 100755 --- a/.ci/manywheel/build_rocm.sh +++ b/.ci/manywheel/build_rocm.sh @@ -266,6 +266,20 @@ RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC)) DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/}) DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) +# PyTorch 2.6+ (AOTriton 0.8b+) +# AKS = "AOTriton Kernel Storage", a file format to store GPU kernels compactly +if (( $(echo "${PYTORCH_VERSION} 2.6" | awk '{print ($1 >= $2)}') )); then + LIBAOTRITON_DIR=$(find "$ROCM_HOME/lib/" -name "libaotriton_v2.so" -printf '%h\n') + if [[ -z ${LIBAOTRITON_DIR} ]]; then + LIBAOTRITON_DIR=$(find "$ROCM_HOME/" -name "libaotriton_v2.so" -printf '%h\n') + fi + AKS_FILES=($(find "${LIBAOTRITON_DIR}/aotriton.images" -type f -name '*.aks?' -printf '%P\n')) + AKS_SRC="${LIBAOTRITON_DIR}/aotriton.images" + AKS_DST="lib/aotriton.images" + DEPS_AUX_SRCLIST+=(${AKS_FILES[@]/#/${AKS_SRC}/}) + DEPS_AUX_DSTLIST+=(${AKS_FILES[@]/#/${AKS_DST}/}) +fi + echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}" SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index e87c6271ef12..7fe7ee7a1ba1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1160,6 +1160,7 @@ std::tuple _efficient_ const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); using aotriton::v2::flash::attn_fwd; + using aotriton::v2::flash::attn_fwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -1172,22 +1173,46 @@ std::tuple _efficient_ auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); hipError_t err; // TODO: Error handling - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, - softmax_scale, - mk_aotensor<2>(softmax_lse, "M"), - mk_aotensor(output_t, "Out"), - dropout_p, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); + if (seqstart_q.has_value()) { + // varlen aka nested tensor + err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + } else { + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + } if (!compute_logsumexp) { // Set the tensor to empty when compute_logsumexp is false logsumexp = at::empty( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index ffe51b373265..09799ff125d1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -441,29 +441,57 @@ _efficient_attention_backward( hipError_t err; using aotriton::v2::flash::attn_bwd; + using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - mk_aotensor<2>(delta, "delta"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); + if (cu_seqlens_q.has_value()) { + // varlen aka Nested tensor + err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } #else at::Tensor workspace; cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index df584fe6ff24..a60cbe5ea061 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -103,6 +103,10 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { return matmul_alignment_mn; } +// On ROCM, ME and FA share the backend, and hence they share the checking +// function for fundamental limitations by the GPU kernel +// caller_is_meff is added to make the TORCH_WARN message showing the correct result +template bool check_head_dim_size_flash(sdp_params const& params, bool debug) { // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); @@ -114,7 +118,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { if (!(same_head_dim_size && (query_size_last <= max_size))) { if (debug) { TORCH_WARN( - "Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.", + caller_is_meff ? "Efficient attention on ROCM" : "Flash attention", + " requires q,k,v to have the same last dimension and to be less than or equal to 256.", " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", @@ -128,6 +133,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { return true; } +// See check_head_dim_size_flash above for the purpose of caller_is_meff +template bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { const auto max_size = c10::SymInt(256); const auto query_size_last = params.query.sym_size(-1); @@ -139,7 +146,9 @@ bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { (query_size_last <= max_size))) { if (debug) { TORCH_WARN( - "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", + "For NestedTensor inputs,", + caller_is_meff ? " Efficient attention on ROCM " : " Flash attention", + " requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", @@ -208,7 +217,6 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug // Check that the gpu is capable of running flash attention using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; - auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM #if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -220,19 +228,11 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug } return false; } - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { - static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; - if (!enable_navi3x) { - TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental." - " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); - return false; - } - } #else return false; #endif #else + auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -252,7 +252,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; - auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM #if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -264,19 +263,11 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { - static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; - if (!enable_navi3x) { - TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental." - " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); - return false; - } - } #else return false; #endif #else + auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -615,7 +606,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { check_all_tensors_on_device, check_tensor_shapes, check_for_attn_mask, - check_head_dim_size_flash, + check_head_dim_size_flash, check_flash_attention_hardware_support, check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89, check_flash_causal_non_square_seqlens, @@ -629,7 +620,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { if (has_for_nested_inputs(params)) { constexpr auto nested_constraints = array_of( check_batch_size_nested, - check_head_dim_size_flash_nested, + check_head_dim_size_flash_nested, check_for_seq_len_0_nested_tensor); for (auto& constraint : nested_constraints) { if (!constraint(params, debug)) { @@ -637,11 +628,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { } } } -#if USE_ROCM - constexpr bool backend_supports_grouped_query_attention = false; -#else constexpr bool backend_supports_grouped_query_attention = true; -#endif if (has_only_dense_inputs(params)) { constexpr auto dense_constraints = array_of( check_batch_size_and_num_heads_dense, @@ -679,7 +666,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { check_mem_efficient_hardware_support, check_tensor_shapes, #ifdef USE_ROCM - check_head_dim_size_flash + check_head_dim_size_flash #else check_head_dim_size_mem_efficient #endif @@ -691,12 +678,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } if (has_for_nested_inputs(params)) { -#ifdef USE_ROCM - TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention."); - return false; -#endif constexpr auto nested_constraints = array_of( +#ifndef USE_ROCM // ME and FA shares backend on ROCM and thus supports training check_requires_grad_and_nested, +#else // Meanwhile ME on ROCM share the limits of FA about head dimensions + check_head_dim_size_flash_nested, +#endif check_batch_size_nested, check_for_seq_len_0_nested_tensor); for (auto& constraint : nested_constraints) { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 7191a5f13331..dcbac79e317d 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -77,6 +77,37 @@ void check_gpu_arch(hipStream_t stream) { } } +// We want to checkpoint and save the RNG state for backward if dropout +// We get the default generator and return the seed and offset which will +// be used in the backward function +std::tuple +prepare_philox_arguments(float p_dropout, int64_t counter_offset) { + at::Tensor seed_t, offset_t; + at::PhiloxCudaState philox_state; + bool use_philox_state = false; + if (p_dropout <= 0.0) { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + return { seed_t, offset_t, philox_state, use_philox_state }; + } + auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + std::lock_guard lock(gen->mutex_); + philox_state = gen->philox_cuda_state(counter_offset); + if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); + } else { + // See Note [CUDA Graph-safe RNG states] about the design + use_philox_state = true; + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } + + return { seed_t, offset_t, philox_state, use_philox_state }; +} + + } #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") @@ -158,44 +189,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::Tensor seed_t, offset_t; - - at::PhiloxCudaState philox_state; - bool use_philox_state = false; - if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = batch_size * num_heads * 32; - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); - } else { - // See Note [CUDA Graph-safe RNG states] about the design - use_philox_state = true; - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); // Transpose tensors to meet AOTriton's Flash API at::Tensor q_t = q_padded.permute({0,2,1,3}); @@ -215,7 +211,6 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; - using aotriton::TensorView; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -266,16 +261,150 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int window_size_right, const bool return_softmax, std::optional gen_) { + TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); + const bool paged_KV = block_table_.has_value(); + TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt"); + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); - TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm"); + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); - at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat)); - at::Tensor p = at::empty({}, at::dtype(at::kFloat)); - at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor out = at::empty({}, at::dtype(at::kFloat)); + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); - return {out, q, k, v, softmax_lse, seed_t, offset_t, p}; + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { + is_causal = false; + } // causal=true is the same as causal=false in this case + + at::Tensor temp_q = q; + const int total_q = temp_q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= max_seqlen_k) { + window_size_right = -1; + } + + CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + // AOTriton's varlen API needs input shapes be + // (1, num_heads, total sequence lenght, head dimension) + at::Tensor q_padded, k_padded, v_padded; + at::Tensor out, out_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + } else { + out = at::empty_like(q); + } + out_padded = out.unsqueeze(0).transpose(1, 2); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = head_size_og; + + auto opts = q.options(); + + auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q}); + at::Tensor softmax_fa_t; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + softmax_fa_t = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { + softmax_fa_t.zero_(); + } + } + + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); + + if (max_seqlen_k > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_fwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); + err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(out_padded, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + return {out, q, k, v, softmax_lse, seed_t, offset_t, softmax_fa_t}; } std::tuple @@ -297,6 +426,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); check_gpu_arch(stream); @@ -341,8 +473,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); @@ -381,23 +511,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv = at::empty_like(k); } - // const at::Tensor& dout_padded = dout; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - auto opts = q.options(); - auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - - at::Tensor dk_expanded, dv_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor q_t = q.permute({0,2,1,3}); at::Tensor k_t = k.permute({0,2,1,3}); @@ -440,14 +555,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si stream); } - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - } return { dq, dk, dv, softmax_d }; -#undef CALL_BWD_DROPOUT -#undef CALL_BWD } std::tuple @@ -473,13 +581,173 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size int window_size_right, const bool deterministic, const at::Tensor philox_seed, - const at::Tensor philox_offset) { - TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); + const at::Tensor philox_offset) +{ + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); - at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat)); + if (is_causal) { + window_size_right = 0; + } + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); - return { q, k, v, softmax_d }; + bool is_dropout = p_dropout > 0.0; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + + at::Tensor q_padded, k_padded, v_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + at::Tensor out_t, dout_t; + out_t = out.unsqueeze(0).transpose(1, 2); + dout_t = dout.unsqueeze(0).transpose(1, 2); + + at::Tensor dq, dk, dv; + at::Tensor dq_padded, dk_padded, dv_padded; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = at::empty_like(v); + } + dq_padded = dq.unsqueeze(0).transpose(1, 2); + dk_padded = dk.unsqueeze(0).transpose(1, 2); + dv_padded = dv.unsqueeze(0).transpose(1, 2); + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + if( zero_tensors ) { + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + at::PhiloxCudaState philox_args; + if (is_dropout) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) + { + philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); + } else { // dropout + capture + philox_args = at::PhiloxCudaState( + philox_seed.data_ptr(), philox_offset.data_ptr(), 0); + } + } + if (max_seqlen_q > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_bwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_padded, "dq"), + mk_aotensor(dk_padded, "dk"), + mk_aotensor(dv_padded, "dv"), + empty_bias, + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + return { dq, dk, dv, softmax_d }; } + } // namespace pytorch_fmha #endif diff --git a/setup.py b/setup.py index 24c8c19b9bb8..5ed97c6df7b6 100644 --- a/setup.py +++ b/setup.py @@ -1373,6 +1373,13 @@ def main(): "lib/*.lib", ] ) + aotriton_image_path = os.path.join(lib_path, "aotriton.images") + aks2_files = [] + for root, dirs, files in os.walk(aotriton_image_path): + subpath = os.path.relpath(root, start=aotriton_image_path) + for fn in files: + aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn)) + torch_package_data += aks2_files if get_cmake_cache_vars()["USE_TENSORPIPE"]: torch_package_data.extend( [ diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 6290d3719d85..bc4588ad54c2 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -389,8 +389,6 @@ class TestFlexAttention(InductorTestCase): KV_S = Q_S if V_D is None: V_D = Q_D - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) @@ -565,9 +563,6 @@ class TestFlexAttention(InductorTestCase): V_D: int = D, block_mask: Optional[BlockMask] = None, ): - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") - assert Q_H % KV_H == 0 q = torch.randn( diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 1e8c0ada855f..c5d10c069f37 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -22,7 +22,7 @@ from torch.nn.attention.flex_attention import ( from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 -from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM +from torch.testing._internal.common_utils import skipIfRocm from torch.utils._triton import has_triton @@ -492,9 +492,6 @@ class TestFlexDecoding(InductorTestCase): V_D: int = D, block_mask: Optional[BlockMask] = None, ): - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") - assert Q_H % KV_H == 0 q = torch.randn( diff --git a/test/test_native_mha.py b/test/test_native_mha.py index 307115147852..0e4489ab135c 100644 --- a/test/test_native_mha.py +++ b/test/test_native_mha.py @@ -277,8 +277,8 @@ class TestMHADeviceType(TestCase): def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding, pad_all, fused): if TEST_WITH_ROCM: - if use_nt: - self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") + if use_nt and use_padding and pad_all: + self.skipTest("Large numerical errors on ROCM to investigate.") if use_padding and not pad_all and fused: self.skipTest("Large numerical errors on ROCM to investigate.") for need_weights in (False, not pad_all): diff --git a/test/test_transformers.py b/test/test_transformers.py index 1bd8f7482f33..325f35c6f1a9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1630,7 +1630,6 @@ class TestSDPAFailureModes(NNTestCase): q, k, v, None, 0.0, False)) @onlyCUDA - @skipIfRocm(msg='enable_gqa=True unsupported') @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel): @@ -1646,7 +1645,6 @@ class TestSDPAFailureModes(NNTestCase): is_causal=False, enable_gqa=True) @onlyCPU - @skipIfRocm(msg='enable_gqa=True unsupported') def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device): rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) @@ -1784,7 +1782,6 @@ class TestSDPAFailureModes(NNTestCase): self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) @onlyCUDA - @skipIfRocm # Missing EFFICIENT_ATTENTION @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_error_cases(self, device): # one of k,v needs to be broadcasted and other has non consistent seq_len dim @@ -2485,7 +2482,7 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention + @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2526,7 +2523,7 @@ class TestSDPACudaOnly(NNTestCase): with self.assertRaisesRegex(RuntimeError, "No available kernel."): o = torch.nn.functional.scaled_dot_product_attention(q, k, v) - @skipIfRocm # No cuDNN Attention + @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -2713,8 +2710,6 @@ class TestSDPACudaOnly(NNTestCase): @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): - if TEST_WITH_ROCM and type == 'nested': - self.skipTest("ROCM does not support efficient attention on nested tensors, for now") make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 @@ -2743,7 +2738,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @skipIfRocm # Missing nested and EFFICIENT_ATTENTION @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @@ -2910,7 +2904,6 @@ class TestSDPACudaOnly(NNTestCase): atol = 9e-4 if dtype == torch.float16 else 9e-3 self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol) - @skipIfRocm # Missing nested and EFFICIENT_ATTENTION @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA") @parametrize("type", ["dense", "nested"]) def test_fused_sdp_choice(self, device, type: str): @@ -3116,8 +3109,6 @@ class TestSDPACudaOnly(NNTestCase): return if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: torch.cuda.empty_cache() # Prevent memory fragmentation - if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: - self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -3227,9 +3218,6 @@ class TestSDPACudaOnly(NNTestCase): if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return - if TEST_WITH_ROCM and dtype == torch.float32: - unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.") - return if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: torch.cuda.empty_cache() # Prevent memory fragmentation seed = 42 @@ -3323,7 +3311,7 @@ class TestSDPACudaOnly(NNTestCase): @parametrize("dropout_p", [0.0, 0.22, 0.48]) @parametrize("dtype", [torch.float16, torch.bfloat16]) @parametrize("scale", [None, "l1"]) - @parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False]) + @parametrize("enable_gqa", [True, False]) @parametrize("n_heads", [[16, 8], [10, 2]]) @tf32_enabled() def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, @@ -3488,8 +3476,6 @@ class TestSDPACudaOnly(NNTestCase): if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") - if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: - self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 n_heads = 4 @@ -3593,7 +3579,6 @@ class TestSDPACudaOnly(NNTestCase): } ) - @skipIfRocm # Nested Tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) @@ -3627,7 +3612,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2) - @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) @@ -3653,6 +3637,9 @@ class TestSDPACudaOnly(NNTestCase): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype) batch, num_heads, head_dim = 32, 8, 64 head_dim_v = 32 if is_efficient else head_dim + if TEST_WITH_ROCM and head_dim != head_dim_v: + self.skipTest("head_dim != head_dim_v unsupported on ROCm for now") + return seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item() if expand_q_batch else torch.randint(low=1, high=32, size=(batch,)).tolist()) @@ -3716,7 +3703,7 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2) - @skipIfRocm # Nested tensor + @skipIfRocm(msg="Efficient Attention on ROCM does not support head_dim != head_dim_v for now.") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_query_dense(self, device): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) @@ -3751,7 +3738,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2) - @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("batch_size", [8, 32]) @parametrize("max_seq_len_q", [32, 256]) @@ -3919,7 +3905,6 @@ class TestAttnBias(NNTestCase): torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) - @skipIfRocm # No support for the second variant for now @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize( "shape", @@ -3929,6 +3914,9 @@ class TestAttnBias(NNTestCase): make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True ) + if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: + self.skipTest("No support for LOWER_RIGHT variant for now") + return bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) @@ -3952,7 +3940,6 @@ class TestAttnBias(NNTestCase): SDPBackend.CUDNN_ATTENTION]): self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) - @skipIfRocm # CausalVariant @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize( "shape", @@ -3961,6 +3948,10 @@ class TestAttnBias(NNTestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): + if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: + self.skipTest("No support for LOWER_RIGHT variant for now") + return + cnts = CompileCounterWithBackend("aot_eager") make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True @@ -3988,7 +3979,6 @@ class TestAttnBias(NNTestCase): self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") - @skipIfRocm @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): make_tensor = partial( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a2c58d9c877b..54a7c1da8a89 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8762,11 +8762,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] samples = [] - gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support - if TEST_WITH_ROCM and dtype == torch.float32: - causal_options = [False] # FIXME: Large errors with causal+fp32 - else: - causal_options = [True, False] + gqa_options = [True, False] + causal_options = [True, False] for qkv_shape, is_causal, dropout_p, _enable_gqa in product( qkv_shapes, causal_options, [0.0, 0.5], gqa_options): shape_q, shape_kv = qkv_shape @@ -8799,16 +8796,6 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ dropout_p=0.0) ) - if not TEST_WITH_ROCM: - samples.append( - SampleInput( - make((batch, num_heads_q_gqa, seq_q, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - enable_gqa=True - ) - ) - yield from samples