Grouped Query Attention (#132689)

### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
This commit is contained in:
Apurva Jain
2024-08-07 05:35:36 +00:00
committed by PyTorch MergeBot
parent 527f104a69
commit 8bc5ef563e
19 changed files with 372 additions and 170 deletions

View File

@ -14711,21 +14711,21 @@
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
autogen: _native_multi_head_attention.out
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
tags: nondeterministic_seeded
# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int
dispatch:
Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
variants: function
tags: nondeterministic_seeded

View File

@ -430,8 +430,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
}
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
auto backend = sdp::select_sdp_backend_cpp(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(
@ -455,12 +455,13 @@ int64_t _fused_sdp_choice_meta(
const std::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal,
std::optional<double> scale) {
std::optional<double> scale,
bool enable_gqa) {
auto query_key_set = query_.key_set();
#if defined(USE_ROCM)
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
if (has_rocm) {
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
return choice_int;
}
#else
@ -474,7 +475,8 @@ int64_t _fused_sdp_choice_meta(
attn_mask_,
dropout_p,
is_causal,
scale);
scale,
enable_gqa);
return choice_int;
}
#endif
@ -607,6 +609,36 @@ bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tens
return any_inputs_require_grad && gradmode_enabled;
}
std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const bool enable_gqa) {
if (!enable_gqa) {
return std::make_tuple(key, value);
}
const auto q_num_heads = query.sym_size(-3);
const auto k_num_heads = key.sym_size(-3);
const auto v_num_heads = value.sym_size(-3);
bool all_equal = q_num_heads == k_num_heads && k_num_heads == v_num_heads;
bool key_divisible = q_num_heads % k_num_heads == 0;
bool value_divisible = q_num_heads % v_num_heads == 0;
TORCH_CHECK(all_equal || (key_divisible && value_divisible),
"Number of heads in key and value must divide the number of heads in ");
if (all_equal){
return std::make_tuple(key, value);
}
auto repeat_key_shape = query.sym_size(-3) / key.sym_size(-3);
auto repeat_value_shape = query.sym_size(-3) / value.sym_size(-3);
at::Tensor key_repeated = key.repeat_interleave_symint(repeat_key_shape, -3);
at::Tensor value_repeated = value.repeat_interleave_symint(repeat_value_shape, -3);
return std::make_tuple(std::move(key_repeated), std::move(value_repeated));
}
} // namespace
// Computes scaled dot product attention on query, key and value tensors, using
@ -645,12 +677,13 @@ Tensor scaled_dot_product_attention(
const std::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal,
std::optional<double> scale) {
std::optional<double> scale,
bool enable_gqa) {
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
}
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
@ -712,8 +745,9 @@ Tensor scaled_dot_product_attention(
attn_mask,
dropout_p,
is_causal,
std::nullopt, /*dropout_mask*/
scale));
c10::nullopt, /*dropout_mask*/
scale,
enable_gqa));
default:
TORCH_CHECK(
false,
@ -725,7 +759,7 @@ Tensor scaled_dot_product_attention(
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
const std::optional<Tensor>& dropout_mask, std::optional<double> scale) {
const std::optional<Tensor>& dropout_mask, std::optional<double> scale, bool enable_gqa) {
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
TORCH_CHECK(
@ -781,7 +815,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
}
auto attn = at::matmul(query, key_acc.transpose(-2, -1) * scaling_factor);
// MQA/GQA handling
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key_acc, value_acc, enable_gqa);
auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor);
if (attn_mask.has_value()) {
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
attn = attn.add(*attn_mask);
@ -797,13 +835,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
auto dropout_scaling = 1.0 / (1 - dropout_p);
return std::make_tuple(at::matmul(attn, value_acc * dropout_scaling).to(origin_dtype), attn.to(origin_dtype));
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling).to(origin_dtype), attn.to(origin_dtype));
} else {
attn = at::dropout(attn, dropout_p, true);
}
}
return std::make_tuple(at::matmul(attn, value_acc).to(origin_dtype), attn.to(origin_dtype));
return std::make_tuple(at::matmul(attn, value_expanded).to(origin_dtype), attn.to(origin_dtype));
}
std::tuple<at::Tensor, at::Tensor>

View File

@ -9,7 +9,7 @@ namespace at {
namespace native {
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale);
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);

View File

