Update FlashAttention to work with sm90 Gpus (#97051)

# Summary
FlashAttention was confirmed to work on h100 and sm90 hardware so we update the checks to account for this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97051
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2023-03-19 19:33:57 +00:00
committed by PyTorch MergeBot
parent 37cde56658
commit 90537a779c
3 changed files with 27 additions and 20 deletions

View File

@ -214,7 +214,8 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75); bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0; bool is_dropout = p_dropout > 0.0;
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax); Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
@ -369,14 +370,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0; bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75); bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
auto launch = &run_fmha_bwd; auto launch = &run_fmha_bwd;
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype(); auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16)); TORCH_CHECK(q_dtype == at::kHalf || ((is_sm8x || is_sm90) && q_dtype == at::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype); TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype); TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(out.dtype() == q_dtype); TORCH_CHECK(out.dtype() == q_dtype);
@ -417,7 +419,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size > 0); TORCH_CHECK(batch_size > 0);
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128)); TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK(is_sm80); TORCH_CHECK(is_sm80 || is_sm90);
} }
CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(q, total_q, num_heads, head_size);

View File

@ -11,20 +11,24 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const b
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure); run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) { } else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) { if ((dprops->major == 8 && dprops->minor == 0) ||
// Don't share smem for K & V, and don't keep V in registers (dprops->major == 9 && dprops->minor == 0)) {
// This speeds things up by 2-3% by avoiding register spills, but it // Don't share smem for K & V, and don't keep V in registers
// uses more shared memory, which is fine on A100 but not other GPUs. // This speeds things up by 2-3% by avoiding register spills, but it
// For other GPUs, we keep V in registers. // uses more shared memory, which is fine on A100 and H100 but not other
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; // GPUs. For other GPUs, we keep V in registers.
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure); using Kernel_traits =
} else if (dprops->major == 8 && dprops->minor > 0) { FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure); } else if (dprops->major == 8 && dprops->minor > 0) {
} else if (dprops->major == 7 && dprops->minor == 5) { using Kernel_traits =
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure); run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} } else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits =
FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
} }
})); }));
} }

View File

@ -492,10 +492,11 @@ inline bool check_gpu_sm75_or_greater(sdp_params params, bool debug) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
if (!(is_sm8x || is_sm75)) { bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
if (!(is_sm90 || is_sm8x || is_sm75)) {
if (debug) { if (debug) {
TORCH_WARN( TORCH_WARN(
"Flash attention only supports sm75 and sm8x gpu architectures. Attempting to run on a sm ", "Flash attention only supports {sm75, sm8x, sm90} gpu architectures. Attempting to run on a sm ",
dprops->major, dprops->major,
".", ".",
dprops->minor, dprops->minor,