mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ROCm] Integrate AITER Fav3 fwd kernels (#160105)"
This reverts commit d2393c2d7da03a1523a12e6f80edb6bd7b464ec5. Reverted https://github.com/pytorch/pytorch/pull/160105 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing internal ROCm build ([comment](https://github.com/pytorch/pytorch/pull/160105#issuecomment-3273297183))
This commit is contained in:
@ -1,22 +1,13 @@
|
||||
include(CMakePrintHelpers)
|
||||
|
||||
# Generate AITER/CK Asm code
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/py_itfs_cu/fmha_v3_fwd_kernel_generate.py --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Failed to generate FAv3 fwd CK Kernels")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/py_itfs_cu/fmha_v3_bwd_kernel_generate.py --receipt 1 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Failed to generate FAv3 bwd CK Kernels")
|
||||
message( FATAL_ERROR "Failed to generate FAv3 CK Kernels")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
@ -24,24 +15,6 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Failed to generate FAv3 bwd api")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/cpp_itfs/mha_fwd_generate.py --receipt 6 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Failed to generate FAv3 fwd api")
|
||||
endif()
|
||||
|
||||
# Change file extensions to .hip
|
||||
execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done"
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Failed to modify aiter file extensions")
|
||||
endif()
|
||||
execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done")
|
||||
|
@ -3,7 +3,6 @@
|
||||
******************************************************************************/
|
||||
|
||||
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
|
||||
#include <mha_fwd.h>
|
||||
#include <fmha_fwd.hpp>
|
||||
#include <mask.hpp>
|
||||
|
||||
@ -142,7 +141,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
0, // min_seqlen_q
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
@ -351,14 +350,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed_offset);
|
||||
float t = aiter::mha_fwd(args, // mha_fwd_args args
|
||||
stream_config, // stream_config
|
||||
q_dtype_str, // q_dtype_str
|
||||
false, // is_group_mode
|
||||
mask.type, // mask_type
|
||||
attn_bias_.has_value() ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse, // has_lse
|
||||
true); // use_ext_asm
|
||||
float t = fmha_fwd(traits, args, stream_config);
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
|
||||
}
|
||||
else {
|
||||
|
@ -349,7 +349,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
p_dropout,
|
||||
drop_seed_offset);
|
||||
float t = fmha_fwd(traits, args, stream_config);
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_varlen_fwd");
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
|
||||
}
|
||||
else {
|
||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
|
2
third_party/aiter
vendored
2
third_party/aiter
vendored
Submodule third_party/aiter updated: 28918c0e68...01aae101b9
Reference in New Issue
Block a user