@ -560,7 +560,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);
sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false};
sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false};
auto backend = select_sdp_backend(kernel_params);
// strides from packed projection for nested tensors when seq_len is 1 will be
// and will trigger a contiguous call in the kernel, so we prevent this
@ -868,8 +868,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
}
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
auto backend = select_sdp_backend(kernel_params);
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(

View File

@ -607,7 +607,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_batch_size_and_num_heads_dense<true /*supports_grouped_query_attention=*/>,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) {
@ -665,9 +665,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>,
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention=*/>);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;

View File

@ -42,7 +42,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
check_nested_tensor,
check_for_dropout,
check_tensor_shapes,
check_batch_size_and_num_heads_dense,
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>,
check_attn_mask_shape,
check_head_dim_size_cpp,
check_nonzero_sequence_lengths_dense,

View File

@ -48,6 +48,7 @@ struct sdp_params {
std::optional<at::Tensor> attn_mask;
double dropout;
bool is_causal;
bool enable_gqa;
};
SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
@ -353,6 +354,46 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
return true;
}
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
const auto q_num_heads = params.query.sym_size(-3);
const auto k_num_heads = params.key.sym_size(-3);
const auto v_num_heads = params.value.sym_size(-3);
const bool same_kv_heads = k_num_heads == v_num_heads;
if (!(same_kv_heads)){
if (debug) {
TORCH_WARN(
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
"Key sizes: ",
params.key.sizes(),
", Value sizes: ",
params.value.sizes(),
", Query sizes: ",
params.query.sizes(),
" instead.");
}
return false;
}
// Check if grouped query attention is supported and validate the number of
// heads
if (q_num_heads % k_num_heads != 0) {
if (debug) {
TORCH_WARN(
"FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
"Got input Key sizes(): ",
params.key.sym_size(-3),
", Value sizes(): ",
params.value.sym_size(-3),
", Query sizes(): ",
params.query.sym_size(-3),
" instead.");
}
return false;
}
return true;
}
template <bool supports_gqa>
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional
@ -364,16 +405,36 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
auto q_num_heads = params.query.sym_size(1);
auto k_num_heads = params.key.sym_size(1);
auto v_num_heads = params.value.sym_size(1);
auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool same_num_heads =
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
if (!(same_batch_size && same_num_heads)) {
if (!same_batch_size){
if(debug) {
TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
"Query.sizes(): ",
params.query.sizes(),
", Key.sizes(): ",
params.key.sizes(),
", Value.sizes(): ",
params.value.sizes(),
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
}
return false;
}
if(params.enable_gqa && supports_gqa){
return check_grouped_query_attention(params, debug);
}
if (!same_num_heads){
if (debug) {
TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ",
"For dense input, both fused kernels require query, key and value to have the same num_heads. ",
"Query.sizes(): ",
params.query.sizes(),
", Key sizes(): ",
@ -384,6 +445,7 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
}
return false;
}
// If all checks pass, return true
return true;
}

View File

@ -128,7 +128,7 @@ void quantize_tensor_per_tensor_affine_privateuse1(
}
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale, bool enable_gqa){
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}

View File

@ -302,8 +302,9 @@ class DistMatrixOpsTest(DTensorTestBase):
# TODO: Add test cases where is_causal=False and an attention mask is provided.
# Gaps include missing op support for aten.masked_fill_.Scalar.
is_causal = True
enable_gqa = False
params = torch.backends.cuda.SDPAParams(
query, key, value, None, dropout_p, is_causal
query, key, value, None, dropout_p, is_causal, enable_gqa
)
if torch.backends.cuda.can_use_flash_attention(params, debug=False):
available_backends.append(SDPBackend.FLASH_ATTENTION)

View File

