Initial Flash Attention support on ROCM (#114309)

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- [ ] Only supports power of two sequence lengths.
- [ ] No support for varlen APIs.
- [ ] Only support head dimension 16,32,64,128.
- [ ] Performance is still being optimized.

Fixes https://github.com/pytorch/pytorch/issues/112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309

Approved by: https://github.com/jeffdaily, https://github.com/malfet

---------

Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
This commit is contained in:
Xinya Zhang
2023-12-14 10:52:57 -06:00
committed by GitHub
parent ac60a70e06
commit 5bddbed399
14 changed files with 854 additions and 23 deletions

View File

@ -735,10 +735,21 @@ endif()
include(cmake/Dependencies.cmake) include(cmake/Dependencies.cmake)
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now # Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
# TODO: Merge this into cmake_dependent_option as "NOT MSVC AND (USE_CUDA OR USE_ROCM)"
# once cmake_minimum_required is bumped to 3.22
# See https://cmake.org/cmake/help/latest/policy/CMP0127.html for the feature required here.
if(MSVC)
set(CONFIG_FA OFF)
elseif(USE_ROCM OR USE_CUDA)
set(CONFIG_FA ON)
else()
set(CONFIG_FA OFF)
endif()
cmake_dependent_option( cmake_dependent_option(
USE_FLASH_ATTENTION USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention" ON "Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF) "CONFIG_FA" OFF)
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't # Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
cmake_dependent_option( cmake_dependent_option(

View File

@ -164,6 +164,10 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu") file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp") file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
# flash_attention sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
#Mem_eff attention sources #Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu") file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu") file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu")
@ -175,6 +179,9 @@ if(USE_FLASH_ATTENTION)
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp}) list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu}) list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu})
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu}) list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
endif() endif()
if(USE_MEM_EFF_ATTENTION) if(USE_MEM_EFF_ATTENTION)
@ -284,10 +291,34 @@ endif()
if(USE_ROCM) if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip}) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}
${native_hip_hip}
${native_nested_hip_hip}
${native_sparse_hip_hip}
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${native_quantized_cudnn_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS}) list(APPEND all_hip_cpp
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp}) ${native_nested_hip_cpp}
${native_sparse_hip_cpp}
${native_quantized_hip_cpp}
${native_transformers_hip_cpp}
${native_quantized_cudnn_hip_cpp}
${hip_cpp}
${native_hip_cpp}
${native_hip_linalg_cpp}
${cuda_generated_sources}
${ATen_HIP_SRCS}
${native_miopen_cpp}
${native_cudnn_hip_cpp}
${miopen_cpp}
${all_hip_cpp}
)
endif() endif()
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..) list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)

View File

@ -445,6 +445,13 @@ int64_t _fused_sdp_choice_meta(
bool is_causal, bool is_causal,
c10::optional<double> scale) { c10::optional<double> scale) {
auto query_key_set = query_.key_set(); auto query_key_set = query_.key_set();
#if defined(USE_ROCM)
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
if (has_rocm) {
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
return choice_int;
}
#else
bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA); bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
if (has_cuda) { if (has_cuda) {
auto choice_int = _fused_sdp_choice_stub( auto choice_int = _fused_sdp_choice_stub(
@ -458,6 +465,7 @@ int64_t _fused_sdp_choice_meta(
scale); scale);
return choice_int; return choice_int;
} }
#endif
return static_cast<int64_t>(sdp::SDPBackend::math); return static_cast<int64_t>(sdp::SDPBackend::math);
} }
namespace { namespace {
@ -625,7 +633,8 @@ Tensor scaled_dot_product_attention(
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale); validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math); int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (query_.device().type() == DeviceType::CUDA if (query_.device().type() == DeviceType::CUDA
|| query_.device().type() == DeviceType::CPU){ || query_.device().type() == DeviceType::CPU
|| query_.device().type() == DeviceType::HIP){
choice_int = _fused_sdp_choice_stub(query_.device().type(), choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, is_causal, scale); query_, key, value, attn_mask_, dropout_p, is_causal, scale);
} }

View File

