mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Grouped Query Attention (#128898)"
This reverts commit d039b14207fe659d664c590efc06cc0a2abc96c0. Reverted https://github.com/pytorch/pytorch/pull/128898 on behalf of https://github.com/albanD due to Broken test on main ([comment](https://github.com/pytorch/pytorch/pull/128898#issuecomment-2258314481))
This commit is contained in:
@ -14709,21 +14709,21 @@
|
||||
CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
|
||||
autogen: _native_multi_head_attention.out
|
||||
|
||||
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
|
||||
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
|
||||
python_module: nn
|
||||
variants: function
|
||||
autogen: scaled_dot_product_attention.out
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
# This aten function is kept so that we can test the choice function from Python
|
||||
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int
|
||||
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int
|
||||
dispatch:
|
||||
Meta: _fused_sdp_choice_meta
|
||||
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
|
||||
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
|
||||
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
|
@ -431,8 +431,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu(
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
|
||||
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
|
||||
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
|
||||
auto backend = sdp::select_sdp_backend_cpp(kernel_params);
|
||||
if (backend == sdp::SDPBackend::error) {
|
||||
TORCH_CHECK(
|
||||
@ -456,13 +456,12 @@ int64_t _fused_sdp_choice_meta(
|
||||
const std::optional<Tensor>& attn_mask_,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale,
|
||||
bool enable_gqa) {
|
||||
std::optional<double> scale) {
|
||||
auto query_key_set = query_.key_set();
|
||||
#if defined(USE_ROCM)
|
||||
bool has_rocm = query_key_set.has(c10::DispatchKey::HIP);
|
||||
if (has_rocm) {
|
||||
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
|
||||
auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||
return choice_int;
|
||||
}
|
||||
#else
|
||||
@ -476,8 +475,7 @@ int64_t _fused_sdp_choice_meta(
|
||||
attn_mask_,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa);
|
||||
scale);
|
||||
return choice_int;
|
||||
}
|
||||
#endif
|
||||
@ -610,36 +608,6 @@ bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tens
|
||||
return any_inputs_require_grad && gradmode_enabled;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input(
|
||||
const at::Tensor& query,
|
||||
const at::Tensor& key,
|
||||
const at::Tensor& value,
|
||||
const bool enable_gqa) {
|
||||
|
||||
if (!enable_gqa) {
|
||||
return std::make_tuple(key, value);
|
||||
}
|
||||
const auto q_num_heads = query.sym_size(-3);
|
||||
const auto k_num_heads = key.sym_size(-3);
|
||||
const auto v_num_heads = value.sym_size(-3);
|
||||
|
||||
bool all_equal = q_num_heads == k_num_heads && k_num_heads == v_num_heads;
|
||||
bool key_divisible = q_num_heads % k_num_heads == 0;
|
||||
bool value_divisible = q_num_heads % v_num_heads == 0;
|
||||
TORCH_CHECK(all_equal || (key_divisible && value_divisible),
|
||||
"Number of heads in key and value must divide the number of heads in ");
|
||||
|
||||
if (all_equal){
|
||||
return std::make_tuple(key, value);
|
||||
}
|
||||
auto repeat_key_shape = query.sym_size(-3) / key.sym_size(-3);
|
||||
auto repeat_value_shape = query.sym_size(-3) / value.sym_size(-3);
|
||||
|
||||
at::Tensor key_repeated = key.repeat_interleave_symint(repeat_key_shape, -3);
|
||||
at::Tensor value_repeated = value.repeat_interleave_symint(repeat_value_shape, -3);
|
||||
return std::make_tuple(std::move(key_repeated), std::move(value_repeated));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Computes scaled dot product attention on query, key and value tensors, using
|
||||
@ -678,13 +646,12 @@ Tensor scaled_dot_product_attention(
|
||||
const std::optional<Tensor>& attn_mask_,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
std::optional<double> scale,
|
||||
bool enable_gqa) {
|
||||
std::optional<double> scale) {
|
||||
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
|
||||
if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) {
|
||||
choice_int = _fused_sdp_choice_stub(query_.device().type(),
|
||||
query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa);
|
||||
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
|
||||
}
|
||||
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
|
||||
std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
|
||||
@ -746,9 +713,8 @@ Tensor scaled_dot_product_attention(
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
c10::nullopt, /*dropout_mask*/
|
||||
scale,
|
||||
enable_gqa));
|
||||
std::nullopt, /*dropout_mask*/
|
||||
scale));
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
@ -760,7 +726,7 @@ Tensor scaled_dot_product_attention(
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal,
|
||||
const std::optional<Tensor>& dropout_mask, std::optional<double> scale, bool enable_gqa) {
|
||||
const std::optional<Tensor>& dropout_mask, std::optional<double> scale) {
|
||||
C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback");
|
||||
if (query_.is_nested() || key.is_nested() || value.is_nested()) {
|
||||
TORCH_CHECK(
|
||||
@ -787,11 +753,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
|
||||
}
|
||||
|
||||
|
||||
// MQA/GQA handling
|
||||
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key, value, enable_gqa);
|
||||
auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor);
|
||||
auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
|
||||
if (attn_mask.has_value()) {
|
||||
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
|
||||
attn = attn.add(*attn_mask);
|
||||
@ -807,13 +769,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
|
||||
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
|
||||
auto dropout_scaling = 1.0 / (1 - dropout_p);
|
||||
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling), attn);
|
||||
return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
|
||||
} else {
|
||||
attn = at::dropout(attn, dropout_p, true);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(at::matmul(attn, value_expanded), attn);
|
||||
return std::make_tuple(at::matmul(attn, value), attn);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
|
@ -9,7 +9,7 @@ namespace at {
|
||||
namespace native {
|
||||
|
||||
using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa);
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale);
|
||||
|
||||
DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub);
|
||||
|
||||
|
@ -868,8 +868,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
|
||||
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
|
||||
const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){
|
||||
sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal};
|
||||
auto backend = select_sdp_backend(kernel_params);
|
||||
if (backend == sdp::SDPBackend::error) {
|
||||
TORCH_CHECK(
|
||||
|
@ -598,7 +598,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
if (has_only_dense_inputs(params)) {
|
||||
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_batch_size_and_num_heads_dense<true /*supports_grouped_query_attention=*/>,
|
||||
check_batch_size_and_num_heads_dense,
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
|
||||
for (auto& constraint : dense_constraints) {
|
||||
@ -655,9 +655,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
if (has_only_dense_inputs(params)) {
|
||||
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||
check_batch_size_and_num_heads_dense,
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>,
|
||||
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention=*/>);
|
||||
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
|
||||
for (auto& constraint : dense_constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
return false;
|
||||
|
@ -42,7 +42,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
|
||||
check_nested_tensor,
|
||||
check_for_dropout,
|
||||
check_tensor_shapes,
|
||||
check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>,
|
||||
check_batch_size_and_num_heads_dense,
|
||||
check_attn_mask_shape,
|
||||
check_head_dim_size_cpp,
|
||||
check_nonzero_sequence_lengths_dense,
|
||||
|
@ -48,7 +48,6 @@ struct sdp_params {
|
||||
std::optional<at::Tensor> attn_mask;
|
||||
double dropout;
|
||||
bool is_causal;
|
||||
bool enable_gqa;
|
||||
};
|
||||
|
||||
SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params);
|
||||
@ -354,46 +353,6 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
|
||||
const auto q_num_heads = params.query.sym_size(-3);
|
||||
const auto k_num_heads = params.key.sym_size(-3);
|
||||
const auto v_num_heads = params.value.sym_size(-3);
|
||||
const bool same_kv_heads = k_num_heads == v_num_heads;
|
||||
|
||||
if (!(same_kv_heads)){
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
|
||||
"Key sizes: ",
|
||||
params.key.sizes(),
|
||||
", Value sizes: ",
|
||||
params.value.sizes(),
|
||||
", Query sizes: ",
|
||||
params.query.sizes(),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// Check if grouped query attention is supported and validate the number of
|
||||
// heads
|
||||
if (q_num_heads % k_num_heads != 0) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.",
|
||||
"Got input Key sizes(): ",
|
||||
params.key.sym_size(-3),
|
||||
", Value sizes(): ",
|
||||
params.value.sym_size(-3),
|
||||
", Query sizes(): ",
|
||||
params.query.sym_size(-3),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <bool supports_gqa>
|
||||
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
|
||||
// This is expected to be called after check_tensor_shapes ensuring that the
|
||||
// size() calls won't error since the inputs are all 4 dimensional
|
||||
@ -405,36 +364,16 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
|
||||
bool same_batch_size =
|
||||
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
|
||||
|
||||
auto q_num_heads = params.query.sym_size(-3);
|
||||
auto k_num_heads = params.key.sym_size(-3);
|
||||
auto v_num_heads = params.value.sym_size(-3);
|
||||
|
||||
auto q_num_heads = params.query.sym_size(1);
|
||||
auto k_num_heads = params.key.sym_size(1);
|
||||
auto v_num_heads = params.value.sym_size(1);
|
||||
bool same_num_heads =
|
||||
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
|
||||
|
||||
if (!same_batch_size){
|
||||
if(debug) {
|
||||
TORCH_WARN(
|
||||
"For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
|
||||
"Query.sizes(): ",
|
||||
params.query.sizes(),
|
||||
", Key.sizes(): ",
|
||||
params.key.sizes(),
|
||||
", Value.sizes(): ",
|
||||
params.value.sizes(),
|
||||
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if(params.enable_gqa && supports_gqa){
|
||||
return check_grouped_query_attention(params, debug);
|
||||
}
|
||||
|
||||
if (!same_num_heads){
|
||||
if (!(same_batch_size && same_num_heads)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"For dense input, both fused kernels require query, key and value to have the same num_heads. ",
|
||||
"For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ",
|
||||
"Query.sizes(): ",
|
||||
params.query.sizes(),
|
||||
", Key sizes(): ",
|
||||
@ -445,7 +384,6 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// If all checks pass, return true
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -128,7 +128,7 @@ void quantize_tensor_per_tensor_affine_privateuse1(
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
|
||||
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale, bool enable_gqa){
|
||||
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){
|
||||
auto backend = sdp::SDPBackend::overrideable;
|
||||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@torch.compile(fullgraph=True, backend=counter)
|
||||
def fn(q, k, v, m):
|
||||
return SDPAParams(q, k, v, m, 0.1, True, False)
|
||||
return SDPAParams(q, k, v, m, 0.1, True)
|
||||
|
||||
q = torch.randn(10)
|
||||
k = torch.randn(10)
|
||||
@ -39,7 +39,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
m = torch.randn(10)
|
||||
o = fn(q, k, v, m)
|
||||
self.assertTrue(isinstance(o, SDPAParams))
|
||||
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
|
||||
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True))
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
def test_graph_break_SDPAParams(self):
|
||||
@ -48,7 +48,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@torch.compile(backend=counter)
|
||||
def fn(q, k, v, m):
|
||||
z = SDPAParams(q, k, v, m, 0.1, True, False)
|
||||
z = SDPAParams(q, k, v, m, 0.1, True)
|
||||
torch._dynamo.graph_break()
|
||||
return z, q + 1
|
||||
|
||||
@ -58,7 +58,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
m = torch.randn(10)
|
||||
o, _ = fn(q, k, v, m)
|
||||
self.assertTrue(isinstance(o, SDPAParams))
|
||||
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
|
||||
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True))
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
|
||||
def test_input_SDPAParams(self):
|
||||
@ -74,7 +74,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
k = torch.randn(10)
|
||||
v = torch.randn(10)
|
||||
m = torch.randn(10)
|
||||
s = SDPAParams(q, k, v, m, 0.1, True, False)
|
||||
s = SDPAParams(q, k, v, m, 0.1, True)
|
||||
o, _ = fn(s, q)
|
||||
self.assertIs(o, s)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
@ -86,7 +86,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
@torch.compile(fullgraph=True, backend=counter)
|
||||
def fn(q, k, v, m):
|
||||
q += 1
|
||||
z = SDPAParams(q, k, v, m, 0.1, True, False)
|
||||
z = SDPAParams(q, k, v, m, 0.1, True)
|
||||
a = z.query
|
||||
return a + 1, z, q
|
||||
|
||||
@ -95,7 +95,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase):
|
||||
v = torch.randn(10)
|
||||
m = torch.randn(10)
|
||||
_, o, _ = fn(q, k, v, m)
|
||||
expected = SDPAParams(q, k, v, m, 0.1, True, False)
|
||||
expected = SDPAParams(q, k, v, m, 0.1, True)
|
||||
self.assert_ref_equals_params(o, expected)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
|
@ -1561,36 +1561,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@skipIfRocm # Nested Tensor
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
|
||||
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||
|
||||
with sdpa_kernel(fused_kernel):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
||||
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
|
||||
"key and value to have"):
|
||||
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
|
||||
is_causal=False, enable_gqa=True)
|
||||
|
||||
@onlyCPU
|
||||
@skipIfRocm # Nested Tensor
|
||||
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
|
||||
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||
rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||
|
||||
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
||||
with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, "
|
||||
"key and value to have"):
|
||||
F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0,
|
||||
is_causal=False, enable_gqa=True)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
|
||||
@parametrize("kernel", PLATFORM_SPECIFIC_SDPA)
|
||||
@ -1742,8 +1712,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
seq_len_list = [2, 4, 5, 6, 7]
|
||||
shape = SdpaShape(5, 8, seq_len_list, 57)
|
||||
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
|
||||
q, k, v = make_tensor().transpose(1, 2), make_tensor().transpose(1, 2), make_tensor().transpose(1, 2)
|
||||
|
||||
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
||||
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
@ -1823,7 +1792,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel"):
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
@onlyCUDA
|
||||
@ -2980,32 +2949,23 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
@parametrize("dropout_p", [0.0, 0.22, 0.48])
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@parametrize("scale", [None, "l1"])
|
||||
@parametrize("enable_gqa", [True, False])
|
||||
@parametrize("n_heads", [[16, 8], [10, 2]])
|
||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str, enable_gqa: bool, n_heads: List[int]):
|
||||
scale: str):
|
||||
if isSM8XDevice and head_dim in range(193, 256 + 1):
|
||||
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
|
||||
if is_causal and seq_len_q != seq_len_k:
|
||||
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
|
||||
if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1:
|
||||
torch.cuda.empty_cache() # Prevent memory fragmentation
|
||||
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
|
||||
unittest.skip("Reference implementation OOM")
|
||||
return
|
||||
|
||||
scale = scale if scale is None else (1 / head_dim)
|
||||
num_heads_q = num_heads_kv = 4
|
||||
if enable_gqa:
|
||||
num_heads_q = n_heads[0]
|
||||
num_heads_kv = n_heads[1]
|
||||
|
||||
query = torch.rand(batch_size, num_heads_q, seq_len_q, head_dim,
|
||||
n_heads = 4
|
||||
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
|
||||
device=device, dtype=dtype, requires_grad=True)
|
||||
key = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim, device=device,
|
||||
key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device,
|
||||
dtype=dtype, requires_grad=True)
|
||||
value = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim,
|
||||
value = torch.rand(batch_size, n_heads, seq_len_k, head_dim,
|
||||
device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32
|
||||
@ -3015,15 +2975,14 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
if not is_dropout:
|
||||
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
out = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
|
||||
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
# High Precision Math Reference
|
||||
out_ref = F.scaled_dot_product_attention(
|
||||
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
|
||||
query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale)
|
||||
# Low Precision Math Reference
|
||||
out_lp_ref = F.scaled_dot_product_attention(
|
||||
query, key, value, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)
|
||||
query, key, value, is_causal=is_causal, scale=scale)
|
||||
else:
|
||||
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
|
||||
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
|
||||
@ -3050,12 +3009,11 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
dropout_mask = softmax_mask >= 0
|
||||
# High Precision Math Reference
|
||||
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
|
||||
scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
|
||||
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
|
||||
# Low Precision Math Reference
|
||||
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
|
||||
dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
|
||||
dropout_mask=dropout_mask)[0]
|
||||
|
||||
upstream_grad = torch.rand_like(out, requires_grad=False)
|
||||
|
||||
@ -3075,7 +3033,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
'out': 1.5,
|
||||
'grad_query': 13.0,
|
||||
'grad_key': 2.0,
|
||||
'grad_value': 1.75,
|
||||
'grad_value': 1.5,
|
||||
}
|
||||
)
|
||||
|
||||
@ -3227,7 +3185,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@skipIfRocm # Nested Tensor
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
|
@ -466,7 +466,6 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
"dropout_p: float = 0.0",
|
||||
"is_causal: bool = False",
|
||||
"scale: Optional[float] = None",
|
||||
"enable_gqa: bool = False",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
@ -1956,7 +1956,6 @@ class _SDPAParams:
|
||||
attn_mask: Optional[Tensor]
|
||||
dropout: _float
|
||||
is_causal: _bool
|
||||
enable_gqa: _bool
|
||||
def __init__(
|
||||
self,
|
||||
query: Tensor,
|
||||
@ -1964,8 +1963,7 @@ class _SDPAParams:
|
||||
value: Tensor,
|
||||
attn_mask: Optional[Tensor],
|
||||
dropout: _float,
|
||||
is_causal: _bool,
|
||||
enable_gqa: _bool) -> None: ...
|
||||
is_causal: _bool) -> None: ...
|
||||
|
||||
class _SDPBackend(Enum):
|
||||
ERROR = -1
|
||||
|
@ -33,9 +33,6 @@ class SDPAParamsVariable(VariableTracker):
|
||||
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
|
||||
value.is_causal
|
||||
)
|
||||
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
|
||||
value.enable_gqa
|
||||
)
|
||||
param_vars = [
|
||||
query_var,
|
||||
key_var,
|
||||
@ -43,7 +40,6 @@ class SDPAParamsVariable(VariableTracker):
|
||||
attn_mask_var,
|
||||
dropout_var,
|
||||
is_causal_var,
|
||||
enable_gqa_var,
|
||||
]
|
||||
return TorchInGraphFunctionVariable(SDPAParams).call_function(
|
||||
tx, param_vars, {}
|
||||
|
@ -1953,24 +1953,16 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
at::Tensor const& value,
|
||||
std::optional<at::Tensor> attn_mask,
|
||||
double dropout,
|
||||
bool is_causal,
|
||||
bool enable_gqa) {
|
||||
bool is_causal) {
|
||||
return sdp::sdp_params{
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
std::move(attn_mask),
|
||||
dropout,
|
||||
is_causal,
|
||||
enable_gqa};
|
||||
query, key, value, std::move(attn_mask), dropout, is_causal};
|
||||
}))
|
||||
.def_readonly("query", &sdp::sdp_params::query)
|
||||
.def_readonly("key", &sdp::sdp_params::key)
|
||||
.def_readonly("value", &sdp::sdp_params::value)
|
||||
.def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
|
||||
.def_readonly("dropout", &sdp::sdp_params::dropout)
|
||||
.def_readonly("is_causal", &sdp::sdp_params::is_causal)
|
||||
.def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
|
||||
.def_readonly("is_causal", &sdp::sdp_params::is_causal);
|
||||
|
||||
py::enum_<sdp::SDPBackend>(
|
||||
py_module,
|
||||
|
@ -261,7 +261,7 @@ def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa):
|
||||
def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
|
||||
if (
|
||||
not flash_sdp_enabled()
|
||||
and not mem_efficient_sdp_enabled()
|
||||
@ -275,7 +275,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable
|
||||
SDPBackend.MATH,
|
||||
)
|
||||
|
||||
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa)
|
||||
params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
|
||||
|
||||
for backend in ordering:
|
||||
if backend == SDPBackend.FLASH_ATTENTION:
|
||||
@ -622,7 +622,6 @@ def jagged_scaled_dot_product_attention(
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
enable_gqa=False,
|
||||
):
|
||||
_validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
|
||||
# for mypy, ugh
|
||||
@ -653,7 +652,7 @@ def jagged_scaled_dot_product_attention(
|
||||
compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
|
||||
|
||||
backend_choice = _select_sdp_backend(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
|
||||
query, key, value, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
|
||||
if backend_choice == SDPBackend.FLASH_ATTENTION:
|
||||
|
@ -173,7 +173,6 @@ class CausalBias(torch.Tensor):
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Handles the logic for computing attention with the specified causal bias.
|
||||
@ -190,7 +189,6 @@ class CausalBias(torch.Tensor):
|
||||
are set.
|
||||
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
|
||||
to :math:`\frac{1}{\sqrt{E}}`.
|
||||
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
|
||||
|
||||
Returns:
|
||||
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
|
||||
@ -214,13 +212,10 @@ class CausalBias(torch.Tensor):
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
|
||||
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
|
||||
sdpa_params = SDPAParams(
|
||||
query, key, value, None, dropout_p, is_causal, enable_gqa
|
||||
)
|
||||
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
|
||||
if can_use_flash_attention(sdpa_params):
|
||||
needs_padding = query.size(-1) % 8 != 0
|
||||
og_head_size = query.size(-1)
|
||||
@ -269,7 +264,6 @@ class CausalBias(torch.Tensor):
|
||||
dropout_p=dropout_p,
|
||||
is_causal=False,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -5606,21 +5606,20 @@ def _in_projection(
|
||||
|
||||
scaled_dot_product_attention = _add_docstr(
|
||||
torch._C._nn.scaled_dot_product_attention,
|
||||
r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
|
||||
is_causal=False, scale=None, enable_gqa=False) -> Tensor:
|
||||
r"""
|
||||
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor:
|
||||
|
||||
Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed,
|
||||
and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be
|
||||
specified as a keyword argument.
|
||||
Computes scaled dot product attention on query, key and value tensors, using
|
||||
an optional attention mask if passed, and applying dropout if a probability
|
||||
greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument.
|
||||
|
||||
.. code-block:: python
|
||||
.. code-block:: python
|
||||
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
|
||||
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
|
||||
L, S = query.size(-2), key.size(-2)
|
||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||
attn_bias = torch.zeros(L, S, dtype=query.dtype)
|
||||
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
||||
if is_causal:
|
||||
assert attn_mask is None
|
||||
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
||||
@ -5631,22 +5630,17 @@ scaled_dot_product_attention = _add_docstr(
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_bias += attn_mask
|
||||
|
||||
if enable_gqa:
|
||||
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
||||
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
||||
|
||||
attn_bias = attn_mask + attn_bias
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight += attn_bias
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
||||
return attn_weight @ value
|
||||
|
||||
.. warning::
|
||||
This function is beta and subject to change.
|
||||
.. warning:: This function is beta and subject to change.
|
||||
|
||||
.. warning::
|
||||
|
||||
.. warning::
|
||||
This function always applies dropout according to the specified ``dropout_p`` argument.
|
||||
To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module
|
||||
that makes the function call is not in training mode.
|
||||
@ -5661,10 +5655,9 @@ scaled_dot_product_attention = _add_docstr(
|
||||
self.p = p
|
||||
|
||||
def forward(self, ...):
|
||||
return F.scaled_dot_product_attention(...,
|
||||
dropout_p=(self.p if self.training else 0.0))
|
||||
return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0))
|
||||
|
||||
Note:
|
||||
Note:
|
||||
|
||||
There are currently three supported implementations of scaled dot product attention:
|
||||
|
||||
@ -5695,24 +5688,16 @@ scaled_dot_product_attention = _add_docstr(
|
||||
The c++ implementation supports torch.float64 and can be used when higher precision is required.
|
||||
For more information please see :doc:`/notes/numerical_accuracy`
|
||||
|
||||
Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
|
||||
and math kernel on CUDA tensor, and does not support Nested tensor.
|
||||
Constraints for GQA:
|
||||
|
||||
- number_of_heads_query % number_of_heads_key_value == 0 and,
|
||||
- number_of_heads_key == number_of_heads_value
|
||||
|
||||
Note:
|
||||
|
||||
Note:
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(
|
||||
""".format(
|
||||
**reproducibility_notes
|
||||
)
|
||||
+ r"""
|
||||
Args:
|
||||
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
|
||||
key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`.
|
||||
value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`.
|
||||
Args:
|
||||
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
|
||||
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
|
||||
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
|
||||
attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights,
|
||||
which is :math:`(N,..., L, S)`. Two types of masks are supported.
|
||||
A boolean mask where a value of True indicates that the element *should* take part in attention.
|
||||
@ -5724,21 +5709,19 @@ scaled_dot_product_attention = _add_docstr(
|
||||
An error is thrown if both attn_mask and is_causal are set.
|
||||
scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set
|
||||
to :math:`\frac{1}{\sqrt{E}}`.
|
||||
enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
|
||||
|
||||
Returns:
|
||||
output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`.
|
||||
|
||||
Shape legend:
|
||||
Returns:
|
||||
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
|
||||
|
||||
Shape legend:
|
||||
- :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}`
|
||||
- :math:`S: \text{Source sequence length}`
|
||||
- :math:`L: \text{Target sequence length}`
|
||||
- :math:`E: \text{Embedding dimension of the query and key}`
|
||||
- :math:`Ev: \text{Embedding dimension of the value}`
|
||||
- :math:`Hq: \text{Number of heads of query}`
|
||||
- :math:`H: \text{Number of heads of key and value}`
|
||||
|
||||
Examples:
|
||||
Examples:
|
||||
|
||||
>>> # Optionally use the context manager to ensure one of the fused kernels is run
|
||||
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
|
||||
@ -5748,21 +5731,12 @@ scaled_dot_product_attention = _add_docstr(
|
||||
>>> F.scaled_dot_product_attention(query,key,value)
|
||||
|
||||
|
||||
>>> # Sample for GQA for llama3
|
||||
>>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
|
||||
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
|
||||
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
|
||||
>>> with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
>>> F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
|
||||
|
||||
|
||||
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
|
||||
.. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning:
|
||||
https://arxiv.org/abs/2307.08691
|
||||
.. _Memory-Efficient Attention:
|
||||
.. _Memory-Efficient Attention:
|
||||
https://github.com/facebookresearch/xformers
|
||||
.. _Grouped-Query Attention:
|
||||
https://arxiv.org/pdf/2305.13245
|
||||
""",
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
|
@ -8688,7 +8688,6 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
|
||||
num_heads_q_gqa, num_heads_kv_gqa = 32, 8
|
||||
|
||||
dim_3_q_shape = (batch, seq_q, head_dim)
|
||||
dim_3_kv_shape = (batch, seq_kv, head_dim)
|
||||
@ -8699,8 +8698,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
||||
|
||||
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
|
||||
samples = []
|
||||
for qkv_shape, is_causal, dropout_p, enable_gqa in product(
|
||||
qkv_shapes, [True, False], [0.0, 0.5], [True, False]):
|
||||
for qkv_shape, is_causal, dropout_p in product(
|
||||
qkv_shapes, [True, False], [0.0, 0.5]):
|
||||
shape_q, shape_kv = qkv_shape
|
||||
samples.append(SampleInput(
|
||||
make(shape_q),
|
||||
@ -8730,15 +8729,6 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
||||
dropout_p=0.0)
|
||||
)
|
||||
|
||||
samples.append(
|
||||
SampleInput(
|
||||
make((batch, num_heads_q_gqa, seq_q, head_dim)),
|
||||
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
|
||||
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
|
||||
enable_gqa=True
|
||||
)
|
||||
)
|
||||
|
||||
yield from samples
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user