@ -31,7 +31,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m):
return SDPAParams(q, k, v, m, 0.1, True)
return SDPAParams(q, k, v, m, 0.1, True, False)
q = torch.randn(10)
k = torch.randn(10)
@ -39,7 +39,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
m = torch.randn(10)
o = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assertEqual(counter.frame_count, 1)
def test_graph_break_SDPAParams(self):
@ -48,7 +48,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(backend=counter)
def fn(q, k, v, m):
z = SDPAParams(q, k, v, m, 0.1, True)
z = SDPAParams(q, k, v, m, 0.1, True, False)
torch._dynamo.graph_break()
return z, q + 1
@ -58,7 +58,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
m = torch.randn(10)
o, _ = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assertEqual(counter.frame_count, 2)
def test_input_SDPAParams(self):
@ -74,7 +74,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
k = torch.randn(10)
v = torch.randn(10)
m = torch.randn(10)
s = SDPAParams(q, k, v, m, 0.1, True)
s = SDPAParams(q, k, v, m, 0.1, True, False)
o, _ = fn(s, q)
self.assertIs(o, s)
self.assertEqual(counter.frame_count, 1)
@ -86,7 +86,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m):
q += 1
z = SDPAParams(q, k, v, m, 0.1, True)
z = SDPAParams(q, k, v, m, 0.1, True, False)
a = z.query
return a + 1, z, q
@ -95,7 +95,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
v = torch.randn(10)
m = torch.randn(10)
_, o, _ = fn(q, k, v, m)
expected = SDPAParams(q, k, v, m, 0.1, True)
expected = SDPAParams(q, k, v, m, 0.1, True, False)
self.assert_ref_equals_params(o, expected)
self.assertEqual(counter.frame_count, 1)

View File

@ -1561,6 +1561,36 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))
@onlyCUDA
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
with sdpa_kernel(fused_kernel):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
"key and value to have"):
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
is_causal=False, enable_gqa=True)
@onlyCPU
@skipIfRocm # Nested Tensor
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
"key and value to have"):
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
is_causal=False, enable_gqa=True)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
@ -1712,7 +1742,8 @@ class TestSDPAFailureModes(NNTestCase):
seq_len_list = [2, 4, 5, 6, 7]
shape = SdpaShape(5, 8, seq_len_list, 57)
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor()
q, k, v = make_tensor().transpose(1, 2), make_tensor().transpose(1, 2), make_tensor().transpose(1, 2)
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
@ -1792,7 +1823,7 @@ class TestSDPAFailureModes(NNTestCase):
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
out = torch.nn.functional.scaled_dot_product_attention(
torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
@onlyCUDA
@ -2949,23 +2980,32 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("dropout_p", [0.0, 0.22, 0.48])
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("scale", [None, "l1"])
@parametrize("enable_gqa", [True, False])
@parametrize("n_heads", [[16, 8], [10, 2]])
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str):
scale: str, enable_gqa: bool, n_heads: List[int]):
if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1:
torch.cuda.empty_cache() # Prevent memory fragmentation
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
unittest.skip("Reference implementation OOM")
return
scale = scale if scale is None else (1 / head_dim)
n_heads = 4
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
num_heads_q = num_heads_kv = 4
if enable_gqa:
num_heads_q = n_heads[0]
num_heads_kv = n_heads[1]
query = torch.rand(batch_size, num_heads_q, seq_len_q, head_dim,
device=device, dtype=dtype, requires_grad=True)
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
key = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim, device=device,
dtype=dtype, requires_grad=True)
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
value = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim,
device=device, dtype=dtype, requires_grad=True)
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
@ -2975,14 +3015,15 @@ class TestSDPACudaOnly(NNTestCase):
if not is_dropout:
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
out = F.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
with sdpa_kernel(backends=[SDPBackend.MATH]):
# High Precision Math Reference
out_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
# Low Precision Math Reference
out_lp_ref = F.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, scale=scale)
query, key, value, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
else:
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
@ -3009,11 +3050,12 @@ class TestSDPACudaOnly(NNTestCase):
dropout_mask = softmax_mask >= 0
# High Precision Math Reference
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
dropout_mask=dropout_mask)[0]
dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
upstream_grad = torch.rand_like(out, requires_grad=False)
@ -3185,6 +3227,7 @@ class TestSDPACudaOnly(NNTestCase):
}
)
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if

View File

@ -466,6 +466,7 @@ def gen_nn_functional(fm: FileManager) -> None:
"dropout_p: float = 0.0",
"is_causal: bool = False",
"scale: Optional[float] = None",
"enable_gqa: bool = False",
]
)
)

View File

@ -1958,6 +1958,7 @@ class _SDPAParams:
attn_mask: Optional[Tensor]
dropout: _float
is_causal: _bool
enable_gqa: _bool
def __init__(
self,
query: Tensor,
@ -1965,7 +1966,8 @@ class _SDPAParams:
value: Tensor,
attn_mask: Optional[Tensor],
dropout: _float,
is_causal: _bool) -> None: ...
is_causal: _bool,
enable_gqa: _bool) -> None: ...
class _SDPBackend(Enum):
ERROR = -1