@ -14,6 +14,7 @@
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/env.h> #include <c10/util/env.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <c10/util/CallOnce.h>
#include <c10/core/SymInt.h> #include <c10/core/SymInt.h>
#include <c10/util/string_view.h> #include <c10/util/string_view.h>
@ -176,11 +177,42 @@ bool check_sm_version(cudaDeviceProp * dprops) {
return is_gte_lower_bound && is_lte_upper_bound; return is_gte_lower_bound && is_lte_upper_bound;
} }
#if USE_ROCM
c10::once_flag gcn_arch_override_flag;
const char* over_arch = nullptr;
void init_gcn_arch_override() {
over_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (over_arch) {
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect SDPA function's behavior.");
}
}
#endif
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) { bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
// Check that the gpu is capable of running flash attention // Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>; using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>; using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
const char* real_arch = dprops->gcnArchName;
c10::call_once(gcn_arch_override_flag, init_gcn_arch_override);
const char* arch = over_arch ? over_arch : real_arch;
if (mi200 != arch) {
if (debug) {
TORCH_WARN(
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
arch,
".",
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
over_arch ? real_arch : "");
}
return false;
}
#else
if (!check_sm_version<sm80, sm90>(dprops)) { if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) { if (debug) {
TORCH_WARN( TORCH_WARN(
@ -192,6 +224,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
} }
return false; return false;
} }
#endif
return true; return true;
} }

View File

