mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
BM FM FlashAttention Test (#151974)
Reviewed By: joebos Differential Revision: D72880307 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151974 Approved by: https://github.com/yoyoyocmu, https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
8542d55f0c
commit
9336608307
@ -276,7 +276,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
|
||||
TORCH_CHECK(head_size_8x <= 256, "CK FlashAttention backward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
|
Reference in New Issue
Block a user