mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes https://github.com/pytorch/pytorch/issues/112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
This commit is contained in:
@ -735,10 +735,21 @@ endif()
|
|||||||
include(cmake/Dependencies.cmake)
|
include(cmake/Dependencies.cmake)
|
||||||
|
|
||||||
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
|
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
|
||||||
|
# TODO: Merge this into cmake_dependent_option as "NOT MSVC AND (USE_CUDA OR USE_ROCM)"
|
||||||
|
# once cmake_minimum_required is bumped to 3.22
|
||||||
|
# See https://cmake.org/cmake/help/latest/policy/CMP0127.html for the feature required here.
|
||||||
|
if(MSVC)
|
||||||
|
set(CONFIG_FA OFF)
|
||||||
|
elseif(USE_ROCM OR USE_CUDA)
|
||||||
|
set(CONFIG_FA ON)
|
||||||
|
else()
|
||||||
|
set(CONFIG_FA OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
USE_FLASH_ATTENTION
|
USE_FLASH_ATTENTION
|
||||||
"Whether to build the flash_attention kernel for scaled dot product attention" ON
|
"Whether to build the flash_attention kernel for scaled dot product attention" ON
|
||||||
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
"CONFIG_FA" OFF)
|
||||||
|
|
||||||
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
|
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
|
@ -164,6 +164,10 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
|
|||||||
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
|
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
|
||||||
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
||||||
|
|
||||||
|
# flash_attention sources
|
||||||
|
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
|
||||||
|
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
|
||||||
|
|
||||||
#Mem_eff attention sources
|
#Mem_eff attention sources
|
||||||
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
|
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
|
||||||
file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu")
|
file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu")
|
||||||
@ -175,6 +179,9 @@ if(USE_FLASH_ATTENTION)
|
|||||||
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
|
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
|
||||||
list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu})
|
list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu})
|
||||||
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
|
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
|
||||||
|
|
||||||
|
list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
|
||||||
|
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_MEM_EFF_ATTENTION)
|
if(USE_MEM_EFF_ATTENTION)
|
||||||
@ -284,10 +291,34 @@ endif()
|
|||||||
|
|
||||||
if(USE_ROCM)
|
if(USE_ROCM)
|
||||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||||
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip})
|
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||||
|
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||||
|
list(APPEND ATen_HIP_SRCS
|
||||||
|
${ATen_HIP_SRCS}
|
||||||
|
${hip_hip}
|
||||||
|
${native_hip_hip}
|
||||||
|
${native_nested_hip_hip}
|
||||||
|
${native_sparse_hip_hip}
|
||||||
|
${native_quantized_hip_hip}
|
||||||
|
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
||||||
|
)
|
||||||
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
||||||
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${native_quantized_cudnn_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
|
list(APPEND all_hip_cpp
|
||||||
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp})
|
${native_nested_hip_cpp}
|
||||||
|
${native_sparse_hip_cpp}
|
||||||
|
${native_quantized_hip_cpp}
|
||||||
|
${native_transformers_hip_cpp}
|
||||||
|
${native_quantized_cudnn_hip_cpp}
|
||||||
|
${hip_cpp}
|
||||||
|
${native_hip_cpp}
|
||||||
|
${native_hip_linalg_cpp}
|
||||||
|
${cuda_generated_sources}
|
||||||
|
${ATen_HIP_SRCS}
|
||||||
|
${native_miopen_cpp}
|
||||||
|
${native_cudnn_hip_cpp}
|
||||||
|
${miopen_cpp}
|
||||||
|
${all_hip_cpp}
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||||
|
@ -445,6 +445,13 @@ int64_t _fused_sdp_choice_meta(
|
|||||||
bool is_causal,
|
bool is_causal,
|
||||||
c10::optional<double> scale) {
|
c10::optional<double> scale) {
|
||||||
auto query_key_set = query_.key_set();
|
auto query_key_set = query_.key_set();
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
|
||||||
|
if (has_rocm) {
|
||||||
|
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||||
|
return choice_int;
|
||||||
|
}
|
||||||
|
#else
|
||||||
bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
|
bool has_cuda = query_key_set.has(c10::DispatchKey::CUDA);
|
||||||
if (has_cuda) {
|
if (has_cuda) {
|
||||||
auto choice_int = _fused_sdp_choice_stub(
|
auto choice_int = _fused_sdp_choice_stub(
|
||||||
@ -458,6 +465,7 @@ int64_t _fused_sdp_choice_meta(
|
|||||||
scale);
|
scale);
|
||||||
return choice_int;
|
return choice_int;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
return static_cast<int64_t>(sdp::SDPBackend::math);
|
return static_cast<int64_t>(sdp::SDPBackend::math);
|
||||||
}
|
}
|
||||||
namespace {
|
namespace {
|
||||||
@ -625,7 +633,8 @@ Tensor scaled_dot_product_attention(
|
|||||||
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||||
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
|
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
|
||||||
if (query_.device().type() == DeviceType::CUDA
|
if (query_.device().type() == DeviceType::CUDA
|
||||||
|| query_.device().type() == DeviceType::CPU){
|
|| query_.device().type() == DeviceType::CPU
|
||||||
|
|| query_.device().type() == DeviceType::HIP){
|
||||||
choice_int = _fused_sdp_choice_stub(query_.device().type(),
|
choice_int = _fused_sdp_choice_stub(query_.device().type(),
|
||||||
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/env.h>
|
#include <c10/util/env.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <c10/util/CallOnce.h>
|
||||||
|
|
||||||
#include <c10/core/SymInt.h>
|
#include <c10/core/SymInt.h>
|
||||||
#include <c10/util/string_view.h>
|
#include <c10/util/string_view.h>
|
||||||
@ -176,11 +177,42 @@ bool check_sm_version(cudaDeviceProp * dprops) {
|
|||||||
return is_gte_lower_bound && is_lte_upper_bound;
|
return is_gte_lower_bound && is_lte_upper_bound;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if USE_ROCM
|
||||||
|
c10::once_flag gcn_arch_override_flag;
|
||||||
|
const char* over_arch = nullptr;
|
||||||
|
|
||||||
|
void init_gcn_arch_override() {
|
||||||
|
over_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
|
||||||
|
if (over_arch) {
|
||||||
|
TORCH_WARN("SDPA functions only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
|
||||||
|
"Later changes to this environment variable with os.environ "
|
||||||
|
"(or other methods) will not affect SDPA function's behavior.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
|
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
|
||||||
// Check that the gpu is capable of running flash attention
|
// Check that the gpu is capable of running flash attention
|
||||||
using sm80 = SMVersion<8, 0>;
|
using sm80 = SMVersion<8, 0>;
|
||||||
using sm90 = SMVersion<9, 0>;
|
using sm90 = SMVersion<9, 0>;
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
#if USE_ROCM
|
||||||
|
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
|
||||||
|
const char* real_arch = dprops->gcnArchName;
|
||||||
|
c10::call_once(gcn_arch_override_flag, init_gcn_arch_override);
|
||||||
|
const char* arch = over_arch ? over_arch : real_arch;
|
||||||
|
if (mi200 != arch) {
|
||||||
|
if (debug) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Flash attention only supports gpu architecture gfx90a, for now. Attempting to run on a ",
|
||||||
|
arch,
|
||||||
|
".",
|
||||||
|
over_arch ? " This is overrided by PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE. Real architecture is " : "",
|
||||||
|
over_arch ? real_arch : "");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#else
|
||||||
if (!check_sm_version<sm80, sm90>(dprops)) {
|
if (!check_sm_version<sm80, sm90>(dprops)) {
|
||||||
if (debug) {
|
if (debug) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
@ -192,6 +224,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
651
aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip
Normal file
651
aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip
Normal file
@ -0,0 +1,651 @@
|
|||||||
|
/******************************************************************************
|
||||||
|
* Copyright (c) 2023, Advanced Micro Devices, Inc.
|
||||||
|
* Copyright (c) 2022, Tri Dao.
|
||||||
|
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
* * Redistributions of source code must retain the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer.
|
||||||
|
* * Redistributions in binary form must reproduce the above copyright
|
||||||
|
* notice, this list of conditions and the following disclaimer in the
|
||||||
|
* documentation and/or other materials provided with the distribution.
|
||||||
|
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||||
|
* names of its contributors may be used to endorse or promote products
|
||||||
|
* derived from this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
|
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
|
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||||
|
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||||
|
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||||
|
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||||
|
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
******************************************************************************/
|
||||||
|
#include <c10/core/ScalarType.h>
|
||||||
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include <ATen/ops/zeros.h>
|
||||||
|
|
||||||
|
#ifdef USE_FLASH_ATTENTION
|
||||||
|
#include <ATen/core/Tensor.h>
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||||
|
#include <ATen/hip/HIPGraphsUtils.cuh>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <ATen/ops/empty_like.h>
|
||||||
|
#include <ATen/ops/reshape.h>
|
||||||
|
#include <ATen/ops/scalar_tensor.h>
|
||||||
|
#include <ATen/ops/sum.h>
|
||||||
|
#include <ATen/ops/slice.h>
|
||||||
|
#include <ATen/ops/narrow.h>
|
||||||
|
#include <ATen/ops/pad.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <ATen/native/transformers/hip/flash_attn/flash_api.h>
|
||||||
|
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <c10/util/CallOnce.h>
|
||||||
|
|
||||||
|
// OORT headers
|
||||||
|
#include <oort/attn_fwd.h>
|
||||||
|
#include <oort/bwd_kernel_dk_dv.h>
|
||||||
|
#include <oort/bwd_kernel_dq.h>
|
||||||
|
#include <oort/bwd_preprocess.h>
|
||||||
|
|
||||||
|
namespace pytorch_flash {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
c10::once_flag fa_gcn_arch_override_flag;
|
||||||
|
const char* fa_override_arch = nullptr;
|
||||||
|
|
||||||
|
void init_fa_override_arch() {
|
||||||
|
fa_override_arch = std::getenv("PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE");
|
||||||
|
if (fa_override_arch) {
|
||||||
|
TORCH_WARN("ROCM flash attention backend only loads value from PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE once. "
|
||||||
|
"Later changes to this environment variable with os.environ "
|
||||||
|
"(or other methods) will not affect this backend's behavior.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void check_gpu_arch() {
|
||||||
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
|
||||||
|
constexpr std::string_view mi200 = "gfx90a:sramecc+:xnack-";
|
||||||
|
c10::call_once(fa_gcn_arch_override_flag, init_fa_override_arch);
|
||||||
|
if (fa_override_arch) {
|
||||||
|
TORCH_CHECK(mi200 == fa_override_arch,
|
||||||
|
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName) + " override as " + fa_override_arch);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(mi200 == dprops->gcnArchName,
|
||||||
|
"FlashAttention only supports MI200/MI250 GPUs (gfx90a:sramecc+:xnack-), current gcnArchName: " + std::string(dprops->gcnArchName));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||||
|
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||||
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
const float p_dropout,
|
||||||
|
const float softmax_scale,
|
||||||
|
bool is_causal,
|
||||||
|
const int window_size_left,
|
||||||
|
int window_size_right,
|
||||||
|
const bool return_softmax,
|
||||||
|
c10::optional<at::Generator> gen_) {
|
||||||
|
check_gpu_arch();
|
||||||
|
|
||||||
|
auto q_dtype = q.dtype();
|
||||||
|
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||||
|
"FlashAttention only support fp16 and bf16 data type");
|
||||||
|
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||||
|
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||||
|
|
||||||
|
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||||
|
|
||||||
|
// FIXME: ROCM probably does not need this
|
||||||
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
|
||||||
|
const auto sizes = q.sizes();
|
||||||
|
|
||||||
|
const int batch_size = sizes[0];
|
||||||
|
int seqlen_q = sizes[1];
|
||||||
|
int num_heads = sizes[2];
|
||||||
|
const int head_size_og = sizes[3];
|
||||||
|
const int seqlen_k = k.size(1);
|
||||||
|
const int num_heads_k = k.size(2);
|
||||||
|
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||||
|
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
|
||||||
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||||
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
|
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||||
|
if (is_causal) { window_size_right = 0; }
|
||||||
|
|
||||||
|
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||||
|
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||||
|
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||||
|
|
||||||
|
at::Tensor q_padded, k_padded, v_padded;
|
||||||
|
q_padded = q;
|
||||||
|
k_padded = k;
|
||||||
|
v_padded = v;
|
||||||
|
|
||||||
|
at::Tensor out;
|
||||||
|
if (out_.has_value()) {
|
||||||
|
out = out_.value();
|
||||||
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||||
|
CHECK_DEVICE(out);
|
||||||
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||||
|
if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
|
||||||
|
} else {
|
||||||
|
out = at::empty_like(q_padded);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||||
|
const int head_size = round_multiple(head_size_og, 8);
|
||||||
|
const int head_size_rounded = round_multiple(head_size, 32);
|
||||||
|
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||||
|
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||||
|
|
||||||
|
// We want to checkpoint and save the RNG state for backward if dropout
|
||||||
|
// We get the default generator and return the seed and offset which will
|
||||||
|
// be used in the backward function
|
||||||
|
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||||
|
at::Tensor seed_t, offset_t;
|
||||||
|
|
||||||
|
if (p_dropout > 0.0) {
|
||||||
|
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||||
|
// state
|
||||||
|
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||||
|
int64_t counter_offset = batch_size * num_heads * 32;
|
||||||
|
// See Note [Acquire lock when using random generators]
|
||||||
|
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||||
|
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||||
|
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||||
|
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||||
|
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||||
|
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||||
|
} else {
|
||||||
|
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||||
|
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
} else {
|
||||||
|
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||||
|
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||||
|
//reorder tensors and make contiguous
|
||||||
|
at::Tensor q_t = q_padded.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor k_t = k_padded.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor v_t = v_padded.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor output_t = out.permute({0,2,1,3}).contiguous();
|
||||||
|
|
||||||
|
at::Tensor M = at::empty({batch_size, num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse
|
||||||
|
|
||||||
|
constexpr int BLOCK_M = 16;
|
||||||
|
constexpr int BLOCK_N = 16;
|
||||||
|
dim3 grid;
|
||||||
|
grid.x = (q_t.sizes()[2] + BLOCK_M - 1) / BLOCK_M;
|
||||||
|
grid.y = q_t.sizes()[0] * q_t.sizes()[1];
|
||||||
|
grid.z = 1;
|
||||||
|
dim3 block { 64 * 4, 1, 1 }; // compiled triton kernel intrinsic
|
||||||
|
|
||||||
|
at::Tensor softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k},
|
||||||
|
at::dtype(q.dtype()).device(q.device()));
|
||||||
|
|
||||||
|
hipError_t err; // TODO: Error handling
|
||||||
|
#define CALL_FWD(FP, STAGE, BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX) \
|
||||||
|
do { \
|
||||||
|
oort::attn_fwd<STAGE,BLOCK_M, BLOCK_DMODEL, BLOCK_N, pre_load_v, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX> fwd_opt; \
|
||||||
|
err = fwd_opt(grid, block, \
|
||||||
|
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
|
||||||
|
softmax_scale, (float*)M.data_ptr(), (FP*)output_t.data_ptr(), \
|
||||||
|
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
|
||||||
|
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
|
||||||
|
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
|
||||||
|
output_t.stride(0), output_t.stride(1), output_t.stride(2), output_t.stride(3), \
|
||||||
|
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
|
||||||
|
*(uint64_t*)(seed_t.data_ptr()), *(uint32_t*)(offset_t.data_ptr()), \
|
||||||
|
(FP*)(softmax_fa_t.data_ptr()), \
|
||||||
|
stream); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
// TODO: Ugly but works
|
||||||
|
constexpr int kFwdUseCausal = 3;
|
||||||
|
constexpr int kFwdNoCausal = 1;
|
||||||
|
int d_head = q_t.sizes()[3];
|
||||||
|
constexpr int BM = BLOCK_M;
|
||||||
|
constexpr int BN = BLOCK_N;
|
||||||
|
if (q_dtype == at::kHalf) {
|
||||||
|
if (is_causal) {
|
||||||
|
if (d_head == 16)
|
||||||
|
CALL_FWD(__fp16,kFwdUseCausal,BM,16,BN,true,true,true);
|
||||||
|
else if (d_head == 32)
|
||||||
|
CALL_FWD(__fp16,kFwdUseCausal,BM,32,BN,true,true,true);
|
||||||
|
else if (d_head == 64)
|
||||||
|
CALL_FWD(__fp16,kFwdUseCausal,BM,64,BN,true,true,true);
|
||||||
|
else if (d_head == 128)
|
||||||
|
CALL_FWD(__fp16,kFwdUseCausal,BM,128,BN,true,true,true);
|
||||||
|
} else {
|
||||||
|
if (d_head == 16)
|
||||||
|
CALL_FWD(__fp16,kFwdNoCausal,BM,16,BN,true,true,true);
|
||||||
|
else if (d_head == 32)
|
||||||
|
CALL_FWD(__fp16,kFwdNoCausal,BM,32,BN,true,true,true);
|
||||||
|
else if (d_head == 64)
|
||||||
|
CALL_FWD(__fp16,kFwdNoCausal,BM,64,BN,true,true,true);
|
||||||
|
else if (d_head == 128)
|
||||||
|
CALL_FWD(__fp16,kFwdNoCausal,BM,128,BN,true,true,true);
|
||||||
|
}
|
||||||
|
} else if (q_dtype == at::kBFloat16) {
|
||||||
|
if (is_causal) {
|
||||||
|
if (d_head == 16)
|
||||||
|
CALL_FWD(__bf16,kFwdUseCausal,BM,16,BN,true,true,true);
|
||||||
|
else if (d_head == 32)
|
||||||
|
CALL_FWD(__bf16,kFwdUseCausal,BM,32,BN,true,true,true);
|
||||||
|
else if (d_head == 64)
|
||||||
|
CALL_FWD(__bf16,kFwdUseCausal,BM,64,BN,true,true,true);
|
||||||
|
else if (d_head == 128)
|
||||||
|
CALL_FWD(__bf16,kFwdUseCausal,BM,128,BN,true,true,true);
|
||||||
|
} else {
|
||||||
|
if (d_head == 16)
|
||||||
|
CALL_FWD(__bf16,kFwdNoCausal,BM,16,BN,true,true,true);
|
||||||
|
else if (d_head == 32)
|
||||||
|
CALL_FWD(__bf16,kFwdNoCausal,BM,32,BN,true,true,true);
|
||||||
|
else if (d_head == 64)
|
||||||
|
CALL_FWD(__bf16,kFwdNoCausal,BM,64,BN,true,true,true);
|
||||||
|
else if (d_head == 128)
|
||||||
|
CALL_FWD(__bf16,kFwdNoCausal,BM,128,BN,true,true,true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//undo reorder tensors
|
||||||
|
q_padded = q_t.permute({0,2,1,3}).contiguous();
|
||||||
|
k_padded = k_t.permute({0,2,1,3}).contiguous();
|
||||||
|
v_padded = v_t.permute({0,2,1,3}).contiguous();
|
||||||
|
out = output_t.permute({0,2,1,3}).contiguous();
|
||||||
|
|
||||||
|
return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t};
|
||||||
|
#undef CALL_FWD
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &cu_seqlens_q, // b+1
|
||||||
|
const at::Tensor &cu_seqlens_k, // b+1
|
||||||
|
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
|
const int max_seqlen_q,
|
||||||
|
const int max_seqlen_k,
|
||||||
|
const float p_dropout,
|
||||||
|
const float softmax_scale,
|
||||||
|
const bool zero_tensors,
|
||||||
|
const bool is_causal,
|
||||||
|
const int window_size_left,
|
||||||
|
int window_size_right,
|
||||||
|
const bool return_softmax,
|
||||||
|
c10::optional<at::Generator> gen_) {
|
||||||
|
|
||||||
|
TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm");
|
||||||
|
|
||||||
|
at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat));
|
||||||
|
at::Tensor p = at::empty({}, at::dtype(at::kFloat));
|
||||||
|
at::Tensor offset_t = at::empty({}, at::dtype(at::kLong));
|
||||||
|
at::Tensor seed_t = at::empty({}, at::dtype(at::kLong));
|
||||||
|
at::Tensor out = at::empty({}, at::dtype(at::kFloat));
|
||||||
|
|
||||||
|
return {out, q, k, v, softmax_lse, seed_t, offset_t, p};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||||
|
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
||||||
|
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||||
|
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
|
const float p_dropout, // probability to drop
|
||||||
|
const float softmax_scale,
|
||||||
|
const bool is_causal,
|
||||||
|
const int window_size_left,
|
||||||
|
int window_size_right,
|
||||||
|
const at::Tensor philox_seed,
|
||||||
|
const at::Tensor philox_offset) {
|
||||||
|
check_gpu_arch();
|
||||||
|
|
||||||
|
bool is_dropout = p_dropout > 0.0;
|
||||||
|
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||||
|
|
||||||
|
auto q_dtype = q.dtype();
|
||||||
|
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||||
|
"FlashAttention only support fp16 and bf16 data type");
|
||||||
|
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||||
|
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||||
|
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||||
|
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||||
|
|
||||||
|
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||||
|
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
||||||
|
|
||||||
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||||
|
|
||||||
|
const auto sizes = q.sizes();
|
||||||
|
|
||||||
|
const int batch_size = sizes[0];
|
||||||
|
const int seqlen_q = sizes[1];
|
||||||
|
const int num_heads = sizes[2];
|
||||||
|
const int head_size_og = dout.size(3);
|
||||||
|
const int head_size = sizes[3];
|
||||||
|
const int seqlen_k = k.size(1);
|
||||||
|
const int num_heads_k = k.size(2);
|
||||||
|
|
||||||
|
if (is_causal){
|
||||||
|
TORCH_CHECK((seqlen_q == seqlen_k), "For backwards kernel seqlen_q must equal seqlen_k for causal kernels");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||||
|
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||||
|
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||||
|
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||||
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
|
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||||
|
const int head_size_rounded = round_multiple(head_size, 32);
|
||||||
|
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||||
|
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||||
|
|
||||||
|
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||||
|
|
||||||
|
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||||
|
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
||||||
|
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
||||||
|
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
|
||||||
|
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
|
||||||
|
|
||||||
|
at::Tensor dq, dk, dv;
|
||||||
|
if (dq_.has_value()) {
|
||||||
|
dq = dq_.value();
|
||||||
|
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||||
|
CHECK_DEVICE(dq);
|
||||||
|
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
||||||
|
} else {
|
||||||
|
dq = at::empty_like(q);
|
||||||
|
}
|
||||||
|
if (dk_.has_value()) {
|
||||||
|
dk = dk_.value();
|
||||||
|
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||||
|
CHECK_DEVICE(dk);
|
||||||
|
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
||||||
|
} else {
|
||||||
|
dk = at::empty_like(k);
|
||||||
|
}
|
||||||
|
if (dv_.has_value()) {
|
||||||
|
dv = dv_.value();
|
||||||
|
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||||
|
CHECK_DEVICE(dv);
|
||||||
|
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
||||||
|
} else {
|
||||||
|
dv = at::empty_like(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// const at::Tensor& dout_padded = dout;
|
||||||
|
|
||||||
|
// bool loop = seqlen_k > blocksize_c;
|
||||||
|
// TODO: change later, for now set to true for simplicity
|
||||||
|
bool loop = true;
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||||
|
|
||||||
|
auto opts = q.options();
|
||||||
|
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||||
|
at::Tensor dq_accum;
|
||||||
|
at::Tensor dk_accum, dv_accum;
|
||||||
|
if (loop) {
|
||||||
|
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||||
|
// dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||||
|
// dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor dk_expanded, dv_expanded;
|
||||||
|
if (num_heads_k != num_heads) { // MQA / GQA
|
||||||
|
dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
||||||
|
dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
||||||
|
} else {
|
||||||
|
dk_expanded = dk;
|
||||||
|
dv_expanded = dv;
|
||||||
|
}
|
||||||
|
|
||||||
|
at::PhiloxCudaState philox_args;
|
||||||
|
if (is_dropout) {
|
||||||
|
if (at::cuda::currentStreamCaptureStatus() ==
|
||||||
|
at::cuda::CaptureStatus::None)
|
||||||
|
{
|
||||||
|
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||||
|
} else { // dropout + capture
|
||||||
|
philox_args = at::PhiloxCudaState(
|
||||||
|
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//JCG TODO WE GO IN HERE TODO backwards
|
||||||
|
//reorder tensors and make contiguous
|
||||||
|
at::Tensor q_t = q.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor k_t = k.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor v_t = v.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor out_t = out.permute({0,2,1,3}).contiguous();
|
||||||
|
|
||||||
|
//reorder tensors and make contiguous
|
||||||
|
at::Tensor dq_t = dq.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor dk_t = dk.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor dv_t = dv.permute({0,2,1,3}).contiguous();
|
||||||
|
at::Tensor dout_t = dout.permute({0,2,1,3}).contiguous();
|
||||||
|
|
||||||
|
dim3 block { 64 * 4, 1, 1 };
|
||||||
|
|
||||||
|
at::Tensor new_do = at::empty_like(dout_t).contiguous();
|
||||||
|
at::Tensor delta = at::empty_like(softmax_lse).contiguous();
|
||||||
|
|
||||||
|
int d_head = head_size_og;
|
||||||
|
hipError_t err; // TODO: Error handling
|
||||||
|
#define CALL_BWD_PP(FP, PP_BLOCK, PP_DMODEL) \
|
||||||
|
do { \
|
||||||
|
dim3 pp_grid; \
|
||||||
|
pp_grid.x = batch_size * num_heads * ((dout_t.size(2) + PP_BLOCK - 1) / PP_BLOCK); \
|
||||||
|
pp_grid.y = 1; \
|
||||||
|
pp_grid.z = 1; \
|
||||||
|
oort::bwd_preprocess<PP_BLOCK, PP_DMODEL> pre_opt; \
|
||||||
|
err = pre_opt(pp_grid, block, \
|
||||||
|
(FP*)(out_t.data_ptr()), \
|
||||||
|
(FP*)(dout_t.data_ptr()), \
|
||||||
|
(FP*)(new_do.data_ptr()), \
|
||||||
|
(float*)(delta.data_ptr()), \
|
||||||
|
stream); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CALL_BWD_PP_DMODEL(FP, PP_BLOCK) \
|
||||||
|
do { \
|
||||||
|
if (d_head == 16) \
|
||||||
|
CALL_BWD_PP(FP, PP_BLOCK, 16); \
|
||||||
|
else if (d_head == 32) \
|
||||||
|
CALL_BWD_PP(FP, PP_BLOCK, 32); \
|
||||||
|
else if (d_head == 64) \
|
||||||
|
CALL_BWD_PP(FP, PP_BLOCK, 64); \
|
||||||
|
else if (d_head == 128) \
|
||||||
|
CALL_BWD_PP(FP, PP_BLOCK, 128); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
if(q_dtype == at::kHalf) {
|
||||||
|
if (seqlen_q >= 64)
|
||||||
|
CALL_BWD_PP_DMODEL(__fp16, 16);
|
||||||
|
else
|
||||||
|
CALL_BWD_PP_DMODEL(__fp16, 16);
|
||||||
|
} else if (q_dtype == at::kBFloat16) {
|
||||||
|
if (seqlen_q >= 64)
|
||||||
|
CALL_BWD_PP_DMODEL(__bf16, 16);
|
||||||
|
else
|
||||||
|
CALL_BWD_PP_DMODEL(__bf16, 16);
|
||||||
|
}
|
||||||
|
#undef CALL_BWD_PP
|
||||||
|
|
||||||
|
#define CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT) \
|
||||||
|
do { \
|
||||||
|
dim3 grid; \
|
||||||
|
grid.x = (seqlen_k + BLOCK_M - 1) / BLOCK_M; \
|
||||||
|
grid.y = batch_size * num_heads; \
|
||||||
|
grid.z = 1; \
|
||||||
|
oort::bwd_kernel_dk_dv<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dk_dv_opt; \
|
||||||
|
err = dk_dv_opt(grid, block, \
|
||||||
|
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
|
||||||
|
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
|
||||||
|
(FP*)dk_t.data_ptr(),(FP*)dv_t.data_ptr(), \
|
||||||
|
(float*)(softmax_lse.data_ptr()), \
|
||||||
|
(float*)(delta.data_ptr()), \
|
||||||
|
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
|
||||||
|
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
|
||||||
|
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
|
||||||
|
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
|
||||||
|
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
|
||||||
|
grid.x = (seqlen_q + BLOCK_M - 1) / BLOCK_M; \
|
||||||
|
oort::bwd_kernel_dq<BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, ENABLE_DROPOUT> dq_opt; \
|
||||||
|
err = dq_opt(grid, block, \
|
||||||
|
(FP*)(q_t.data_ptr()), (FP*)(k_t.data_ptr()), (FP*)(v_t.data_ptr()), \
|
||||||
|
softmax_scale, (FP*)out_t.data_ptr(), (FP*)dout_t.data_ptr(), \
|
||||||
|
(FP*)dq_t.data_ptr(), \
|
||||||
|
(float*)(softmax_lse.data_ptr()), \
|
||||||
|
(float*)(delta.data_ptr()), \
|
||||||
|
q_t.stride(0), q_t.stride(1), q_t.stride(2), q_t.stride(3), \
|
||||||
|
k_t.stride(0), k_t.stride(1), k_t.stride(2), k_t.stride(3), \
|
||||||
|
v_t.stride(0), v_t.stride(1), v_t.stride(2), v_t.stride(3), \
|
||||||
|
q_t.sizes()[0], q_t.sizes()[1], seqlen_q, seqlen_k, p_dropout, \
|
||||||
|
(uint64_t)(philox_args.seed_.val), (uint32_t)(philox_args.offset_.val), stream); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
#define CALL_BWD_DROPOUT(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL) \
|
||||||
|
do { \
|
||||||
|
if (p_dropout > 0.0) { \
|
||||||
|
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, true); \
|
||||||
|
} else { \
|
||||||
|
CALL_BWD(FP, BLOCK_M, BLOCK_DMODEL, BLOCK_N, CAUSAL, false); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CALL_BWD_DROPOUT_DMODEL(FP, BLOCK_M, BLOCK_N, CAUSAL) \
|
||||||
|
do { \
|
||||||
|
if (d_head == 16) \
|
||||||
|
CALL_BWD_DROPOUT(FP, BLOCK_M, 16, BLOCK_N, CAUSAL); \
|
||||||
|
else if (d_head == 32) \
|
||||||
|
CALL_BWD_DROPOUT(FP, BLOCK_M, 32, BLOCK_N, CAUSAL); \
|
||||||
|
else if (d_head == 64) \
|
||||||
|
CALL_BWD_DROPOUT(FP, BLOCK_M, 64, BLOCK_N, CAUSAL); \
|
||||||
|
else if (d_head == 128) \
|
||||||
|
CALL_BWD_DROPOUT(FP, BLOCK_M, 128, BLOCK_N, CAUSAL); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
if (q_dtype == at::kHalf) {
|
||||||
|
if (is_causal) {
|
||||||
|
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, true);
|
||||||
|
} else {
|
||||||
|
CALL_BWD_DROPOUT_DMODEL(__fp16, 16, 16, false);
|
||||||
|
}
|
||||||
|
} else if (q_dtype == at::kBFloat16) {
|
||||||
|
if (is_causal) {
|
||||||
|
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, true);
|
||||||
|
} else {
|
||||||
|
CALL_BWD_DROPOUT_DMODEL(__bf16, 16, 16, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//undo reorder tensors for returns
|
||||||
|
dq = dq_t.permute({0,2,1,3}).contiguous();
|
||||||
|
dk = dk_t.permute({0,2,1,3}).contiguous();
|
||||||
|
dv = dv_t.permute({0,2,1,3}).contiguous();
|
||||||
|
|
||||||
|
// For MQA/GQA we need to sum dK and dV across the groups
|
||||||
|
if (num_heads_k != num_heads) {
|
||||||
|
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
||||||
|
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
||||||
|
}
|
||||||
|
return { dq, dk, dv, softmax_d };
|
||||||
|
#undef CALL_BWD_DROPOUT
|
||||||
|
#undef CALL_BWD
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
|
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||||
|
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &out, // total_q x num_heads x head_size
|
||||||
|
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||||
|
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||||
|
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
|
const at::Tensor &cu_seqlens_q, // b+1
|
||||||
|
const at::Tensor &cu_seqlens_k, // b+1
|
||||||
|
const int max_seqlen_q,
|
||||||
|
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||||
|
const float p_dropout, // probability to drop
|
||||||
|
const float softmax_scale,
|
||||||
|
const bool zero_tensors,
|
||||||
|
const bool is_causal,
|
||||||
|
const int window_size_left,
|
||||||
|
int window_size_right,
|
||||||
|
const at::Tensor philox_seed,
|
||||||
|
const at::Tensor philox_offset) {
|
||||||
|
TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");
|
||||||
|
|
||||||
|
at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat));
|
||||||
|
|
||||||
|
return { q, k, v, softmax_d };
|
||||||
|
}
|
||||||
|
} // namespace pytorch_fmha
|
||||||
|
|
||||||
|
#endif
|
@ -955,6 +955,7 @@ endif()
|
|||||||
if(USE_ROCM)
|
if(USE_ROCM)
|
||||||
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
|
||||||
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
|
hip_add_library(torch_hip ${Caffe2_HIP_SRCS})
|
||||||
|
target_link_libraries(torch_hip PRIVATE __caffe2_oort)
|
||||||
set(CUDA_LINK_LIBRARIES_KEYWORD)
|
set(CUDA_LINK_LIBRARIES_KEYWORD)
|
||||||
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
|
torch_compile_options(torch_hip) # see cmake/public/utils.cmake
|
||||||
# TODO: Not totally sure if this is live or not
|
# TODO: Not totally sure if this is live or not
|
||||||
@ -1305,6 +1306,9 @@ if(USE_ROCM)
|
|||||||
/opt/rocm/rocblas/include
|
/opt/rocm/rocblas/include
|
||||||
/opt/rocm/hipsparse/include
|
/opt/rocm/hipsparse/include
|
||||||
)
|
)
|
||||||
|
if(USE_FLASH_ATTENTION)
|
||||||
|
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(BUILD_LITE_INTERPRETER)
|
if(BUILD_LITE_INTERPRETER)
|
||||||
|
@ -1291,6 +1291,7 @@ if(USE_ROCM)
|
|||||||
message(STATUS "Disabling Kernel Assert for ROCm")
|
message(STATUS "Disabling Kernel Assert for ROCm")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
include(${CMAKE_CURRENT_LIST_DIR}/External/oort.cmake)
|
||||||
else()
|
else()
|
||||||
caffe2_update_option(USE_ROCM OFF)
|
caffe2_update_option(USE_ROCM OFF)
|
||||||
endif()
|
endif()
|
||||||
|
25
cmake/External/oort.cmake
vendored
Normal file
25
cmake/External/oort.cmake
vendored
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
if(NOT __OORT_INCLUDED)
|
||||||
|
set(__OORT_INCLUDED TRUE)
|
||||||
|
|
||||||
|
set(__OORT_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/src")
|
||||||
|
set(__OORT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/oort/build")
|
||||||
|
set(__OORT_INSTALL_DIR "${PROJECT_SOURCE_DIR}/torch")
|
||||||
|
ExternalProject_Add(oort_external
|
||||||
|
GIT_REPOSITORY https://github.com/ROCmSoftwarePlatform/triton.git
|
||||||
|
GIT_TAG 29e1252c1ac8e6a54deb883701e553e5b201a1ba
|
||||||
|
SOURCE_DIR ${__OORT_SOURCE_DIR}
|
||||||
|
SOURCE_SUBDIR mathaot
|
||||||
|
BINARY_DIR ${__OORT_BUILD_DIR}
|
||||||
|
PREFIX ${__OORT_INSTALL_DIR}
|
||||||
|
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__OORT_INSTALL_DIR}
|
||||||
|
# CONFIGURE_COMMAND ""
|
||||||
|
# BUILD_COMMAND ${MAKE_COMMAND}
|
||||||
|
BUILD_BYPRODUCTS "${__OORT_INSTALL_DIR}/lib/liboort.a"
|
||||||
|
# INSTALL_COMMAND ${MAKE_COMMAND} install
|
||||||
|
)
|
||||||
|
set(OORT_FOUND TRUE)
|
||||||
|
add_library(__caffe2_oort INTERFACE)
|
||||||
|
add_dependencies(__caffe2_oort oort_external)
|
||||||
|
target_link_libraries(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/lib/liboort.a)
|
||||||
|
target_include_directories(__caffe2_oort INTERFACE ${__OORT_INSTALL_DIR}/include)
|
||||||
|
endif() # __OORT_INCLUDED
|
@ -117,6 +117,7 @@ function(caffe2_print_configuration_summary)
|
|||||||
message(STATUS " USE_ROCM : ${USE_ROCM}")
|
message(STATUS " USE_ROCM : ${USE_ROCM}")
|
||||||
if(${USE_ROCM})
|
if(${USE_ROCM})
|
||||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||||
|
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||||
endif()
|
endif()
|
||||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||||
|
@ -20,6 +20,8 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
|||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
from torch.testing._internal.common_nn import NNTestCase
|
from torch.testing._internal.common_nn import NNTestCase
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
|
TEST_WITH_ROCM,
|
||||||
|
skipIfRocm,
|
||||||
TEST_FAIRSEQ,
|
TEST_FAIRSEQ,
|
||||||
run_tests,
|
run_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
@ -117,6 +119,18 @@ def query_key_value_clones(query: torch.Tensor, key: torch.Tensor, value: torch.
|
|||||||
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
|
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
|
||||||
return query_ref, key_ref, value_ref
|
return query_ref, key_ref, value_ref
|
||||||
|
|
||||||
|
def get_platform_specific_sdpa():
|
||||||
|
ret = []
|
||||||
|
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
||||||
|
ret.append(SDPBackend.FLASH_ATTENTION)
|
||||||
|
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
|
||||||
|
ret.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||||
|
if not ret:
|
||||||
|
# Add a placeholder, an empty list causes "An empty arg_values was passed to @parametrize"
|
||||||
|
ret.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
|
||||||
|
|
||||||
def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
|
def rand_sdpa_tensor(shape: SdpaShape, device: str, dtype: torch.dtype, type: str,
|
||||||
requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
|
requires_grad: bool = False, packed: bool = False) -> torch.Tensor:
|
||||||
@ -1212,6 +1226,7 @@ class TestTransformers(NNTestCase):
|
|||||||
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
|
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
@skipIfRocm # Missing EFFICIENT_ATTENTION
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
|
||||||
)
|
)
|
||||||
@ -1277,9 +1292,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"kernel",
|
"kernel",
|
||||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
PLATFORM_SPECIFIC_SDPA,
|
||||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION
|
|
||||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
|
||||||
)
|
)
|
||||||
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
|
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
|
||||||
with sdp_kernel(**backend_map[kernel]):
|
with sdp_kernel(**backend_map[kernel]):
|
||||||
@ -1297,9 +1310,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"kernel",
|
"kernel",
|
||||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
PLATFORM_SPECIFIC_SDPA,
|
||||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION
|
|
||||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
|
||||||
)
|
)
|
||||||
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
|
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
|
||||||
with sdp_kernel(**backend_map[kernel]):
|
with sdp_kernel(**backend_map[kernel]):
|
||||||
@ -1315,8 +1326,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
|
||||||
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
|
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
|
||||||
with sdp_kernel(**backend_map[kernel]):
|
with sdp_kernel(**backend_map[kernel]):
|
||||||
# Passing in a q,k,v with 0 length sequences will error
|
# Passing in a q,k,v with 0 length sequences will error
|
||||||
@ -1330,8 +1340,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
|
||||||
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
|
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
|
||||||
with sdp_kernel(**backend_map[kernel]):
|
with sdp_kernel(**backend_map[kernel]):
|
||||||
# Passing in a q,k,v with 0 length sequences will error
|
# Passing in a q,k,v with 0 length sequences will error
|
||||||
@ -1361,9 +1370,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"kernel",
|
"kernel",
|
||||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
PLATFORM_SPECIFIC_SDPA,
|
||||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION
|
|
||||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
|
||||||
)
|
)
|
||||||
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
|
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
|
||||||
with sdp_kernel(**backend_map[kernel]):
|
with sdp_kernel(**backend_map[kernel]):
|
||||||
@ -1436,6 +1443,7 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
_ = torch.nn.functional.scaled_dot_product_attention(
|
_ = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v, None, 0.0, False)
|
q, k, v, None, 0.0, False)
|
||||||
|
|
||||||
|
# Note: do not truncate the list according to platforms. These tests should always raise errors.
|
||||||
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
@parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||||
def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
|
def test_invalid_inputs_different_datatypes(self, device, kernel: SDPBackend):
|
||||||
with sdp_kernel(**backend_map[kernel]):
|
with sdp_kernel(**backend_map[kernel]):
|
||||||
@ -1467,7 +1475,8 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@skipIfRocm # Missing EFFICIENT_ATTENTION
|
||||||
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
|
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
|
||||||
# one of k,v needs to be broadcasted and other has non consistent seq_len dim
|
# one of k,v needs to be broadcasted and other has non consistent seq_len dim
|
||||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
|
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
|
||||||
@ -1788,6 +1797,9 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
query_padding_mask: (batch_size, seqlen_q)
|
query_padding_mask: (batch_size, seqlen_q)
|
||||||
key_padding_mask: (batch_size, seqlen_k)
|
key_padding_mask: (batch_size, seqlen_k)
|
||||||
"""
|
"""
|
||||||
|
if TEST_WITH_ROCM:
|
||||||
|
return S
|
||||||
|
|
||||||
b, h, seqlen_q, seqlen_k = S.shape
|
b, h, seqlen_q, seqlen_k = S.shape
|
||||||
warps_n = 4
|
warps_n = 4
|
||||||
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
|
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
|
||||||
@ -1954,6 +1966,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
|
||||||
|
|
||||||
|
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("type", ["dense", "nested"])
|
@parametrize("type", ["dense", "nested"])
|
||||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
@ -2066,6 +2079,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
# Cast up and compare
|
# Cast up and compare
|
||||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
|
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
@skipIfRocm # Small matrices
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
|
||||||
@parametrize("contiguous_inputs", [True, False])
|
@parametrize("contiguous_inputs", [True, False])
|
||||||
@parametrize("is_causal", [True, False])
|
@parametrize("is_causal", [True, False])
|
||||||
@ -2118,6 +2132,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
rtol = 7e-4 if dtype == torch.float16 else 7e-3
|
rtol = 7e-4 if dtype == torch.float16 else 7e-3
|
||||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
|
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
|
||||||
@parametrize("type", ["dense", "nested"])
|
@parametrize("type", ["dense", "nested"])
|
||||||
def test_fused_sdp_choice(self, device, type: str):
|
def test_fused_sdp_choice(self, device, type: str):
|
||||||
@ -2464,6 +2479,15 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||||
scale: str):
|
scale: str):
|
||||||
|
if TEST_WITH_ROCM:
|
||||||
|
def is_power_of_2(n):
|
||||||
|
return n & (n - 1) == 0
|
||||||
|
if not is_power_of_2(seq_len_q) or not is_power_of_2(seq_len_k) or not is_power_of_2(head_dim):
|
||||||
|
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q seq_len_k headdim, for now.")
|
||||||
|
if head_dim < 16 or seq_len_q < 16 or seq_len_k < 16:
|
||||||
|
self.skipTest("Flash attention on ROCM only supports power of two seq_len_q, seq_len_k, headdim >= 16, for now.")
|
||||||
|
if head_dim > 128:
|
||||||
|
self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")
|
||||||
|
|
||||||
if isSM86or89Device and head_dim in range(193, 256 + 1):
|
if isSM86or89Device and head_dim in range(193, 256 + 1):
|
||||||
self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled")
|
self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled")
|
||||||
@ -2540,7 +2564,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
|
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
|
||||||
|
|
||||||
# See [Note] Fused Tolerances above
|
# See [Note] Fused Tolerances above
|
||||||
output_fudge_factor = 3 if head_dim % 8 != 0 else 1
|
output_fudge_factor = 3 if head_dim % 8 != 0 or TEST_WITH_ROCM else 1
|
||||||
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
|
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
|
||||||
|
|
||||||
# TODO: Investigate why grad_q needs larger tolerances
|
# TODO: Investigate why grad_q needs larger tolerances
|
||||||
@ -2559,6 +2583,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
||||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||||
|
|
||||||
|
@skipIfRocm # FIXME: "capturing stream has unjoined work"
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||||
@parametrize("batch_size", [1, 8])
|
@parametrize("batch_size", [1, 8])
|
||||||
@parametrize("seq_len_q", [256, 512, 1024])
|
@parametrize("seq_len_q", [256, 512, 1024])
|
||||||
@ -2568,7 +2593,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
@parametrize("dropout_p", [0.0, 0.22])
|
@parametrize("dropout_p", [0.0, 0.22])
|
||||||
@parametrize("dtype", [torch.float16,])
|
@parametrize("dtype", [torch.float16,])
|
||||||
@parametrize("scale", [None, "l1"])
|
@parametrize("scale", [None, "l1"])
|
||||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
|
||||||
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
is_causal: bool,
|
is_causal: bool,
|
||||||
@ -2721,6 +2746,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
||||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||||
|
|
||||||
|
@skipIfRocm # Nested Tensor
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||||
@ -2755,6 +2781,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@skipIfRocm # Nested tensor
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||||
@ -2878,6 +2905,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
|
@skipIfRocm # Nested tensor
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||||
@parametrize("batch_size", [8, 32])
|
@parametrize("batch_size", [8, 32])
|
||||||
@parametrize("max_seq_len_q", [32, 256])
|
@parametrize("max_seq_len_q", [32, 256])
|
||||||
@ -3036,6 +3064,7 @@ class TestAttnMasks(NNTestCase):
|
|||||||
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
||||||
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
||||||
|
|
||||||
|
@skipIfRocm # No support for the second variant for now
|
||||||
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"shape",
|
"shape",
|
||||||
@ -3064,6 +3093,7 @@ class TestAttnMasks(NNTestCase):
|
|||||||
|
|
||||||
self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
|
self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
|
||||||
|
|
||||||
|
@skipIfRocm # No support for the second variant for now
|
||||||
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"shape",
|
"shape",
|
||||||
|
@ -90,7 +90,14 @@ includes = [
|
|||||||
"aten/src/ATen/native/nested/cuda/*",
|
"aten/src/ATen/native/nested/cuda/*",
|
||||||
"aten/src/ATen/native/sparse/cuda/*",
|
"aten/src/ATen/native/sparse/cuda/*",
|
||||||
"aten/src/ATen/native/quantized/cuda/*",
|
"aten/src/ATen/native/quantized/cuda/*",
|
||||||
"aten/src/ATen/native/transformers/cuda/*",
|
"aten/src/ATen/native/transformers/cuda/attention_backward.cu",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/attention.cu",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/sdp_utils.cpp",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/sdp_utils.h",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h",
|
||||||
|
"aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h",
|
||||||
"aten/src/THC/*",
|
"aten/src/THC/*",
|
||||||
"aten/src/ATen/test/*",
|
"aten/src/ATen/test/*",
|
||||||
# CMakeLists.txt isn't processed by default, but there are a few
|
# CMakeLists.txt isn't processed by default, but there are a few
|
||||||
|
@ -6,6 +6,7 @@ import torch.cuda
|
|||||||
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
|
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
|
||||||
import inspect
|
import inspect
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
|
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
|
||||||
@ -28,7 +29,23 @@ SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
|||||||
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
||||||
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
|
||||||
|
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and (not TEST_WITH_ROCM) and (not IS_WINDOWS) and SM80OrLater)
|
def evaluate_gfx90a_exact():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
||||||
|
arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
|
||||||
|
return arch == 'gfx90a:sramecc+:xnack-'
|
||||||
|
|
||||||
|
GFX90A_Exact = LazyVal(lambda: evaluate_gfx90a_exact())
|
||||||
|
|
||||||
|
def evaluate_platform_supports_flash_attention():
|
||||||
|
if TEST_WITH_ROCM:
|
||||||
|
return evaluate_gfx90a_exact()
|
||||||
|
if TEST_CUDA:
|
||||||
|
return not IS_WINDOWS and SM80OrLater
|
||||||
|
return False
|
||||||
|
|
||||||
|
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
|
||||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM)
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM)
|
||||||
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
|
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
|
||||||
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
|
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
|
||||||
|
@ -14265,6 +14265,15 @@ op_db: List[OpInfo] = [
|
|||||||
device_type='cpu'),
|
device_type='cpu'),
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
|
||||||
device_type='cpu'),
|
device_type='cpu'),
|
||||||
|
# TODO: Do not work even on MI200 because of stride mismatching.
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace',
|
||||||
|
device_type='cuda', dtypes=[torch.float16, torch.bfloat16],
|
||||||
|
active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace',
|
||||||
|
device_type='cuda', dtypes=[torch.float16, torch.bfloat16],
|
||||||
|
active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_amp',
|
||||||
|
device_type='cuda', active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION),
|
||||||
# When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false
|
# When changing input from Tensor to CompositeCompliantTensor, input.requires_grad() changes from true to false
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward',
|
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward',
|
||||||
device_type='cpu'),
|
device_type='cpu'),
|
||||||
|
@ -8572,6 +8572,8 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
|||||||
C10_MAPPINGS = collections.OrderedDict(
|
C10_MAPPINGS = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),
|
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),
|
||||||
|
("CUDA_LAUNCH_BLOCKING=1", ("AMD_SERIALIZE_KERNEL=3", API_C10)),
|
||||||
|
("CUDA_LAUNCH_BLOCKING", ("AMD_SERIALIZE_KERNEL", API_C10)),
|
||||||
("cuda::compat::", ("hip::compat::", API_C10)),
|
("cuda::compat::", ("hip::compat::", API_C10)),
|
||||||
("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)),
|
("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)),
|
||||||
("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)),
|
("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)),
|
||||||
|
Reference in New Issue
Block a user