@ -0,0 +1,651 @@
/******************************************************************************
* Copyright (c) 2023, Advanced Micro Devices, Inc.
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <cstdint>
#include <tuple>
#include <ATen/ops/zeros.h>
#ifdef USE_FLASH_ATTENTION
#include <ATen/core/Tensor.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/HIPGraphsUtils.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/reshape.h>
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/sum.h>
#include <ATen/ops/slice.h>
#include <ATen/ops/narrow.h>
#include <ATen/ops/pad.h>
#endif
#include <ATen/native/transformers/hip/flash_attn/flash_api.h>
#include <c10/util/Exception.h>
#include <c10/util/CallOnce.h>
// OORT headers
#include <oort/attn_fwd.h>
#include <oort/bwd_kernel_dk_dv.h>
#include <oort/bwd_kernel_dq.h>
#include <oort/bwd_preprocess.h>
namespace pytorch_flash {
namespace {
c10::once_flag fa_gcn_arch_override_flag;
const char* fa_override_arch = nullptr;
void init_fa_override_arch() {
fa_override_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
if (fa_override_arch) {
TORCH_WARN("ROCM flash attention backend only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
"Later changes to this environment variable with os.environ "
"(or other methods) will not affect this backend's behavior.");
}
}
void check_gpu_arch() {
auto dprops = at::cuda::getCurrentDeviceProperties();
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
c10::call_once(fa_gcn_arch_override_flag, init_fa_override_arch);
if (fa_override_arch) {
TORCH_CHECK(mi200 == fa_override_arch,
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName) + " override as " + fa_override_arch);
} else {
TORCH_CHECK(mi200 == dprops->gcnArchName,
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName));
}
}
}
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float p_dropout,
const float softmax_scale,
bool is_causal,
const int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
check_gpu_arch();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
// FIXME: ROCM probably does not need this
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
at::Tensor q_padded, k_padded, v_padded;
q_padded = q;
k_padded = k;
v_padded = v;
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
} else {
out = at::empty_like(q_padded);
}
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
// We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will
// be used in the backward function
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::Tensor seed_t, offset_t;
if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = batch_size * num_heads * 32;
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
} else {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
} else {
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
} else {
seed_t = at::empty({}, at::dtype(at::kLong));
offset_t = at::empty({}, at::dtype(at::kLong));
}
}
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
//reorder tensors and make contiguous
at::Tensor q_t = q_padded.permute({0,2,1,3}).contiguous();
at::Tensor k_t = k_padded.permute({0,2,1,3}).contiguous();
at::Tensor v_t = v_padded.permute({0,2,1,3}).contiguous();
at::Tensor output_t = out.permute({0,2,1,3}).contiguous();
at::Tensor M = at::empty({batch_size, num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
constexpr int BLOCK_M = 16;
constexpr int BLOCK_N = 16;
dim3 grid;
grid.x = (q_t.sizes()[2] + BLOCK_M - 1) / BLOCK_M;
grid.y = q_t.sizes()[0] * q_t.sizes()[1];
grid.z = 1;
dim3 block { 64 * 4, 1, 1 }; // compiled triton kernel intrinsic
at::Tensor softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
at::dtype(q.dtype()).device(q.device()));
hipError_t err; // TODO: Error handling
#define CALL_FWD(FP, STAGE, BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX) \
do { \
oort::attn_fwd<STAGE,BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX> fwd_opt; \
err = fwd_opt(grid, block, \
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
softmax_scale, (float*)M.data_ptr(), (FP*)output_t.data_ptr(), \
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
output_t.stride(0), output_t.stride(1), output_t.stride(2), output_t.stride(3), \
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
*(uint64_t*)(seed_t.data_ptr()), *(uint32_t*)(offset_t.data_ptr()), \
(FP*)(softmax_fa_t.data_ptr()), \
stream); \
} while(0)
// TODO: Ugly but works
constexpr int kFwdUseCausal = 3;
constexpr int kFwdNoCausal = 1;
int d_head = q_t.sizes()[3];
constexpr int BM = BLOCK_M;
constexpr int BN = BLOCK_N;
if (q_dtype == at::kHalf) {
if (is_causal) {
if (d_head == 16)
CALL_FWD(__fp16,kFwdUseCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__fp16,kFwdUseCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__fp16,kFwdUseCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__fp16,kFwdUseCausal,BM,128,BN,true,true,true);
} else {
if (d_head == 16)
CALL_FWD(__fp16,kFwdNoCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__fp16,kFwdNoCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__fp16,kFwdNoCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__fp16,kFwdNoCausal,BM,128,BN,true,true,true);
}
} else if (q_dtype == at::kBFloat16) {
if (is_causal) {
if (d_head == 16)
CALL_FWD(__bf16,kFwdUseCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__bf16,kFwdUseCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__bf16,kFwdUseCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__bf16,kFwdUseCausal,BM,128,BN,true,true,true);
} else {
if (d_head == 16)
CALL_FWD(__bf16,kFwdNoCausal,BM,16,BN,true,true,true);
else if (d_head == 32)
CALL_FWD(__bf16,kFwdNoCausal,BM,32,BN,true,true,true);
else if (d_head == 64)
CALL_FWD(__bf16,kFwdNoCausal,BM,64,BN,true,true,true);
else if (d_head == 128)
CALL_FWD(__bf16,kFwdNoCausal,BM,128,BN,true,true,true);
}
}
//undo reorder tensors
q_padded = q_t.permute({0,2,1,3}).contiguous();
k_padded = k_t.permute({0,2,1,3}).contiguous();
v_padded = v_t.permute({0,2,1,3}).contiguous();
out = output_t.permute({0,2,1,3}).contiguous();
return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t};
#undef CALL_FWD
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int window_size_left,
int window_size_right,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm");
at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat));
at::Tensor p = at::empty({}, at::dtype(at::kFloat));
at::Tensor offset_t = at::empty({}, at::dtype(at::kLong));
at::Tensor seed_t = at::empty({}, at::dtype(at::kLong));
at::Tensor out = at::empty({}, at::dtype(at::kFloat));
return {out, q, k, v, softmax_lse, seed_t, offset_t, p};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
const int window_size_left,
int window_size_right,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
check_gpu_arch();
bool is_dropout = p_dropout > 0.0;
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q = sizes[1];
const int num_heads = sizes[2];
const int head_size_og = dout.size(3);
const int head_size = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
if (is_causal){
TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels");
}
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
at::Tensor dq, dk, dv;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
} else {
dq = at::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dk = at::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
} else {
dv = at::empty_like(k);
}
// const at::Tensor& dout_padded = dout;
// bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity
bool loop = true;
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
// dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
}
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
at::PhiloxCudaState philox_args;
if (is_dropout) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None)
{
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
} else { // dropout + capture
philox_args = at::PhiloxCudaState(
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
}
}
//JCG TODO WE GO IN HERE TODO backwards
//reorder tensors and make contiguous
at::Tensor q_t = q.permute({0,2,1,3}).contiguous();
at::Tensor k_t = k.permute({0,2,1,3}).contiguous();
at::Tensor v_t = v.permute({0,2,1,3}).contiguous();
at::Tensor out_t = out.permute({0,2,1,3}).contiguous();
//reorder tensors and make contiguous
at::Tensor dq_t = dq.permute({0,2,1,3}).contiguous();
at::Tensor dk_t = dk.permute({0,2,1,3}).contiguous();
at::Tensor dv_t = dv.permute({0,2,1,3}).contiguous();
at::Tensor dout_t = dout.permute({0,2,1,3}).contiguous();
dim3 block { 64 * 4, 1, 1 };
at::Tensor new_do = at::empty_like(dout_t).contiguous();
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
int d_head = head_size_og;
hipError_t err; // TODO: Error handling
#define CALL_BWD_PP(FP, PP_BLOCK, PP_DMODEL) \
do { \
dim3 pp_grid; \
pp_grid.x = batch_size * num_heads * ((dout_t.size(2) + PP_BLOCK - 1) / PP_BLOCK); \
pp_grid.y = 1; \
pp_grid.z = 1; \
oort::bwd_preprocess<PP_BLOCK, PP_DMODEL> pre_opt; \
err = pre_opt(pp_grid, block, \
(FP*)(out_t.data_ptr()), \
(FP*)(dout_t.data_ptr()), \
(FP*)(new_do.data_ptr()), \
(float*)(delta.data_ptr()), \
stream); \
} while (0)
#define CALL_BWD_PP_DMODEL(FP, PP_BLOCK) \
do { \
if (d_head == 16) \
CALL_BWD_PP(FP, PP_BLOCK, 16); \
else if (d_head == 32) \
CALL_BWD_PP(FP, PP_BLOCK, 32); \
else if (d_head == 64) \
CALL_BWD_PP(FP, PP_BLOCK, 64); \
else if (d_head == 128) \
CALL_BWD_PP(FP, PP_BLOCK, 128); \
} while (0)
if(q_dtype == at::kHalf) {
if (seqlen_q >= 64)
CALL_BWD_PP_DMODEL(__fp16, 16);
else
CALL_BWD_PP_DMODEL(__fp16, 16);
} else if (q_dtype == at::kBFloat16) {
if (seqlen_q >= 64)
CALL_BWD_PP_DMODEL(__bf16, 16);
else
CALL_BWD_PP_DMODEL(__bf16, 16);
}
#undef CALL_BWD_PP
#define CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT) \
do { \
dim3 grid; \
grid.x = (seqlen_k + BLOCK_M - 1) / BLOCK_M; \
grid.y = batch_size * num_heads; \
grid.z = 1; \
oort::bwd_kernel_dk_dv<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dk_dv_opt; \
err = dk_dv_opt(grid, block, \
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
(FP*)dk_t.data_ptr(),(FP*)dv_t.data_ptr(), \
(float*)(softmax_lse.data_ptr()), \
(float*)(delta.data_ptr()), \
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
grid.x = (seqlen_q + BLOCK_M - 1) / BLOCK_M; \
oort::bwd_kernel_dq<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dq_opt; \
err = dq_opt(grid, block, \
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
(FP*)dq_t.data_ptr(), \
(float*)(softmax_lse.data_ptr()), \
(float*)(delta.data_ptr()), \
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
} while(0)
#define CALL_BWD_DROPOUT(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL) \
do { \
if (p_dropout > 0.0) { \
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, true); \
} else { \
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, false); \
} \
} while (0)
#define CALL_BWD_DROPOUT_DMODEL(FP, BLOCK_M, BLOCK_N, CAUSAL) \
do { \
if (d_head == 16) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 16, BLOCK_N, CAUSAL); \
else if (d_head == 32) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 32, BLOCK_N, CAUSAL); \
else if (d_head == 64) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 64, BLOCK_N, CAUSAL); \
else if (d_head == 128) \
CALL_BWD_DROPOUT(FP, BLOCK_M, 128, BLOCK_N, CAUSAL); \
} while (0)
if (q_dtype == at::kHalf) {
if (is_causal) {
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, true);
} else {
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, false);
}
} else if (q_dtype == at::kBFloat16) {
if (is_causal) {
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, true);
} else {
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, false);
}
}
//undo reorder tensors for returns
dq = dq_t.permute({0,2,1,3}).contiguous();
dk = dk_t.permute({0,2,1,3}).contiguous();
dv = dv_t.permute({0,2,1,3}).contiguous();
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
}
return { dq, dk, dv, softmax_d };
#undef CALL_BWD_DROPOUT
#undef CALL_BWD
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int window_size_left,
int window_size_right,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");
at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat));
return { q, k, v, softmax_d };
}
} // namespace pytorch_fmha
#endif

View File

@ -955,6 +955,7 @@ endif()
if(USE_ROCM) if(USE_ROCM)
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
hip_add_library(torch_hip ${Caffe2_HIP_SRCS}) hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
set(CUDA_LINK_LIBRARIES_KEYWORD) set(CUDA_LINK_LIBRARIES_KEYWORD)
torch_compile_options(torch_hip) # see cmake/public/utils.cmake torch_compile_options(torch_hip) # see cmake/public/utils.cmake
# TODO: Not totally sure if this is live or not # TODO: Not totally sure if this is live or not
@ -1305,6 +1306,9 @@ if(USE_ROCM)
/opt/rocm/rocblas/include /opt/rocm/rocblas/include
/opt/rocm/hipsparse/include /opt/rocm/hipsparse/include
) )
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
endif()
endif() endif()
if(BUILD_LITE_INTERPRETER) if(BUILD_LITE_INTERPRETER)

View File

@ -1291,6 +1291,7 @@ if(USE_ROCM)
message(STATUS "Disabling Kernel Assert for ROCm") message(STATUS "Disabling Kernel Assert for ROCm")
endif() endif()
include(${CMAKE_CURRENT_LIST_DIR}/External/oort.cmake)
else() else()
caffe2_update_option(USE_ROCM OFF) caffe2_update_option(USE_ROCM OFF)
endif() endif()

25
cmake/External/oort.cmake vendored Normal file
View File

@ -0,0 +1,25 @@
if(NOT __OORT_INCLUDED)
set(__OORT_INCLUDED TRUE)
set(__OORT_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/src")
set(__OORT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/build")
set(__OORT_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
ExternalProject_Add(oort_external
GIT_REPOSITORY https://github.com/ROCmSoftwarePlatform/triton.git
GIT_TAG 29e1252c1ac8e6a54deb883701e553e5b201a1ba
SOURCE_DIR ${__OORT_SOURCE_DIR}
SOURCE_SUBDIR mathaot
BINARY_DIR ${__OORT_BUILD_DIR}
PREFIX ${__OORT_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__OORT_INSTALL_DIR}
# CONFIGURE_COMMAND ""
# BUILD_COMMAND ${MAKE_COMMAND}
BUILD_BYPRODUCTS "${__OORT_INSTALL_DIR}/lib/liboort.a"
# INSTALL_COMMAND ${MAKE_COMMAND} install
)
set(OORT_FOUND TRUE)
add_library(__caffe2_oort INTERFACE)
add_dependencies(__caffe2_oort oort_external)
target_link_libraries(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/lib/liboort.a)
target_include_directories(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/include)
endif() # __OORT_INCLUDED

View File

@ -117,6 +117,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_ROCM : ${USE_ROCM}") message(STATUS " USE_ROCM : ${USE_ROCM}")
if(${USE_ROCM}) if(${USE_ROCM})
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}") message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
endif() endif()
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}") message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}") message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")

View File

@ -20,6 +20,8 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
skipIfRocm,
TEST_FAIRSEQ, TEST_FAIRSEQ,
run_tests, run_tests,
parametrize, parametrize,
@ -117,6 +119,18 @@ def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref return query_ref, key_ref, value_ref
def get_platform_specific_sdpa():
ret = []
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
ret.append(SDPBackend.FLASH_ATTENTION)
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
ret.append(SDPBackend.EFFICIENT_ATTENTION)
if not ret:
# Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize"
ret.append(SDPBackend.EFFICIENT_ATTENTION)
return ret
PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str, def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
requires_grad: bool = False, packed: bool = False) -> torch.Tensor: requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
@ -1212,6 +1226,7 @@ class TestTransformers(NNTestCase):
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
torch.cuda.synchronize() torch.cuda.synchronize()
@skipIfRocm # Missing EFFICIENT_ATTENTION
@unittest.skipIf( @unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
) )
@ -1277,9 +1292,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize( @parametrize(
"kernel", "kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] PLATFORM_SPECIFIC_SDPA,
if PLATFORM_SUPPORTS_FLASH_ATTENTION
else [SDPBackend.EFFICIENT_ATTENTION],
) )
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdp_kernel(**backend_map[kernel]):
@ -1297,9 +1310,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize( @parametrize(
"kernel", "kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] PLATFORM_SPECIFIC_SDPA,
if PLATFORM_SUPPORTS_FLASH_ATTENTION
else [SDPBackend.EFFICIENT_ATTENTION],
) )
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdp_kernel(**backend_map[kernel]):
@ -1315,8 +1326,7 @@ class TestSDPAFailureModes(NNTestCase):
@onlyCUDA @onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend): def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdp_kernel(**backend_map[kernel]):
# Passing in a q,k,v with 0 length sequences will error # Passing in a q,k,v with 0 length sequences will error
@ -1330,8 +1340,7 @@ class TestSDPAFailureModes(NNTestCase):
@onlyCUDA @onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdp_kernel(**backend_map[kernel]):
# Passing in a q,k,v with 0 length sequences will error # Passing in a q,k,v with 0 length sequences will error
@ -1361,9 +1370,7 @@ class TestSDPAFailureModes(NNTestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
@parametrize( @parametrize(
"kernel", "kernel",
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] PLATFORM_SPECIFIC_SDPA,
if PLATFORM_SUPPORTS_FLASH_ATTENTION
else [SDPBackend.EFFICIENT_ATTENTION],
) )
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend): def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdp_kernel(**backend_map[kernel]):
@ -1436,6 +1443,7 @@ class TestSDPAFailureModes(NNTestCase):
_ = torch.nn.functional.scaled_dot_product_attention( _ = torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False) q, k, v, None, 0.0, False)
# Note: do not truncate the list according to platforms. These tests should always raise errors.
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend): def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
with sdp_kernel(**backend_map[kernel]): with sdp_kernel(**backend_map[kernel]):
@ -1467,7 +1475,8 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
@onlyCUDA @onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @skipIfRocm # Missing EFFICIENT_ATTENTION
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
def test_fused_kernels_nested_broadcasting_error_cases(self, device): def test_fused_kernels_nested_broadcasting_error_cases(self, device):
# one of k,v needs to be broadcasted and other has non consistent seq_len dim # one of k,v needs to be broadcasted and other has non consistent seq_len dim
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
@ -1788,6 +1797,9 @@ class TestSDPACudaOnly(NNTestCase):
query_padding_mask: (batch_size, seqlen_q) query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k) key_padding_mask: (batch_size, seqlen_k)
""" """
if TEST_WITH_ROCM:
return S
b, h, seqlen_q, seqlen_k = S.shape b, h, seqlen_q, seqlen_k = S.shape
warps_n = 4 warps_n = 4
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal) blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
@ -1954,6 +1966,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2) self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"]) @parametrize("type", ["dense", "nested"])
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
@ -2066,6 +2079,7 @@ class TestSDPACudaOnly(NNTestCase):
# Cast up and compare # Cast up and compare
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
@skipIfRocm # Small matrices
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
@parametrize("contiguous_inputs", [True, False]) @parametrize("contiguous_inputs", [True, False])
@parametrize("is_causal", [True, False]) @parametrize("is_causal", [True, False])
@ -2118,6 +2132,7 @@ class TestSDPACudaOnly(NNTestCase):
rtol = 7e-4 if dtype == torch.float16 else 7e-3 rtol = 7e-4 if dtype == torch.float16 else 7e-3
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol) self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
@parametrize("type", ["dense", "nested"]) @parametrize("type", ["dense", "nested"])
def test_fused_sdp_choice(self, device, type: str): def test_fused_sdp_choice(self, device, type: str):
@ -2464,6 +2479,15 @@ class TestSDPACudaOnly(NNTestCase):
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str): scale: str):
if TEST_WITH_ROCM:
def is_power_of_2(n):
return n & (n - 1) == 0
if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim):
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.")
if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16:
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.")
if head_dim > 128:
self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")
if isSM86or89Device and head_dim in range(193, 256 + 1): if isSM86or89Device and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled") self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled")
@ -2540,7 +2564,7 @@ class TestSDPACudaOnly(NNTestCase):
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype)) out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
# See [Note] Fused Tolerances above # See [Note] Fused Tolerances above
output_fudge_factor = 3 if head_dim % 8 != 0 else 1 output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor) output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
# TODO: Investigate why grad_q needs larger tolerances # TODO: Investigate why grad_q needs larger tolerances
@ -2559,6 +2583,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
@skipIfRocm # FIXME: "capturing stream has unjoined work"
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("batch_size", [1, 8]) @parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [256, 512, 1024]) @parametrize("seq_len_q", [256, 512, 1024])
@ -2568,7 +2593,7 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("dropout_p", [0.0, 0.22]) @parametrize("dropout_p", [0.0, 0.22])
@parametrize("dtype", [torch.float16,]) @parametrize("dtype", [torch.float16,])
@parametrize("scale", [None, "l1"]) @parametrize("scale", [None, "l1"])
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, head_dim: int,
is_causal: bool, is_causal: bool,
@ -2721,6 +2746,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
@ -2755,6 +2781,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2) self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
@skipIfRocm # Nested tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION]) PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
@ -2878,6 +2905,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2) self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
@onlyCUDA @onlyCUDA
@skipIfRocm # Nested tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("batch_size", [8, 32]) @parametrize("batch_size", [8, 32])
@parametrize("max_seq_len_q", [32, 256]) @parametrize("max_seq_len_q", [32, 256])
@ -3036,6 +3064,7 @@ class TestAttnMasks(NNTestCase):
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol) torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
@skipIfRocm # No support for the second variant for now
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
@parametrize( @parametrize(
"shape", "shape",
@ -3064,6 +3093,7 @@ class TestAttnMasks(NNTestCase):
self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol) self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
@skipIfRocm # No support for the second variant for now
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT]) @parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
@parametrize( @parametrize(
"shape", "shape",

View File

@ -90,7 +90,14 @@ includes = [
"aten/src/ATen/native/nested/cuda/*", "aten/src/ATen/native/nested/cuda/*",
"aten/src/ATen/native/sparse/cuda/*", "aten/src/ATen/native/sparse/cuda/*",
"aten/src/ATen/native/quantized/cuda/*", "aten/src/ATen/native/quantized/cuda/*",
"aten/src/ATen/native/transformers/cuda/*", "aten/src/ATen/native/transformers/cuda/attention_backward.cu",
"aten/src/ATen/native/transformers/cuda/attention.cu",
"aten/src/ATen/native/transformers/cuda/sdp_utils.cpp",
"aten/src/ATen/native/transformers/cuda/sdp_utils.h",
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h",
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h",
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h",
"aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h",
"aten/src/THC/*", "aten/src/THC/*",
"aten/src/ATen/test/*", "aten/src/ATen/test/*",
# CMakeLists.txt isn't processed by default, but there are a few # CMakeLists.txt isn't processed by default, but there are a few

View File

@ -6,6 +6,7 @@ import torch.cuda
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
import inspect import inspect
import contextlib import contextlib
import os
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized() CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
@ -28,7 +29,23 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)) SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)) SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and (not TEST_WITH_ROCM) and (not IS_WINDOWS) and SM80OrLater) def evaluate_gfx90a_exact():
if not torch.cuda.is_available():
return False
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
return arch == 'gfx90a:sramecc+:xnack-'
GFX90A_Exact = LazyVal(lambda: evaluate_gfx90a_exact())
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
return evaluate_gfx90a_exact()
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
return False
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM) PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM)
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate # This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)

View File

@ -14265,6 +14265,15 @@ op_db: List[OpInfo] = [
device_type='cpu'), device_type='cpu'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace', DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
device_type='cpu'), device_type='cpu'),
# TODO: Do not work even on MI200 because of stride mismatching.
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
device_type='cuda', dtypes=[torch.float16, torch.bfloat16],
active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace',
device_type='cuda', dtypes=[torch.float16, torch.bfloat16],
active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_amp',
device_type='cuda', active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
# When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false # When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward', DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward',
device_type='cpu'), device_type='cpu'),

View File

@ -8572,6 +8572,8 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
C10_MAPPINGS = collections.OrderedDict( C10_MAPPINGS = collections.OrderedDict(
[ [
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),
("CUDA_LAUNCH_BLOCKING=1", ("AMD_SERIALIZE_KERNEL=3", API_C10)),
("CUDA_LAUNCH_BLOCKING", ("AMD_SERIALIZE_KERNEL", API_C10)),
("cuda::compat::", ("hip::compat::", API_C10)), ("cuda::compat::", ("hip::compat::", API_C10)),
("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)), ("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)),
("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)), ("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)),