View File

@ -34,6 +34,9 @@ class SDPAParamsVariable(VariableTracker):
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [
query_var,
key_var,
@ -41,6 +44,7 @@ class SDPAParamsVariable(VariableTracker):
attn_mask_var,
dropout_var,
is_causal_var,
enable_gqa_var,
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {}

View File

@ -1955,16 +1955,24 @@ Call this whenever a new thread is created in order to propagate values from
at::Tensor const& value,
std::optional<at::Tensor> attn_mask,
double dropout,
bool is_causal) {
bool is_causal,
bool enable_gqa) {
return sdp::sdp_params{
query, key, value, std::move(attn_mask), dropout, is_causal};
query,
key,
value,
std::move(attn_mask),
dropout,
is_causal,
enable_gqa};
}))
.def_readonly("query", &sdp::sdp_params::query)
.def_readonly("key", &sdp::sdp_params::key)
.def_readonly("value", &sdp::sdp_params::value)
.def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
.def_readonly("dropout", &sdp::sdp_params::dropout)
.def_readonly("is_causal", &sdp::sdp_params::is_causal);
.def_readonly("is_causal", &sdp::sdp_params::is_causal)
.def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
py::enum_<sdp::SDPBackend>(
py_module,

View File

@ -262,7 +262,7 @@ def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
return True
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
if (
not flash_sdp_enabled()
and not mem_efficient_sdp_enabled()
@ -276,7 +276,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
SDPBackend.MATH,
)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
for backend in ordering:
if backend == SDPBackend.FLASH_ATTENTION:
@ -623,6 +623,7 @@ def jagged_scaled_dot_product_attention(
dropout_p=0.0,
is_causal=False,
scale=None,
enable_gqa=False,
):
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
# for mypy, ugh
@ -653,7 +654,7 @@ def jagged_scaled_dot_product_attention(
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
backend_choice = _select_sdp_backend(
query, key, value, attn_mask, dropout_p, is_causal
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
)
if backend_choice == SDPBackend.FLASH_ATTENTION:

View File

@ -175,6 +175,7 @@ class CausalBias(torch.Tensor):
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
r"""
Handles the logic for computing attention with the specified causal bias.
@ -191,6 +192,7 @@ class CausalBias(torch.Tensor):
are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
@ -214,10 +216,13 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=enable_gqa,
)
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
sdpa_params = SDPAParams(
query, key, value, None, dropout_p, is_causal, enable_gqa
)
if can_use_flash_attention(sdpa_params):
needs_padding = query.size(-1) % 8 != 0
og_head_size = query.size(-1)
@ -266,6 +271,7 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=enable_gqa,
)
else:
raise ValueError(

View File

@ -5606,138 +5606,164 @@ def _in_projection(
scaled_dot_product_attention = _add_docstr(
torch._C._nn.scaled_dot_product_attention,
r"""
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor:
r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> Tensor:
Computes scaled dot product attention on query, key and value tensors, using
an optional attention mask if passed, and applying dropout if a probability
greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument.
.. code-block:: python
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
.. warning:: This function is beta and subject to change.
.. warning::
This function always applies dropout according to the specified ``dropout_p`` argument.
To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
that makes the function call is not in training mode.
For example:
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
specified as a keyword argument.
.. code-block:: python
class MyModel(nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.p = p
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
def forward(self, ...):
return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
Note:
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
There are currently three supported implementations of scaled dot product attention:
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
- `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_
- `Memory-Efficient Attention`_
- A PyTorch implementation defined in C++ matching the above formulation
.. warning::
This function is beta and subject to change.
The function may call optimized kernels for improved performance when using the CUDA backend.
For all other backends, the PyTorch implementation will be used.
.. warning::
This function always applies dropout according to the specified ``dropout_p`` argument.
To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
that makes the function call is not in training mode.
All implementations are enabled by default. Scaled dot product attention attempts to automatically select the
most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation
is used, the following functions are provided for enabling and disabling implementations.
The context manager is the preferred mechanism:
For example:
- :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations.
- :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention.
- :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention.
- :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation.
.. code-block:: python
Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation,
disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`.
In the event that a fused implementation is not available, a warning will be raised with the
reasons why the fused implementation cannot run.
class MyModel(nn.Module):
def __init__(self, p=0.5):
super().__init__()
self.p = p
Due to the nature of fusing floating point operations, the output of this function may be different
depending on what backend kernel is chosen.
The c++ implementation supports torch.float64 and can be used when higher precision is required.
For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
def forward(self, ...):
return F.scaled_dot_product_attention(...,
dropout_p=(self.p if self.training else 0.0))
Note:
There are currently three supported implementations of scaled dot product attention:
- `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_
- `Memory-Efficient Attention`_
- A PyTorch implementation defined in C++ matching the above formulation
The function may call optimized kernels for improved performance when using the CUDA backend.
For all other backends, the PyTorch implementation will be used.
All implementations are enabled by default. Scaled dot product attention attempts to automatically select the
most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation
is used, the following functions are provided for enabling and disabling implementations.
The context manager is the preferred mechanism:
- :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations.
- :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention.
- :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables Memory-Efficient Attention.
- :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables the PyTorch C++ implementation.
Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation,
disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`.
In the event that a fused implementation is not available, a warning will be raised with the
reasons why the fused implementation cannot run.
Due to the nature of fusing floating point operations, the output of this function may be different
depending on what backend kernel is chosen.
The c++ implementation supports torch.float64 and can be used when higher precision is required.
For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
For more information please see :doc:`/notes/numerical_accuracy`
Note:
{cudnn_reproducibility_note}
""".format(
Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
and math kernel on CUDA tensor, and does not support Nested tensor.
Constraints for GQA:
- number_of_heads_query % number_of_heads_key_value == 0 and,
- number_of_heads_key == number_of_heads_value
Note:
{cudnn_reproducibility_note}
""".format(
**reproducibility_notes
)
+ r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
square matrix. The attention masking has the form of the upper left causal bias due to the alignment
(see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
which is :math:`(N,..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a
square matrix. The attention masking has the form of the upper left causal bias due to the alignment
(see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix.
An error is thrown if both attn_mask and is_causal are set.
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}`
- :math:`Hq: \text{Number of heads of query}`
- :math:`H: \text{Number of heads of key and value}`
Examples:
>>> # Optionally use the context manager to ensure one of the fused kernels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with torch.backends.cuda.sdp_kernel(enable_math=False):
>>> F.scaled_dot_product_attention(query,key,value)
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
Shape legend:
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
- :math:`S: \text{Source sequence length}`
- :math:`L: \text{Target sequence length}`
- :math:`E: \text{Embedding dimension of the query and key}`
- :math:`Ev: \text{Embedding dimension of the value}`
Examples:
>>> # Optionally use the context manager to ensure one of the fused kernels is run
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with torch.backends.cuda.sdp_kernel(enable_math=False):
>>> F.scaled_dot_product_attention(query,key,value)
>>> # Sample for GQA for llama3
>>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with sdpa_kernel(backends=[SDPBackend.MATH]):
>>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
https://arxiv.org/abs/2307.08691
.. _Memory-Efficient Attention:
https://github.com/facebookresearch/xformers
""",
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
https://arxiv.org/abs/2307.08691
.. _Memory-Efficient Attention:
https://github.com/facebookresearch/xformers
.. _Grouped-Query Attention:
https://arxiv.org/pdf/2305.13245
""",
)

View File

@ -8689,6 +8689,7 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
num_heads_q_gqa, num_heads_kv_gqa = 32, 8
dim_3_q_shape = (batch, seq_q, head_dim)
dim_3_kv_shape = (batch, seq_kv, head_dim)
@ -8699,8 +8700,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
samples = []
for qkv_shape, is_causal, dropout_p in product(
qkv_shapes, [True, False], [0.0, 0.5]):
for qkv_shape, is_causal, dropout_p, enable_gqa in product(
qkv_shapes, [True, False], [0.0, 0.5], [True, False]):
shape_q, shape_kv = qkv_shape
samples.append(SampleInput(
make(shape_q),
@ -8730,6 +8731,15 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
dropout_p=0.0)
)
samples.append(
SampleInput(
make((batch, num_heads_q_gqa, seq_q, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
enable_gqa=True
)
)
yield from samples