diff --git a/.ci/docker/aotriton_version.txt b/.ci/docker/aotriton_version.txt index 602b77d3b853..0bb9b7f4bbbf 100644 --- a/.ci/docker/aotriton_version.txt +++ b/.ci/docker/aotriton_version.txt @@ -1,5 +1,5 @@ -0.7b -manylinux_2_17 +0.8b +manylinux_2_28 rocm6.2 -9be04068c3c0857a4cfd17d7e39e71d0423ebac2 -3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291 +6f8cbcac8a92775291bb1ba8f514d4beb350baf4 +e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8 diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 7381984df0e9..004a107df948 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -253,11 +253,11 @@ make_wheel_record() { FPATH=$1 if echo $FPATH | grep RECORD >/dev/null 2>&1; then # if the RECORD file, then - echo "$FPATH,," + echo "\"$FPATH\",," else HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g') FSIZE=$(ls -nl $FPATH | awk '{print $5}') - echo "$FPATH,sha256=$HASH,$FSIZE" + echo "\"$FPATH\",sha256=$HASH,$FSIZE" fi } diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index fd330f6435c8..41d8c4e15272 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -225,11 +225,11 @@ make_wheel_record() { FPATH=$1 if echo $FPATH | grep RECORD >/dev/null 2>&1; then # if the RECORD file, then - echo "$FPATH,," + echo "\"$FPATH\",," else HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g') FSIZE=$(ls -nl $FPATH | awk '{print $5}') - echo "$FPATH,sha256=$HASH,$FSIZE" + echo "\"$FPATH\",sha256=$HASH,$FSIZE" fi } diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh index b9bb67631927..32fd1435caf7 100755 --- a/.ci/manywheel/build_rocm.sh +++ b/.ci/manywheel/build_rocm.sh @@ -266,6 +266,20 @@ RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC)) DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/}) DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) +# PyTorch 2.6+ (AOTriton 0.8b+) +# AKS = "AOTriton Kernel Storage", a file format to store GPU kernels compactly +if (( $(echo "${PYTORCH_VERSION} 2.6" | awk '{print ($1 >= $2)}') )); then + LIBAOTRITON_DIR=$(find "$ROCM_HOME/lib/" -name "libaotriton_v2.so" -printf '%h\n') + if [[ -z ${LIBAOTRITON_DIR} ]]; then + LIBAOTRITON_DIR=$(find "$ROCM_HOME/" -name "libaotriton_v2.so" -printf '%h\n') + fi + AKS_FILES=($(find "${LIBAOTRITON_DIR}/aotriton.images" -type f -name '*.aks?' -printf '%P\n')) + AKS_SRC="${LIBAOTRITON_DIR}/aotriton.images" + AKS_DST="lib/aotriton.images" + DEPS_AUX_SRCLIST+=(${AKS_FILES[@]/#/${AKS_SRC}/}) + DEPS_AUX_DSTLIST+=(${AKS_FILES[@]/#/${AKS_DST}/}) +fi + echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}" SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index e87c6271ef12..7fe7ee7a1ba1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1160,6 +1160,7 @@ std::tuple _efficient_ const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); using aotriton::v2::flash::attn_fwd; + using aotriton::v2::flash::attn_fwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -1172,22 +1173,46 @@ std::tuple _efficient_ auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); hipError_t err; // TODO: Error handling - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, - softmax_scale, - mk_aotensor<2>(softmax_lse, "M"), - mk_aotensor(output_t, "Out"), - dropout_p, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - stream); + if (seqstart_q.has_value()) { + // varlen aka nested tensor + err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + } else { + err = attn_fwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, + softmax_scale, + mk_aotensor<2>(softmax_lse, "M"), + mk_aotensor(output_t, "Out"), + dropout_p, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + } if (!compute_logsumexp) { // Set the tensor to empty when compute_logsumexp is false logsumexp = at::empty( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index ffe51b373265..09799ff125d1 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -441,29 +441,57 @@ _efficient_attention_backward( hipError_t err; using aotriton::v2::flash::attn_bwd; + using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - mk_aotensor<2>(delta, "delta"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); + if (cu_seqlens_q.has_value()) { + // varlen aka Nested tensor + err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + err = attn_bwd(mk_aotensor(q_t, "q"), + mk_aotensor(k_t, "k"), + mk_aotensor(v_t, "v"), + bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_t, "dq"), + mk_aotensor(dk_t, "dk"), + mk_aotensor(dv_t, "dv"), + bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, + mk_aotensor<2>(softmax_lse, "L"), + mk_aotensor<2>(delta, "delta"), + float(dropout_p), + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } #else at::Tensor workspace; cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index()); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index df584fe6ff24..a60cbe5ea061 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -103,6 +103,10 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { return matmul_alignment_mn; } +// On ROCM, ME and FA share the backend, and hence they share the checking +// function for fundamental limitations by the GPU kernel +// caller_is_meff is added to make the TORCH_WARN message showing the correct result +template bool check_head_dim_size_flash(sdp_params const& params, bool debug) { // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); @@ -114,7 +118,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { if (!(same_head_dim_size && (query_size_last <= max_size))) { if (debug) { TORCH_WARN( - "Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.", + caller_is_meff ? "Efficient attention on ROCM" : "Flash attention", + " requires q,k,v to have the same last dimension and to be less than or equal to 256.", " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", @@ -128,6 +133,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { return true; } +// See check_head_dim_size_flash above for the purpose of caller_is_meff +template bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { const auto max_size = c10::SymInt(256); const auto query_size_last = params.query.sym_size(-1); @@ -139,7 +146,9 @@ bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { (query_size_last <= max_size))) { if (debug) { TORCH_WARN( - "For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", + "For NestedTensor inputs,", + caller_is_meff ? " Efficient attention on ROCM " : " Flash attention", + " requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.", " Got Query.size(-1): ", query_size_last, ", Key.size(-1): ", @@ -208,7 +217,6 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug // Check that the gpu is capable of running flash attention using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; - auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM #if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -220,19 +228,11 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug } return false; } - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { - static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; - if (!enable_navi3x) { - TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental." - " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); - return false; - } - } #else return false; #endif #else + auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -252,7 +252,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) // Mem Efficient attention supports hardware in the range [sm_50, sm_90] using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; - auto dprops = at::cuda::getCurrentDeviceProperties(); #if USE_ROCM #if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -264,19 +263,11 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { - static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true; - if (!enable_navi3x) { - TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental." - " Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1."); - return false; - } - } #else return false; #endif #else + auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { if (debug) { TORCH_WARN( @@ -615,7 +606,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { check_all_tensors_on_device, check_tensor_shapes, check_for_attn_mask, - check_head_dim_size_flash, + check_head_dim_size_flash, check_flash_attention_hardware_support, check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89, check_flash_causal_non_square_seqlens, @@ -629,7 +620,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { if (has_for_nested_inputs(params)) { constexpr auto nested_constraints = array_of( check_batch_size_nested, - check_head_dim_size_flash_nested, + check_head_dim_size_flash_nested, check_for_seq_len_0_nested_tensor); for (auto& constraint : nested_constraints) { if (!constraint(params, debug)) { @@ -637,11 +628,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { } } } -#if USE_ROCM - constexpr bool backend_supports_grouped_query_attention = false; -#else constexpr bool backend_supports_grouped_query_attention = true; -#endif if (has_only_dense_inputs(params)) { constexpr auto dense_constraints = array_of( check_batch_size_and_num_heads_dense, @@ -679,7 +666,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { check_mem_efficient_hardware_support, check_tensor_shapes, #ifdef USE_ROCM - check_head_dim_size_flash + check_head_dim_size_flash #else check_head_dim_size_mem_efficient #endif @@ -691,12 +678,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { } if (has_for_nested_inputs(params)) { -#ifdef USE_ROCM - TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention."); - return false; -#endif constexpr auto nested_constraints = array_of( +#ifndef USE_ROCM // ME and FA shares backend on ROCM and thus supports training check_requires_grad_and_nested, +#else // Meanwhile ME on ROCM share the limits of FA about head dimensions + check_head_dim_size_flash_nested, +#endif check_batch_size_nested, check_for_seq_len_0_nested_tensor); for (auto& constraint : nested_constraints) { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 7191a5f13331..dcbac79e317d 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -77,6 +77,37 @@ void check_gpu_arch(hipStream_t stream) { } } +// We want to checkpoint and save the RNG state for backward if dropout +// We get the default generator and return the seed and offset which will +// be used in the backward function +std::tuple +prepare_philox_arguments(float p_dropout, int64_t counter_offset) { + at::Tensor seed_t, offset_t; + at::PhiloxCudaState philox_state; + bool use_philox_state = false; + if (p_dropout <= 0.0) { + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + return { seed_t, offset_t, philox_state, use_philox_state }; + } + auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + std::lock_guard lock(gen->mutex_); + philox_state = gen->philox_cuda_state(counter_offset); + if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { + auto [seed, offset] = at::cuda::philox::unpack(philox_state); + seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); + } else { + // See Note [CUDA Graph-safe RNG states] about the design + use_philox_state = true; + seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + } + + return { seed_t, offset_t, philox_state, use_philox_state }; +} + + } #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") @@ -158,44 +189,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - // We want to checkpoint and save the RNG state for backward if dropout - // We get the default generator and return the seed and offset which will - // be used in the backward function - auto gen = at::get_generator_or_default(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); - at::Tensor seed_t, offset_t; - - at::PhiloxCudaState philox_state; - bool use_philox_state = false; - if (p_dropout > 0.0) { - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = batch_size * num_heads * 32; - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - philox_state = gen->philox_cuda_state(counter_offset); - if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { - auto [seed, offset] = at::cuda::philox::unpack(philox_state); - seed_t = at::scalar_tensor(at::Scalar(static_cast(seed)), at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(offset)), at::dtype(at::kLong).device(at::kCUDA)); - } else { - // See Note [CUDA Graph-safe RNG states] about the design - use_philox_state = true; - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } else { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - } - } + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); // Transpose tensors to meet AOTriton's Flash API at::Tensor q_t = q_padded.permute({0,2,1,3}); @@ -215,7 +211,6 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head hipError_t err; // TODO: Error handling using aotriton::v2::flash::attn_fwd; - using aotriton::TensorView; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -266,16 +261,150 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int window_size_right, const bool return_softmax, std::optional gen_) { + TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); + const bool paged_KV = block_table_.has_value(); + TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt"); + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); - TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm"); + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); - at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat)); - at::Tensor p = at::empty({}, at::dtype(at::kFloat)); - at::Tensor offset_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor seed_t = at::empty({}, at::dtype(at::kLong)); - at::Tensor out = at::empty({}, at::dtype(at::kFloat)); + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); - return {out, q, k, v, softmax_lse, seed_t, offset_t, p}; + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(cu_seqlens_q); + CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int batch_size = cu_seqlens_q.numel() - 1; + int num_heads = sizes[1]; + const int head_size_og = sizes[2]; + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { + is_causal = false; + } // causal=true is the same as causal=false in this case + + at::Tensor temp_q = q; + const int total_q = temp_q.sizes()[0]; + + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= max_seqlen_k) { + window_size_right = -1; + } + + CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og); + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + // AOTriton's varlen API needs input shapes be + // (1, num_heads, total sequence lenght, head dimension) + at::Tensor q_padded, k_padded, v_padded; + at::Tensor out, out_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, total_q, num_heads, head_size_og); + } else { + out = at::empty_like(q); + } + out_padded = out.unsqueeze(0).transpose(1, 2); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = head_size_og; + + auto opts = q.options(); + + auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q}); + at::Tensor softmax_fa_t; + // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { + TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); + softmax_fa_t = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); + } else { + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); + } + + if (zero_tensors) { + out.zero_(); + softmax_lse.fill_(-std::numeric_limits::infinity()); + if (return_softmax) { + softmax_fa_t.zero_(); + } + } + + auto [seed_t, offset_t, philox_state, use_philox_state] = + prepare_philox_arguments(p_dropout, batch_size * num_heads * 32); + + if (max_seqlen_k > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_fwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); + auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t); + auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0; + auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : mk_philoxtensor(nullptr); + auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : mk_philoxtensor(nullptr); + err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor<2>(M, "M"), + mk_aotensor(out_padded, "Out"), + p_dropout, + seed, + offset1, + offset2, + seed_output, + offset_output, + mk_aotensor(softmax_fa_t, "encoded_softmax"), + is_causal, + stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + return {out, q, k, v, softmax_lse, seed_t, offset_t, softmax_fa_t}; } std::tuple @@ -297,6 +426,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); check_gpu_arch(stream); @@ -341,8 +473,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); - const int seqlen_q_rounded = round_multiple(seqlen_q, 128); - const int seqlen_k_rounded = round_multiple(seqlen_k, 128); TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); @@ -381,23 +511,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv = at::empty_like(k); } - // const at::Tensor& dout_padded = dout; - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - auto opts = q.options(); - auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); - - at::Tensor dk_expanded, dv_expanded; - if (num_heads_k != num_heads) { // MQA / GQA - dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts); - } else { - dk_expanded = dk; - dv_expanded = dv; - } + auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor q_t = q.permute({0,2,1,3}); at::Tensor k_t = k.permute({0,2,1,3}); @@ -440,14 +555,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si stream); } - // For MQA/GQA we need to sum dK and dV across the groups - if (num_heads_k != num_heads) { - at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); - } return { dq, dk, dv, softmax_d }; -#undef CALL_BWD_DROPOUT -#undef CALL_BWD } std::tuple @@ -473,13 +581,173 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size int window_size_right, const bool deterministic, const at::Tensor philox_seed, - const at::Tensor philox_offset) { - TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); + const at::Tensor philox_offset) +{ + TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); - at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat)); + if (is_causal) { + window_size_right = 0; + } + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + check_gpu_arch(stream); - return { q, k, v, softmax_d }; + bool is_dropout = p_dropout > 0.0; + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32"); + TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + CHECK_CONTIGUOUS(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_k); + + const auto sizes = q.sizes(); + + const int total_q = sizes[0]; + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = sizes[1]; + const int head_size_og = dout.size(2); + const int head_size = sizes[2]; + const int total_k = k.size(0); + const int num_heads_k = k.size(1); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + + CHECK_SHAPE(q, total_q, num_heads, head_size); + CHECK_SHAPE(k, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(dout, total_q, num_heads, head_size_og); + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); + + at::Tensor q_padded, k_padded, v_padded; + q_padded = q.unsqueeze(0).transpose(1, 2); + k_padded = k.unsqueeze(0).transpose(1, 2); + v_padded = v.unsqueeze(0).transpose(1, 2); + at::Tensor out_t, dout_t; + out_t = out.unsqueeze(0).transpose(1, 2); + dout_t = dout.unsqueeze(0).transpose(1, 2); + + at::Tensor dq, dk, dv; + at::Tensor dq_padded, dk_padded, dv_padded; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, total_q, num_heads, head_size); + } else { + dq = at::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, total_k, num_heads_k, head_size); + } else { + dk = at::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, total_k, num_heads_k, head_size); + } else { + dv = at::empty_like(v); + } + dq_padded = dq.unsqueeze(0).transpose(1, 2); + dk_padded = dk.unsqueeze(0).transpose(1, 2); + dv_padded = dv.unsqueeze(0).transpose(1, 2); + + auto opts = q.options(); + auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + + if( zero_tensors ) { + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + at::PhiloxCudaState philox_args; + if (is_dropout) { + if (at::cuda::currentStreamCaptureStatus() == + at::cuda::CaptureStatus::None) + { + philox_args = at::PhiloxCudaState(*philox_seed.data_ptr(), *philox_offset.data_ptr()); + } else { // dropout + capture + philox_args = at::PhiloxCudaState( + philox_seed.data_ptr(), philox_offset.data_ptr(), 0); + } + } + if (max_seqlen_q > 0) { + hipError_t err; // TODO: Error handling + using aotriton::v2::flash::attn_bwd_compact_varlen; + using sdp::aotriton_adapter::mk_aotensor; + using sdp::aotriton_adapter::mk_aoscalartensor; + using sdp::aotriton_adapter::cast_dtype; + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); + err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), + mk_aotensor(k_padded, "k"), + mk_aotensor(v_padded, "v"), + mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), + mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), + max_seqlen_q, + max_seqlen_k, + empty_bias, + softmax_scale, + mk_aotensor(out_t, "out"), + mk_aotensor(dout_t, "dout"), + mk_aotensor(dq_padded, "dq"), + mk_aotensor(dk_padded, "dk"), + mk_aotensor(dv_padded, "dv"), + empty_bias, + mk_aotensor<2>(softmax_lse_cont, "L"), + mk_aotensor<2>(delta, "delta"), + p_dropout, + mk_aoscalartensor(philox_seed), + mk_aoscalartensor(philox_offset), + 0, + is_causal, + stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dq.zero_(); + dk.zero_(); + dv.zero_(); + softmax_d.zero_(); + } + + return { dq, dk, dv, softmax_d }; } + } // namespace pytorch_fmha #endif diff --git a/setup.py b/setup.py index 24c8c19b9bb8..5ed97c6df7b6 100644 --- a/setup.py +++ b/setup.py @@ -1373,6 +1373,13 @@ def main(): "lib/*.lib", ] ) + aotriton_image_path = os.path.join(lib_path, "aotriton.images") + aks2_files = [] + for root, dirs, files in os.walk(aotriton_image_path): + subpath = os.path.relpath(root, start=aotriton_image_path) + for fn in files: + aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn)) + torch_package_data += aks2_files if get_cmake_cache_vars()["USE_TENSORPIPE"]: torch_package_data.extend( [ diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 6290d3719d85..bc4588ad54c2 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -389,8 +389,6 @@ class TestFlexAttention(InductorTestCase): KV_S = Q_S if V_D is None: V_D = Q_D - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") q = torch.randn( (Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True ) @@ -565,9 +563,6 @@ class TestFlexAttention(InductorTestCase): V_D: int = D, block_mask: Optional[BlockMask] = None, ): - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") - assert Q_H % KV_H == 0 q = torch.randn( diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 1e8c0ada855f..c5d10c069f37 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -22,7 +22,7 @@ from torch.nn.attention.flex_attention import ( from torch.testing import FileCheck from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16 -from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM +from torch.testing._internal.common_utils import skipIfRocm from torch.utils._triton import has_triton @@ -492,9 +492,6 @@ class TestFlexDecoding(InductorTestCase): V_D: int = D, block_mask: Optional[BlockMask] = None, ): - if TEST_WITH_ROCM and Q_H != KV_H: - self.skipTest("enable_gqa=True is unsupported on ROCM, for now") - assert Q_H % KV_H == 0 q = torch.randn( diff --git a/test/test_native_mha.py b/test/test_native_mha.py index 307115147852..0e4489ab135c 100644 --- a/test/test_native_mha.py +++ b/test/test_native_mha.py @@ -277,8 +277,8 @@ class TestMHADeviceType(TestCase): def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding, pad_all, fused): if TEST_WITH_ROCM: - if use_nt: - self.skipTest("ROCM does not support nested tensors for Flash Attention for now.") + if use_nt and use_padding and pad_all: + self.skipTest("Large numerical errors on ROCM to investigate.") if use_padding and not pad_all and fused: self.skipTest("Large numerical errors on ROCM to investigate.") for need_weights in (False, not pad_all): diff --git a/test/test_transformers.py b/test/test_transformers.py index 1bd8f7482f33..325f35c6f1a9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1630,7 +1630,6 @@ class TestSDPAFailureModes(NNTestCase): q, k, v, None, 0.0, False)) @onlyCUDA - @skipIfRocm(msg='enable_gqa=True unsupported') @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel): @@ -1646,7 +1645,6 @@ class TestSDPAFailureModes(NNTestCase): is_causal=False, enable_gqa=True) @onlyCPU - @skipIfRocm(msg='enable_gqa=True unsupported') def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device): rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) @@ -1784,7 +1782,6 @@ class TestSDPAFailureModes(NNTestCase): self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) @onlyCUDA - @skipIfRocm # Missing EFFICIENT_ATTENTION @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_error_cases(self, device): # one of k,v needs to be broadcasted and other has non consistent seq_len dim @@ -2485,7 +2482,7 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention + @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2526,7 +2523,7 @@ class TestSDPACudaOnly(NNTestCase): with self.assertRaisesRegex(RuntimeError, "No available kernel."): o = torch.nn.functional.scaled_dot_product_attention(q, k, v) - @skipIfRocm # No cuDNN Attention + @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -2713,8 +2710,6 @@ class TestSDPACudaOnly(NNTestCase): @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool): - if TEST_WITH_ROCM and type == 'nested': - self.skipTest("ROCM does not support efficient attention on nested tensors, for now") make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 @@ -2743,7 +2738,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) - @skipIfRocm # Missing nested and EFFICIENT_ATTENTION @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @@ -2910,7 +2904,6 @@ class TestSDPACudaOnly(NNTestCase): atol = 9e-4 if dtype == torch.float16 else 9e-3 self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol) - @skipIfRocm # Missing nested and EFFICIENT_ATTENTION @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA") @parametrize("type", ["dense", "nested"]) def test_fused_sdp_choice(self, device, type: str): @@ -3116,8 +3109,6 @@ class TestSDPACudaOnly(NNTestCase): return if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: torch.cuda.empty_cache() # Prevent memory fragmentation - if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: - self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 scale = scale if scale is None else (1 / head_dim) n_heads = 4 @@ -3227,9 +3218,6 @@ class TestSDPACudaOnly(NNTestCase): if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return - if TEST_WITH_ROCM and dtype == torch.float32: - unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.") - return if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128: torch.cuda.empty_cache() # Prevent memory fragmentation seed = 42 @@ -3323,7 +3311,7 @@ class TestSDPACudaOnly(NNTestCase): @parametrize("dropout_p", [0.0, 0.22, 0.48]) @parametrize("dtype", [torch.float16, torch.bfloat16]) @parametrize("scale", [None, "l1"]) - @parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False]) + @parametrize("enable_gqa", [True, False]) @parametrize("n_heads", [[16, 8], [10, 2]]) @tf32_enabled() def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, @@ -3488,8 +3476,6 @@ class TestSDPACudaOnly(NNTestCase): if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") - if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k: - self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k") seed = 42 n_heads = 4 @@ -3593,7 +3579,6 @@ class TestSDPACudaOnly(NNTestCase): } ) - @skipIfRocm # Nested Tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) @@ -3627,7 +3612,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2) - @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) @@ -3653,6 +3637,9 @@ class TestSDPACudaOnly(NNTestCase): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype) batch, num_heads, head_dim = 32, 8, 64 head_dim_v = 32 if is_efficient else head_dim + if TEST_WITH_ROCM and head_dim != head_dim_v: + self.skipTest("head_dim != head_dim_v unsupported on ROCm for now") + return seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item() if expand_q_batch else torch.randint(low=1, high=32, size=(batch,)).tolist()) @@ -3716,7 +3703,7 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2) - @skipIfRocm # Nested tensor + @skipIfRocm(msg="Efficient Attention on ROCM does not support head_dim != head_dim_v for now.") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_query_dense(self, device): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) @@ -3751,7 +3738,6 @@ class TestSDPACudaOnly(NNTestCase): self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2) - @skipIfRocm # Nested tensor @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("batch_size", [8, 32]) @parametrize("max_seq_len_q", [32, 256]) @@ -3919,7 +3905,6 @@ class TestAttnBias(NNTestCase): torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) - @skipIfRocm # No support for the second variant for now @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize( "shape", @@ -3929,6 +3914,9 @@ class TestAttnBias(NNTestCase): make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True ) + if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: + self.skipTest("No support for LOWER_RIGHT variant for now") + return bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) @@ -3952,7 +3940,6 @@ class TestAttnBias(NNTestCase): SDPBackend.CUDNN_ATTENTION]): self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None) - @skipIfRocm # CausalVariant @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize( "shape", @@ -3961,6 +3948,10 @@ class TestAttnBias(NNTestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): + if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: + self.skipTest("No support for LOWER_RIGHT variant for now") + return + cnts = CompileCounterWithBackend("aot_eager") make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True @@ -3988,7 +3979,6 @@ class TestAttnBias(NNTestCase): self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") - @skipIfRocm @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): make_tensor = partial( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a2c58d9c877b..54a7c1da8a89 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8762,11 +8762,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] samples = [] - gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support - if TEST_WITH_ROCM and dtype == torch.float32: - causal_options = [False] # FIXME: Large errors with causal+fp32 - else: - causal_options = [True, False] + gqa_options = [True, False] + causal_options = [True, False] for qkv_shape, is_causal, dropout_p, _enable_gqa in product( qkv_shapes, causal_options, [0.0, 0.5], gqa_options): shape_q, shape_kv = qkv_shape @@ -8799,16 +8796,6 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ dropout_p=0.0) ) - if not TEST_WITH_ROCM: - samples.append( - SampleInput( - make((batch, num_heads_q_gqa, seq_q, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - make((batch, num_heads_kv_gqa, seq_kv, head_dim)), - enable_gqa=True - ) - ) - yield from samples