BM FM FlashAttention Test (#151974)

Reviewed By: joebos

Differential Revision: D72880307

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151974
Approved by: https://github.com/yoyoyocmu, https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
Chenye Zhao
2025-04-25 19:24:22 +00:00
committed by PyTorch MergeBot
parent 8542d55f0c
commit 9336608307

View File

@ -276,7 +276,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };