mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update FlashAttention to work with sm90 Gpus (#97051)
# Summary FlashAttention was confirmed to work on h100 and sm90 hardware so we update the checks to account for this Pull Request resolved: https://github.com/pytorch/pytorch/pull/97051 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
37cde56658
commit
90537a779c
@ -214,7 +214,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||||
TORCH_CHECK(is_sm8x || is_sm75);
|
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||||
|
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
bool is_dropout = p_dropout > 0.0;
|
bool is_dropout = p_dropout > 0.0;
|
||||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||||
@ -369,14 +370,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|||||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||||
TORCH_CHECK(is_sm8x || is_sm75);
|
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||||
|
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||||
auto launch = &run_fmha_bwd;
|
auto launch = &run_fmha_bwd;
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
|
|
||||||
auto q_dtype = q.dtype();
|
auto q_dtype = q.dtype();
|
||||||
|
|
||||||
TORCH_CHECK(q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16));
|
TORCH_CHECK(q_dtype == at::kHalf || ((is_sm8x || is_sm90) && q_dtype == at::kBFloat16));
|
||||||
TORCH_CHECK(k.dtype() == q_dtype);
|
TORCH_CHECK(k.dtype() == q_dtype);
|
||||||
TORCH_CHECK(v.dtype() == q_dtype);
|
TORCH_CHECK(v.dtype() == q_dtype);
|
||||||
TORCH_CHECK(out.dtype() == q_dtype);
|
TORCH_CHECK(out.dtype() == q_dtype);
|
||||||
@ -417,7 +419,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|||||||
TORCH_CHECK(batch_size > 0);
|
TORCH_CHECK(batch_size > 0);
|
||||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||||
if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
|
if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
|
||||||
TORCH_CHECK(is_sm80);
|
TORCH_CHECK(is_sm80 || is_sm90);
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||||
|
@ -11,20 +11,24 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
|
|||||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||||
} else if (params.seqlen_k >= 256) {
|
} else if (params.seqlen_k >= 256) {
|
||||||
if (dprops->major == 8 && dprops->minor == 0) {
|
if ((dprops->major == 8 && dprops->minor == 0) ||
|
||||||
// Don't share smem for K & V, and don't keep V in registers
|
(dprops->major == 9 && dprops->minor == 0)) {
|
||||||
// This speeds things up by 2-3% by avoiding register spills, but it
|
// Don't share smem for K & V, and don't keep V in registers
|
||||||
// uses more shared memory, which is fine on A100 but not other GPUs.
|
// This speeds things up by 2-3% by avoiding register spills, but it
|
||||||
// For other GPUs, we keep V in registers.
|
// uses more shared memory, which is fine on A100 and H100 but not other
|
||||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
|
// GPUs. For other GPUs, we keep V in registers.
|
||||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
using Kernel_traits =
|
||||||
} else if (dprops->major == 8 && dprops->minor > 0) {
|
FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
|
||||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
|
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
} else if (dprops->major == 8 && dprops->minor > 0) {
|
||||||
} else if (dprops->major == 7 && dprops->minor == 5) {
|
using Kernel_traits =
|
||||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||||
}
|
} else if (dprops->major == 7 && dprops->minor == 5) {
|
||||||
|
using Kernel_traits =
|
||||||
|
FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||||
|
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
@ -492,10 +492,11 @@ inline bool check_gpu_sm75_or_greater(sdp_params params, bool debug) {
|
|||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||||
if (!(is_sm8x || is_sm75)) {
|
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||||
|
if (!(is_sm90 || is_sm8x || is_sm75)) {
|
||||||
if (debug) {
|
if (debug) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
"Flash attention only supports sm75 and sm8x gpu architectures. Attempting to run on a sm ",
|
"Flash attention only supports {sm75, sm8x, sm90} gpu architectures. Attempting to run on a sm ",
|
||||||
dprops->major,
|
dprops->major,
|
||||||
".",
|
".",
|
||||||
dprops->minor,
|
dprops->minor,
|
||||||
|
Reference in New Issue
Block a user