mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Grouped Query Attention (#128898)
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898 Approved by: https://github.com/drisspg
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							05a8540041
						
					
				
				
					commit
					d039b14207
				
			| @ -14709,21 +14709,21 @@ | ||||
|     CUDA, NestedTensorCUDA: native_multi_head_attention_cuda | ||||
|   autogen: _native_multi_head_attention.out | ||||
|  | ||||
| - func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor | ||||
| - func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor | ||||
|   python_module: nn | ||||
|   variants: function | ||||
|   autogen: scaled_dot_product_attention.out | ||||
|   tags: nondeterministic_seeded | ||||
|  | ||||
| # This aten function is kept so that we can test the choice function from Python | ||||
| - func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int | ||||
| - func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int | ||||
|   dispatch: | ||||
|     Meta: _fused_sdp_choice_meta | ||||
|     CPU, NestedTensorCPU: _fused_sdp_choice_cpp | ||||
|     CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda | ||||
|   tags: nondeterministic_seeded | ||||
|  | ||||
| - func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) | ||||
| - func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) | ||||
|   variants: function | ||||
|   tags: nondeterministic_seeded | ||||
|  | ||||
|  | ||||
| @ -431,8 +431,8 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cpu( | ||||
| } | ||||
|  | ||||
| int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){ | ||||
|   sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal}; | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){ | ||||
|   sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; | ||||
|   auto backend = sdp::select_sdp_backend_cpp(kernel_params); | ||||
|   if (backend == sdp::SDPBackend::error) { | ||||
|     TORCH_CHECK( | ||||
| @ -456,12 +456,13 @@ int64_t _fused_sdp_choice_meta( | ||||
|     const std::optional<Tensor>& attn_mask_, | ||||
|     double dropout_p, | ||||
|     bool is_causal, | ||||
|     std::optional<double> scale) { | ||||
|     std::optional<double> scale, | ||||
|     bool enable_gqa) { | ||||
|   auto query_key_set = query_.key_set(); | ||||
| #if defined(USE_ROCM) | ||||
|   bool has_rocm = query_key_set.has(c10::DispatchKey::HIP); | ||||
|   if (has_rocm) { | ||||
|     auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale); | ||||
|     auto choice_int = _fused_sdp_choice_stub(at::kHIP, query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa); | ||||
|     return choice_int; | ||||
|   } | ||||
| #else | ||||
| @ -475,7 +476,8 @@ int64_t _fused_sdp_choice_meta( | ||||
|         attn_mask_, | ||||
|         dropout_p, | ||||
|         is_causal, | ||||
|         scale); | ||||
|         scale, | ||||
|         enable_gqa); | ||||
|     return choice_int; | ||||
|   } | ||||
| #endif | ||||
| @ -608,6 +610,36 @@ bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tens | ||||
|   return any_inputs_require_grad && gradmode_enabled; | ||||
| } | ||||
|  | ||||
| std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input( | ||||
|     const at::Tensor& query, | ||||
|     const at::Tensor& key, | ||||
|     const at::Tensor& value, | ||||
|     const bool enable_gqa) { | ||||
|  | ||||
|   if (!enable_gqa) { | ||||
|     return std::make_tuple(key, value); | ||||
|   } | ||||
|   const auto q_num_heads = query.sym_size(-3); | ||||
|   const auto k_num_heads = key.sym_size(-3); | ||||
|   const auto v_num_heads = value.sym_size(-3); | ||||
|  | ||||
|   bool all_equal = q_num_heads == k_num_heads && k_num_heads == v_num_heads; | ||||
|   bool key_divisible = q_num_heads % k_num_heads == 0; | ||||
|   bool value_divisible = q_num_heads % v_num_heads == 0; | ||||
|   TORCH_CHECK(all_equal || (key_divisible && value_divisible), | ||||
|               "Number of heads in key and value must divide the number of heads in "); | ||||
|  | ||||
|   if (all_equal){ | ||||
|     return std::make_tuple(key, value); | ||||
|   } | ||||
|   auto repeat_key_shape = query.sym_size(-3) / key.sym_size(-3); | ||||
|   auto repeat_value_shape = query.sym_size(-3) / value.sym_size(-3); | ||||
|  | ||||
|   at::Tensor key_repeated = key.repeat_interleave_symint(repeat_key_shape, -3); | ||||
|   at::Tensor value_repeated = value.repeat_interleave_symint(repeat_value_shape, -3); | ||||
|   return std::make_tuple(std::move(key_repeated), std::move(value_repeated)); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| // Computes scaled dot product attention on query, key and value tensors, using | ||||
| @ -646,12 +678,13 @@ Tensor scaled_dot_product_attention( | ||||
|     const std::optional<Tensor>& attn_mask_, | ||||
|     double dropout_p, | ||||
|     bool is_causal, | ||||
|     std::optional<double> scale) { | ||||
|     std::optional<double> scale, | ||||
|     bool enable_gqa) { | ||||
|   validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal, scale); | ||||
|   int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math); | ||||
|   if (_fused_sdp_choice_stub.is_device_supported(query_.device().type())) { | ||||
|     choice_int = _fused_sdp_choice_stub(query_.device().type(), | ||||
|           query_, key, value, attn_mask_, dropout_p, is_causal, scale); | ||||
|           query_, key, value, attn_mask_, dropout_p, is_causal, scale, enable_gqa); | ||||
|   } | ||||
|   sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int); | ||||
|   std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); | ||||
| @ -713,8 +746,9 @@ Tensor scaled_dot_product_attention( | ||||
|           attn_mask, | ||||
|           dropout_p, | ||||
|           is_causal, | ||||
|           std::nullopt, /*dropout_mask*/ | ||||
|           scale)); | ||||
|           c10::nullopt, /*dropout_mask*/ | ||||
|           scale, | ||||
|           enable_gqa)); | ||||
|     default: | ||||
|       TORCH_CHECK( | ||||
|           false, | ||||
| @ -726,7 +760,7 @@ Tensor scaled_dot_product_attention( | ||||
| std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math( | ||||
|         const Tensor& query_, const Tensor& key, const Tensor& value, | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, | ||||
|         const std::optional<Tensor>& dropout_mask, std::optional<double> scale) { | ||||
|         const std::optional<Tensor>& dropout_mask, std::optional<double> scale, bool enable_gqa) { | ||||
|   C10_LOG_API_USAGE_ONCE("torch.sdpa.math_fallback"); | ||||
|   if (query_.is_nested() || key.is_nested() || value.is_nested()) { | ||||
|     TORCH_CHECK( | ||||
| @ -753,7 +787,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math( | ||||
|         attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril(); | ||||
|         attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype()); | ||||
|     } | ||||
|     auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor); | ||||
|  | ||||
|  | ||||
|     // MQA/GQA handling | ||||
|     auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key, value, enable_gqa); | ||||
|     auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor); | ||||
|     if (attn_mask.has_value()) { | ||||
|       if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { | ||||
|         attn = attn.add(*attn_mask); | ||||
| @ -769,13 +807,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math( | ||||
|         TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes."); | ||||
|         attn = attn.masked_fill(dropout_mask->logical_not(), 0.0); | ||||
|         auto dropout_scaling = 1.0 / (1 - dropout_p); | ||||
|         return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn); | ||||
|         return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling), attn); | ||||
|       } else { | ||||
|         attn = at::dropout(attn, dropout_p, true); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     return std::make_tuple(at::matmul(attn, value), attn); | ||||
|     return std::make_tuple(at::matmul(attn, value_expanded), attn); | ||||
| } | ||||
|  | ||||
| std::tuple<at::Tensor, at::Tensor> | ||||
|  | ||||
| @ -9,7 +9,7 @@ namespace at { | ||||
| namespace native { | ||||
|  | ||||
| using fused_sdp_choice_fn = int64_t (*)(const Tensor& query_, const Tensor& key, const Tensor& value, | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale); | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa); | ||||
|  | ||||
| DECLARE_DISPATCH(fused_sdp_choice_fn, _fused_sdp_choice_stub); | ||||
|  | ||||
|  | ||||
| @ -868,8 +868,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti | ||||
| } | ||||
|  | ||||
| int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale){ | ||||
|   sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal}; | ||||
|         const std::optional<Tensor>& attn_mask_, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){ | ||||
|   sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; | ||||
|   auto backend = select_sdp_backend(kernel_params); | ||||
|   if (backend == sdp::SDPBackend::error) { | ||||
|     TORCH_CHECK( | ||||
|  | ||||
| @ -598,7 +598,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { | ||||
|   } | ||||
|   if (has_only_dense_inputs(params)) { | ||||
|     constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>( | ||||
|         check_batch_size_and_num_heads_dense, | ||||
|         check_batch_size_and_num_heads_dense<true /*supports_grouped_query_attention=*/>, | ||||
|         check_nonzero_sequence_lengths_dense, | ||||
|         check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>); | ||||
|     for (auto& constraint : dense_constraints) { | ||||
| @ -655,9 +655,9 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { | ||||
|   } | ||||
|   if (has_only_dense_inputs(params)) { | ||||
|     constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>( | ||||
|         check_batch_size_and_num_heads_dense, | ||||
|         check_nonzero_sequence_lengths_dense, | ||||
|         check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>); | ||||
|         check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>, | ||||
|         check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention=*/>); | ||||
|     for (auto& constraint : dense_constraints) { | ||||
|       if (!constraint(params, debug)) { | ||||
|         return false; | ||||
|  | ||||
| @ -42,7 +42,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) { | ||||
|       check_nested_tensor, | ||||
|       check_for_dropout, | ||||
|       check_tensor_shapes, | ||||
|       check_batch_size_and_num_heads_dense, | ||||
|       check_batch_size_and_num_heads_dense<false /*supports_grouped_query_attention*/>, | ||||
|       check_attn_mask_shape, | ||||
|       check_head_dim_size_cpp, | ||||
|       check_nonzero_sequence_lengths_dense, | ||||
|  | ||||
| @ -48,6 +48,7 @@ struct sdp_params { | ||||
|   std::optional<at::Tensor> attn_mask; | ||||
|   double dropout; | ||||
|   bool is_causal; | ||||
|   bool enable_gqa; | ||||
| }; | ||||
|  | ||||
| SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params); | ||||
| @ -353,6 +354,46 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) { | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| inline bool check_grouped_query_attention(sdp_params const& params, bool debug) { | ||||
|   const auto q_num_heads = params.query.sym_size(-3); | ||||
|   const auto k_num_heads = params.key.sym_size(-3); | ||||
|   const auto v_num_heads = params.value.sym_size(-3); | ||||
|   const bool same_kv_heads = k_num_heads == v_num_heads; | ||||
|  | ||||
|   if (!(same_kv_heads)){ | ||||
|     if (debug) { | ||||
|       TORCH_WARN( | ||||
|           "Both fused kernels require key and value to have the same num_heads and batch_size but got: ", | ||||
|           "Key sizes: ", | ||||
|           params.key.sizes(), | ||||
|           ", Value sizes: ", | ||||
|           params.value.sizes(), | ||||
|           ", Query sizes: ", | ||||
|           params.query.sizes(), | ||||
|           " instead."); | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
|   // Check if grouped query attention is supported and validate the number of | ||||
|   // heads | ||||
|   if (q_num_heads % k_num_heads != 0) { | ||||
|     if (debug) { | ||||
|       TORCH_WARN( | ||||
|           "FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.", | ||||
|           "Got input Key sizes(): ", | ||||
|           params.key.sym_size(-3), | ||||
|           ", Value sizes(): ", | ||||
|           params.value.sym_size(-3), | ||||
|           ", Query sizes(): ", | ||||
|           params.query.sym_size(-3), | ||||
|           " instead."); | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <bool supports_gqa> | ||||
| inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { | ||||
|   // This is expected to be called after check_tensor_shapes ensuring that the | ||||
|   // size() calls won't error since the inputs are all 4 dimensional | ||||
| @ -364,16 +405,36 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool | ||||
|   bool same_batch_size = | ||||
|       q_batch_size == k_batch_size && q_batch_size == v_batch_size; | ||||
|  | ||||
|   auto q_num_heads = params.query.sym_size(1); | ||||
|   auto k_num_heads = params.key.sym_size(1); | ||||
|   auto v_num_heads = params.value.sym_size(1); | ||||
|   auto q_num_heads = params.query.sym_size(-3); | ||||
|   auto k_num_heads = params.key.sym_size(-3); | ||||
|   auto v_num_heads = params.value.sym_size(-3); | ||||
|  | ||||
|   bool same_num_heads = | ||||
|       q_num_heads == k_num_heads && q_num_heads == v_num_heads; | ||||
|  | ||||
|   if (!(same_batch_size && same_num_heads)) { | ||||
|   if (!same_batch_size){ | ||||
|     if(debug) { | ||||
|       TORCH_WARN( | ||||
|           "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ", | ||||
|           "Query.sizes(): ", | ||||
|           params.query.sizes(), | ||||
|           ", Key.sizes(): ", | ||||
|           params.key.sizes(), | ||||
|           ", Value.sizes(): ", | ||||
|           params.value.sizes(), | ||||
|           " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel."); | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   if(params.enable_gqa && supports_gqa){ | ||||
|     return check_grouped_query_attention(params, debug); | ||||
|   } | ||||
|  | ||||
|   if (!same_num_heads){ | ||||
|     if (debug) { | ||||
|       TORCH_WARN( | ||||
|           "For dense inputs, both fused kernels require query, key and value to have the same batch_size and num_heads. ", | ||||
|           "For dense input, both fused kernels require query, key and value to have the same num_heads. ", | ||||
|           "Query.sizes(): ", | ||||
|           params.query.sizes(), | ||||
|           ", Key sizes(): ", | ||||
| @ -384,6 +445,7 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
|   // If all checks pass, return true | ||||
|   return true; | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -128,7 +128,7 @@ void quantize_tensor_per_tensor_affine_privateuse1( | ||||
| } | ||||
|  | ||||
| int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, | ||||
|     const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){ | ||||
|     const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale, bool enable_gqa){ | ||||
|   auto backend = sdp::SDPBackend::overrideable; | ||||
|   return static_cast<int64_t>(backend); | ||||
| } | ||||
|  | ||||
| @ -31,7 +31,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|  | ||||
|             @torch.compile(fullgraph=True, backend=counter) | ||||
|             def fn(q, k, v, m): | ||||
|                 return SDPAParams(q, k, v, m, 0.1, True) | ||||
|                 return SDPAParams(q, k, v, m, 0.1, True, False) | ||||
|  | ||||
|             q = torch.randn(10) | ||||
|             k = torch.randn(10) | ||||
| @ -39,7 +39,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|             m = torch.randn(10) | ||||
|             o = fn(q, k, v, m) | ||||
|             self.assertTrue(isinstance(o, SDPAParams)) | ||||
|             self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) | ||||
|             self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) | ||||
|             self.assertEqual(counter.frame_count, 1) | ||||
|  | ||||
|     def test_graph_break_SDPAParams(self): | ||||
| @ -48,7 +48,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|  | ||||
|             @torch.compile(backend=counter) | ||||
|             def fn(q, k, v, m): | ||||
|                 z = SDPAParams(q, k, v, m, 0.1, True) | ||||
|                 z = SDPAParams(q, k, v, m, 0.1, True, False) | ||||
|                 torch._dynamo.graph_break() | ||||
|                 return z, q + 1 | ||||
|  | ||||
| @ -58,7 +58,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|             m = torch.randn(10) | ||||
|             o, _ = fn(q, k, v, m) | ||||
|             self.assertTrue(isinstance(o, SDPAParams)) | ||||
|             self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True)) | ||||
|             self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) | ||||
|             self.assertEqual(counter.frame_count, 2) | ||||
|  | ||||
|     def test_input_SDPAParams(self): | ||||
| @ -74,7 +74,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|             k = torch.randn(10) | ||||
|             v = torch.randn(10) | ||||
|             m = torch.randn(10) | ||||
|             s = SDPAParams(q, k, v, m, 0.1, True) | ||||
|             s = SDPAParams(q, k, v, m, 0.1, True, False) | ||||
|             o, _ = fn(s, q) | ||||
|             self.assertIs(o, s) | ||||
|             self.assertEqual(counter.frame_count, 1) | ||||
| @ -86,7 +86,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|             @torch.compile(fullgraph=True, backend=counter) | ||||
|             def fn(q, k, v, m): | ||||
|                 q += 1 | ||||
|                 z = SDPAParams(q, k, v, m, 0.1, True) | ||||
|                 z = SDPAParams(q, k, v, m, 0.1, True, False) | ||||
|                 a = z.query | ||||
|                 return a + 1, z, q | ||||
|  | ||||
| @ -95,7 +95,7 @@ class TestSDPA(torch._dynamo.test_case.TestCase): | ||||
|             v = torch.randn(10) | ||||
|             m = torch.randn(10) | ||||
|             _, o, _ = fn(q, k, v, m) | ||||
|             expected = SDPAParams(q, k, v, m, 0.1, True) | ||||
|             expected = SDPAParams(q, k, v, m, 0.1, True, False) | ||||
|             self.assert_ref_equals_params(o, expected) | ||||
|             self.assertEqual(counter.frame_count, 1) | ||||
|  | ||||
|  | ||||
| @ -1561,6 +1561,36 @@ class TestSDPAFailureModes(NNTestCase): | ||||
|                 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( | ||||
|                     q, k, v, None, 0.0, False)) | ||||
|  | ||||
|     @onlyCUDA | ||||
|     @skipIfRocm  # Nested Tensor | ||||
|     @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") | ||||
|     @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) | ||||
|     def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel): | ||||
|         rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) | ||||
|         rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) | ||||
|         rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) | ||||
|  | ||||
|         with sdpa_kernel(fused_kernel): | ||||
|             with self.assertRaisesRegex(RuntimeError, "No available kernel"): | ||||
|                 with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, " | ||||
|                                            "key and value to have"): | ||||
|                     F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0, | ||||
|                                                    is_causal=False, enable_gqa=True) | ||||
|  | ||||
|     @onlyCPU | ||||
|     @skipIfRocm  # Nested Tensor | ||||
|     def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device): | ||||
|         rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True) | ||||
|         rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) | ||||
|         rand_value = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True) | ||||
|  | ||||
|         with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): | ||||
|             with self.assertRaisesRegex(RuntimeError, "No available kernel"): | ||||
|                 with self.assertWarnsRegex(UserWarning, "For dense inputs, both fused kernels require query, " | ||||
|                                            "key and value to have"): | ||||
|                     F.scaled_dot_product_attention(rand_query, rand_key, rand_value, dropout_p=0.0, | ||||
|                                                    is_causal=False, enable_gqa=True) | ||||
|  | ||||
|     @onlyCUDA | ||||
|     @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention") | ||||
|     @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) | ||||
| @ -1712,7 +1742,8 @@ class TestSDPAFailureModes(NNTestCase): | ||||
|         seq_len_list = [2, 4, 5, 6, 7] | ||||
|         shape = SdpaShape(5, 8, seq_len_list, 57) | ||||
|         make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype) | ||||
|         q, k, v = make_tensor(), make_tensor(), make_tensor() | ||||
|         q, k, v = make_tensor().transpose(1, 2), make_tensor().transpose(1, 2), make_tensor().transpose(1, 2) | ||||
|  | ||||
|         with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): | ||||
|             with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"): | ||||
|                 self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( | ||||
| @ -1792,7 +1823,7 @@ class TestSDPAFailureModes(NNTestCase): | ||||
|         with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): | ||||
|             with self.assertWarnsRegex(UserWarning, "Both fused kernels do not support training with broadcasted NT inputs"): | ||||
|                 with self.assertRaisesRegex(RuntimeError, "No available kernel"): | ||||
|                     out = torch.nn.functional.scaled_dot_product_attention( | ||||
|                     torch.nn.functional.scaled_dot_product_attention( | ||||
|                         query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) | ||||
|  | ||||
|     @onlyCUDA | ||||
| @ -2949,23 +2980,32 @@ class TestSDPACudaOnly(NNTestCase): | ||||
|     @parametrize("dropout_p", [0.0, 0.22, 0.48]) | ||||
|     @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||||
|     @parametrize("scale", [None, "l1"]) | ||||
|     @parametrize("enable_gqa", [True, False]) | ||||
|     @parametrize("n_heads", [[16, 8], [10, 2]]) | ||||
|     def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, | ||||
|                                                head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, | ||||
|                                                scale: str): | ||||
|                                                scale: str, enable_gqa: bool, n_heads: List[int]): | ||||
|         if isSM8XDevice and head_dim in range(193, 256 + 1): | ||||
|             self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") | ||||
|         if is_causal and seq_len_q != seq_len_k: | ||||
|             self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") | ||||
|         if TEST_WITH_ROCM and seq_len_q >= 1024 and seq_len_k >= 1024 and batch_size > 1: | ||||
|             torch.cuda.empty_cache()  # Prevent memory fragmentation | ||||
|         if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: | ||||
|             unittest.skip("Reference implementation OOM") | ||||
|             return | ||||
|  | ||||
|         scale = scale if scale is None else (1 / head_dim) | ||||
|         n_heads = 4 | ||||
|         query = torch.rand(batch_size, n_heads, seq_len_q, head_dim, | ||||
|         num_heads_q = num_heads_kv = 4 | ||||
|         if enable_gqa: | ||||
|             num_heads_q = n_heads[0] | ||||
|             num_heads_kv = n_heads[1] | ||||
|  | ||||
|         query = torch.rand(batch_size, num_heads_q, seq_len_q, head_dim, | ||||
|                            device=device, dtype=dtype, requires_grad=True) | ||||
|         key = torch.rand(batch_size, n_heads, seq_len_k, head_dim, device=device, | ||||
|         key = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim, device=device, | ||||
|                          dtype=dtype, requires_grad=True) | ||||
|         value = torch.rand(batch_size, n_heads, seq_len_k, head_dim, | ||||
|         value = torch.rand(batch_size, num_heads_kv, seq_len_k, head_dim, | ||||
|                            device=device, dtype=dtype, requires_grad=True) | ||||
|  | ||||
|         higher_precision_dtype = torch.float64 if dtype == torch.float32 else torch.float32 | ||||
| @ -2975,14 +3015,15 @@ class TestSDPACudaOnly(NNTestCase): | ||||
|  | ||||
|         if not is_dropout: | ||||
|             with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): | ||||
|                 out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) | ||||
|                 out = F.scaled_dot_product_attention( | ||||
|                     query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa) | ||||
|             with sdpa_kernel(backends=[SDPBackend.MATH]): | ||||
|                 # High Precision Math Reference | ||||
|                 out_ref = F.scaled_dot_product_attention( | ||||
|                     query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale) | ||||
|                     query_ref, key_ref, value_ref, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa) | ||||
|                 # Low Precision Math Reference | ||||
|                 out_lp_ref = F.scaled_dot_product_attention( | ||||
|                     query, key, value, is_causal=is_causal, scale=scale) | ||||
|                     query, key, value, is_causal=is_causal, scale=scale, enable_gqa=enable_gqa) | ||||
|         else: | ||||
|             # Problem: We pad sizes in the composite region of the top level SDPA. But we need the | ||||
|             # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout | ||||
| @ -3009,11 +3050,12 @@ class TestSDPACudaOnly(NNTestCase): | ||||
|             dropout_mask = softmax_mask >= 0 | ||||
|             # High Precision Math Reference | ||||
|             out_ref = torch.ops.aten._scaled_dot_product_attention_math( | ||||
|                 query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] | ||||
|                 query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, | ||||
|                 scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0] | ||||
|             # Low Precision Math Reference | ||||
|             out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( | ||||
|                 query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, | ||||
|                 dropout_mask=dropout_mask)[0] | ||||
|                 dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0] | ||||
|  | ||||
|         upstream_grad = torch.rand_like(out, requires_grad=False) | ||||
|  | ||||
| @ -3033,7 +3075,7 @@ class TestSDPACudaOnly(NNTestCase): | ||||
|                 'out': 1.5, | ||||
|                 'grad_query': 13.0, | ||||
|                 'grad_key': 2.0, | ||||
|                 'grad_value': 1.5, | ||||
|                 'grad_value': 1.75, | ||||
|             } | ||||
|         ) | ||||
|  | ||||
| @ -3185,6 +3227,7 @@ class TestSDPACudaOnly(NNTestCase): | ||||
|                 } | ||||
|             ) | ||||
|  | ||||
|  | ||||
|     @skipIfRocm  # Nested Tensor | ||||
|     @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") | ||||
|     @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if | ||||
|  | ||||
| @ -466,6 +466,7 @@ def gen_nn_functional(fm: FileManager) -> None: | ||||
|                             "dropout_p: float = 0.0", | ||||
|                             "is_causal: bool = False", | ||||
|                             "scale: Optional[float] = None", | ||||
|                             "enable_gqa: bool = False", | ||||
|                         ] | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
| @ -1956,6 +1956,7 @@ class _SDPAParams: | ||||
|     attn_mask: Optional[Tensor] | ||||
|     dropout: _float | ||||
|     is_causal: _bool | ||||
|     enable_gqa: _bool | ||||
|     def __init__( | ||||
|         self, | ||||
|         query: Tensor, | ||||
| @ -1963,7 +1964,8 @@ class _SDPAParams: | ||||
|         value: Tensor, | ||||
|         attn_mask: Optional[Tensor], | ||||
|         dropout: _float, | ||||
|         is_causal: _bool) -> None: ... | ||||
|         is_causal: _bool, | ||||
|         enable_gqa: _bool) -> None: ... | ||||
|  | ||||
| class _SDPBackend(Enum): | ||||
|     ERROR = -1 | ||||
|  | ||||
| @ -33,6 +33,9 @@ class SDPAParamsVariable(VariableTracker): | ||||
|         is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( | ||||
|             value.is_causal | ||||
|         ) | ||||
|         enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))( | ||||
|             value.enable_gqa | ||||
|         ) | ||||
|         param_vars = [ | ||||
|             query_var, | ||||
|             key_var, | ||||
| @ -40,6 +43,7 @@ class SDPAParamsVariable(VariableTracker): | ||||
|             attn_mask_var, | ||||
|             dropout_var, | ||||
|             is_causal_var, | ||||
|             enable_gqa_var, | ||||
|         ] | ||||
|         return TorchInGraphFunctionVariable(SDPAParams).call_function( | ||||
|             tx, param_vars, {} | ||||
|  | ||||
| @ -1953,16 +1953,24 @@ Call this whenever a new thread is created in order to propagate values from | ||||
|                        at::Tensor const& value, | ||||
|                        std::optional<at::Tensor> attn_mask, | ||||
|                        double dropout, | ||||
|                        bool is_causal) { | ||||
|                        bool is_causal, | ||||
|                        bool enable_gqa) { | ||||
|         return sdp::sdp_params{ | ||||
|             query, key, value, std::move(attn_mask), dropout, is_causal}; | ||||
|             query, | ||||
|             key, | ||||
|             value, | ||||
|             std::move(attn_mask), | ||||
|             dropout, | ||||
|             is_causal, | ||||
|             enable_gqa}; | ||||
|       })) | ||||
|       .def_readonly("query", &sdp::sdp_params::query) | ||||
|       .def_readonly("key", &sdp::sdp_params::key) | ||||
|       .def_readonly("value", &sdp::sdp_params::value) | ||||
|       .def_readonly("attn_mask", &sdp::sdp_params::attn_mask) | ||||
|       .def_readonly("dropout", &sdp::sdp_params::dropout) | ||||
|       .def_readonly("is_causal", &sdp::sdp_params::is_causal); | ||||
|       .def_readonly("is_causal", &sdp::sdp_params::is_causal) | ||||
|       .def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa); | ||||
|  | ||||
|   py::enum_<sdp::SDPBackend>( | ||||
|       py_module, | ||||
|  | ||||
| @ -261,7 +261,7 @@ def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool: | ||||
|     return True | ||||
|  | ||||
|  | ||||
| def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): | ||||
| def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal, enable_gqa): | ||||
|     if ( | ||||
|         not flash_sdp_enabled() | ||||
|         and not mem_efficient_sdp_enabled() | ||||
| @ -275,7 +275,7 @@ def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal): | ||||
|         SDPBackend.MATH, | ||||
|     ) | ||||
|  | ||||
|     params = SDPAParams(query, key, value, attn_mask, dropout, is_causal) | ||||
|     params = SDPAParams(query, key, value, attn_mask, dropout, is_causal, enable_gqa) | ||||
|  | ||||
|     for backend in ordering: | ||||
|         if backend == SDPBackend.FLASH_ATTENTION: | ||||
| @ -622,6 +622,7 @@ def jagged_scaled_dot_product_attention( | ||||
|     dropout_p=0.0, | ||||
|     is_causal=False, | ||||
|     scale=None, | ||||
|     enable_gqa=False, | ||||
| ): | ||||
|     _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale) | ||||
|     # for mypy, ugh | ||||
| @ -652,7 +653,7 @@ def jagged_scaled_dot_product_attention( | ||||
|     compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad | ||||
|  | ||||
|     backend_choice = _select_sdp_backend( | ||||
|         query, key, value, attn_mask, dropout_p, is_causal | ||||
|         query, key, value, attn_mask, dropout_p, is_causal, enable_gqa | ||||
|     ) | ||||
|  | ||||
|     if backend_choice == SDPBackend.FLASH_ATTENTION: | ||||
|  | ||||
| @ -173,6 +173,7 @@ class CausalBias(torch.Tensor): | ||||
|         dropout_p: float = 0.0, | ||||
|         is_causal: bool = False, | ||||
|         scale: Optional[float] = None, | ||||
|         enable_gqa: bool = False, | ||||
|     ) -> torch.Tensor: | ||||
|         r""" | ||||
|         Handles the logic for computing attention with the specified causal bias. | ||||
| @ -189,6 +190,7 @@ class CausalBias(torch.Tensor): | ||||
|                 are set. | ||||
|             scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set | ||||
|                 to :math:`\frac{1}{\sqrt{E}}`. | ||||
|             enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. | ||||
|  | ||||
|         Returns: | ||||
|             output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. | ||||
| @ -212,10 +214,13 @@ class CausalBias(torch.Tensor): | ||||
|                 dropout_p=dropout_p, | ||||
|                 is_causal=True, | ||||
|                 scale=scale, | ||||
|                 enable_gqa=enable_gqa, | ||||
|             ) | ||||
|         elif attn_mask.variant == CausalVariant.LOWER_RIGHT: | ||||
|             _validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale) | ||||
|             sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal) | ||||
|             sdpa_params = SDPAParams( | ||||
|                 query, key, value, None, dropout_p, is_causal, enable_gqa | ||||
|             ) | ||||
|             if can_use_flash_attention(sdpa_params): | ||||
|                 needs_padding = query.size(-1) % 8 != 0 | ||||
|                 og_head_size = query.size(-1) | ||||
| @ -264,6 +269,7 @@ class CausalBias(torch.Tensor): | ||||
|                     dropout_p=dropout_p, | ||||
|                     is_causal=False, | ||||
|                     scale=scale, | ||||
|                     enable_gqa=enable_gqa, | ||||
|                 ) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|  | ||||
| @ -5606,137 +5606,163 @@ def _in_projection( | ||||
|  | ||||
| scaled_dot_product_attention = _add_docstr( | ||||
|     torch._C._nn.scaled_dot_product_attention, | ||||
|     r""" | ||||
| scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> Tensor: | ||||
|     r"""scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, | ||||
|         is_causal=False, scale=None, enable_gqa=False) -> Tensor: | ||||
|  | ||||
| Computes scaled dot product attention on query, key and value tensors, using | ||||
| an optional attention mask if passed, and applying dropout if a probability | ||||
| greater than 0.0 is specified. The optional scale argument can only be specified as a keyword argument. | ||||
|  | ||||
| .. code-block:: python | ||||
|  | ||||
|     # Efficient implementation equivalent to the following: | ||||
|     def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: | ||||
|         L, S = query.size(-2), key.size(-2) | ||||
|         scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | ||||
|         attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) | ||||
|         if is_causal: | ||||
|             assert attn_mask is None | ||||
|             temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | ||||
|             attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||||
|             attn_bias.to(query.dtype) | ||||
|  | ||||
|         if attn_mask is not None: | ||||
|             if attn_mask.dtype == torch.bool: | ||||
|                 attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | ||||
|             else: | ||||
|                 attn_bias = attn_mask + attn_bias | ||||
|         attn_weight = query @ key.transpose(-2, -1) * scale_factor | ||||
|         attn_weight += attn_bias | ||||
|         attn_weight = torch.softmax(attn_weight, dim=-1) | ||||
|         attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | ||||
|         return attn_weight @ value | ||||
|  | ||||
| .. warning:: This function is beta and subject to change. | ||||
|  | ||||
| .. warning:: | ||||
|  | ||||
|     This function always applies dropout according to the specified ``dropout_p`` argument. | ||||
|     To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module | ||||
|     that makes the function call is not in training mode. | ||||
|  | ||||
|     For example: | ||||
|     Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, | ||||
|     and applying dropout if a probability greater than 0.0 is specified. The optional scale argument can only be | ||||
|     specified as a keyword argument. | ||||
|  | ||||
|     .. code-block:: python | ||||
|  | ||||
|         class MyModel(nn.Module): | ||||
|             def __init__(self, p=0.5): | ||||
|                 super().__init__() | ||||
|                 self.p = p | ||||
|         # Efficient implementation equivalent to the following: | ||||
|         def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, | ||||
|                 is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: | ||||
|             L, S = query.size(-2), key.size(-2) | ||||
|             scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | ||||
|             attn_bias = torch.zeros(L, S, dtype=query.dtype) | ||||
|             if is_causal: | ||||
|                 assert attn_mask is None | ||||
|                 temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | ||||
|                 attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||||
|                 attn_bias.to(query.dtype) | ||||
|  | ||||
|             def forward(self, ...): | ||||
|                 return F.scaled_dot_product_attention(..., dropout_p=(self.p if self.training else 0.0)) | ||||
|             if attn_mask is not None: | ||||
|                 if attn_mask.dtype == torch.bool: | ||||
|                     attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | ||||
|                 else: | ||||
|                     attn_bias += attn_mask | ||||
|  | ||||
| Note: | ||||
|             if enable_gqa: | ||||
|                 key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) | ||||
|                 value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) | ||||
|  | ||||
|     There are currently three supported implementations of scaled dot product attention: | ||||
|             attn_weight = query @ key.transpose(-2, -1) * scale_factor | ||||
|             attn_weight += attn_bias | ||||
|             attn_weight = torch.softmax(attn_weight, dim=-1) | ||||
|             attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | ||||
|             return attn_weight @ value | ||||
|  | ||||
|         - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ | ||||
|         - `Memory-Efficient Attention`_ | ||||
|         - A PyTorch implementation defined in C++ matching the above formulation | ||||
|     .. warning:: | ||||
|         This function is beta and subject to change. | ||||
|  | ||||
|     The function may call optimized kernels for improved performance when using the CUDA backend. | ||||
|     For all other backends, the PyTorch implementation will be used. | ||||
|     .. warning:: | ||||
|         This function always applies dropout according to the specified ``dropout_p`` argument. | ||||
|         To disable dropout during evaluation, be sure to pass a value of ``0.0`` when the module | ||||
|         that makes the function call is not in training mode. | ||||
|  | ||||
|     All implementations are enabled by default. Scaled dot product attention attempts to automatically select the | ||||
|     most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation | ||||
|     is used, the following functions are provided for enabling and disabling implementations. | ||||
|     The context manager is the preferred mechanism: | ||||
|         For example: | ||||
|  | ||||
|         - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. | ||||
|         - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. | ||||
|         - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables  Memory-Efficient Attention. | ||||
|         - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables  the PyTorch C++ implementation. | ||||
|         .. code-block:: python | ||||
|  | ||||
|     Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, | ||||
|     disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. | ||||
|     In the event that a fused implementation is not available, a warning will be raised with the | ||||
|     reasons why the fused implementation cannot run. | ||||
|             class MyModel(nn.Module): | ||||
|                 def __init__(self, p=0.5): | ||||
|                     super().__init__() | ||||
|                     self.p = p | ||||
|  | ||||
|     Due to the nature of fusing floating point operations, the output of this function may be different | ||||
|     depending on what backend kernel is chosen. | ||||
|     The c++ implementation supports torch.float64 and can be used when higher precision is required. | ||||
|     For more information please see :doc:`/notes/numerical_accuracy` | ||||
|                 def forward(self, ...): | ||||
|                     return F.scaled_dot_product_attention(..., | ||||
|                         dropout_p=(self.p if self.training else 0.0)) | ||||
|  | ||||
| Note: | ||||
|     {cudnn_reproducibility_note} | ||||
| """.format( | ||||
|     Note: | ||||
|  | ||||
|         There are currently three supported implementations of scaled dot product attention: | ||||
|  | ||||
|             - `FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning`_ | ||||
|             - `Memory-Efficient Attention`_ | ||||
|             - A PyTorch implementation defined in C++ matching the above formulation | ||||
|  | ||||
|         The function may call optimized kernels for improved performance when using the CUDA backend. | ||||
|         For all other backends, the PyTorch implementation will be used. | ||||
|  | ||||
|         All implementations are enabled by default. Scaled dot product attention attempts to automatically select the | ||||
|         most optimal implementation based on the inputs. In order to provide more fine-grained control over what implementation | ||||
|         is used, the following functions are provided for enabling and disabling implementations. | ||||
|         The context manager is the preferred mechanism: | ||||
|  | ||||
|             - :func:`torch.nn.attention.sdpa_kernel`: A context manager used to enable or disable any of the implementations. | ||||
|             - :func:`torch.backends.cuda.enable_flash_sdp`: Globally enables or disables FlashAttention. | ||||
|             - :func:`torch.backends.cuda.enable_mem_efficient_sdp`: Globally enables or disables  Memory-Efficient Attention. | ||||
|             - :func:`torch.backends.cuda.enable_math_sdp`: Globally enables or disables  the PyTorch C++ implementation. | ||||
|  | ||||
|         Each of the fused kernels has specific input limitations. If the user requires the use of a specific fused implementation, | ||||
|         disable the PyTorch C++ implementation using :func:`torch.nn.attention.sdpa_kernel`. | ||||
|         In the event that a fused implementation is not available, a warning will be raised with the | ||||
|         reasons why the fused implementation cannot run. | ||||
|  | ||||
|         Due to the nature of fusing floating point operations, the output of this function may be different | ||||
|         depending on what backend kernel is chosen. | ||||
|         The c++ implementation supports torch.float64 and can be used when higher precision is required. | ||||
|         For more information please see :doc:`/notes/numerical_accuracy` | ||||
|  | ||||
|         Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention | ||||
|         and math kernel on CUDA tensor, and does not support Nested tensor. | ||||
|         Constraints for GQA: | ||||
|  | ||||
|             - number_of_heads_query % number_of_heads_key_value == 0 and, | ||||
|             - number_of_heads_key == number_of_heads_value | ||||
|  | ||||
|     Note: | ||||
|  | ||||
|         {cudnn_reproducibility_note} | ||||
|     """.format( | ||||
|         **reproducibility_notes | ||||
|     ) | ||||
|     + r""" | ||||
| Args: | ||||
|     query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`. | ||||
|     key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`. | ||||
|     value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`. | ||||
|     attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, | ||||
|         which is :math:`(N,..., L, S)`. Two types of masks are supported. | ||||
|         A boolean mask where a value of True indicates that the element *should* take part in attention. | ||||
|         A float mask of the same type as query, key, value that is added to the attention score. | ||||
|     dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied | ||||
|     is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a | ||||
|         square matrix. The attention masking has the form of the upper left causal bias due to the alignment | ||||
|         (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. | ||||
|         An error is thrown if both attn_mask and is_causal are set. | ||||
|     scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set | ||||
|         to :math:`\frac{1}{\sqrt{E}}`. | ||||
|     Args: | ||||
|         query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. | ||||
|         key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. | ||||
|         value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. | ||||
|         attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, | ||||
|             which is :math:`(N,..., L, S)`. Two types of masks are supported. | ||||
|             A boolean mask where a value of True indicates that the element *should* take part in attention. | ||||
|             A float mask of the same type as query, key, value that is added to the attention score. | ||||
|         dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied | ||||
|         is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a | ||||
|             square matrix. The attention masking has the form of the upper left causal bias due to the alignment | ||||
|             (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. | ||||
|             An error is thrown if both attn_mask and is_causal are set. | ||||
|         scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set | ||||
|             to :math:`\frac{1}{\sqrt{E}}`. | ||||
|         enable_gqa (bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False. | ||||
|  | ||||
|     Returns: | ||||
|         output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. | ||||
|  | ||||
|     Shape legend: | ||||
|         - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` | ||||
|         - :math:`S: \text{Source sequence length}` | ||||
|         - :math:`L: \text{Target sequence length}` | ||||
|         - :math:`E: \text{Embedding dimension of the query and key}` | ||||
|         - :math:`Ev: \text{Embedding dimension of the value}` | ||||
|         - :math:`Hq: \text{Number of heads of query}` | ||||
|         - :math:`H: \text{Number of heads of key and value}` | ||||
|  | ||||
|     Examples: | ||||
|  | ||||
|         >>> # Optionally use the context manager to ensure one of the fused kernels is run | ||||
|         >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|         >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|         >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|         >>> with torch.backends.cuda.sdp_kernel(enable_math=False): | ||||
|         >>>     F.scaled_dot_product_attention(query,key,value) | ||||
|  | ||||
|  | ||||
| Returns: | ||||
|     output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`. | ||||
|  | ||||
| Shape legend: | ||||
|     - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` | ||||
|     - :math:`S: \text{Source sequence length}` | ||||
|     - :math:`L: \text{Target sequence length}` | ||||
|     - :math:`E: \text{Embedding dimension of the query and key}` | ||||
|     - :math:`Ev: \text{Embedding dimension of the value}` | ||||
|  | ||||
| Examples: | ||||
|  | ||||
|     >>> # Optionally use the context manager to ensure one of the fused kernels is run | ||||
|     >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|     >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|     >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|     >>> with torch.backends.cuda.sdp_kernel(enable_math=False): | ||||
|     >>>     F.scaled_dot_product_attention(query,key,value) | ||||
|         >>> # Sample for GQA for llama3 | ||||
|         >>> query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda") | ||||
|         >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|         >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") | ||||
|         >>> with sdpa_kernel(backends=[SDPBackend.MATH]): | ||||
|         >>>     F.scaled_dot_product_attention(query,key,value,enable_gqa=True) | ||||
|  | ||||
|  | ||||
| .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: | ||||
|     https://arxiv.org/abs/2307.08691 | ||||
| .. _Memory-Efficient Attention: | ||||
|     https://github.com/facebookresearch/xformers | ||||
|  | ||||
| """, | ||||
|     .. _FlashAttention-2\: Faster Attention with Better Parallelism and Work Partitioning: | ||||
|         https://arxiv.org/abs/2307.08691 | ||||
|     .. _Memory-Efficient Attention: | ||||
|         https://github.com/facebookresearch/xformers | ||||
|     .. _Grouped-Query Attention: | ||||
|         https://arxiv.org/pdf/2305.13245 | ||||
|     """, | ||||
| ) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -8688,6 +8688,7 @@ def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): | ||||
| def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): | ||||
|     make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||||
|     batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 | ||||
|     num_heads_q_gqa, num_heads_kv_gqa = 32, 8 | ||||
| 
 | ||||
|     dim_3_q_shape = (batch, seq_q, head_dim) | ||||
|     dim_3_kv_shape = (batch, seq_kv, head_dim) | ||||
| @ -8698,8 +8699,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ | ||||
| 
 | ||||
|     qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] | ||||
|     samples = [] | ||||
|     for qkv_shape, is_causal, dropout_p in product( | ||||
|             qkv_shapes, [True, False], [0.0, 0.5]): | ||||
|     for qkv_shape, is_causal, dropout_p, enable_gqa in product( | ||||
|             qkv_shapes, [True, False], [0.0, 0.5], [True, False]): | ||||
|         shape_q, shape_kv = qkv_shape | ||||
|         samples.append(SampleInput( | ||||
|             make(shape_q), | ||||
| @ -8729,6 +8730,15 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ | ||||
|             dropout_p=0.0) | ||||
|     ) | ||||
| 
 | ||||
|     samples.append( | ||||
|         SampleInput( | ||||
|             make((batch, num_heads_q_gqa, seq_q, head_dim)), | ||||
|             make((batch, num_heads_kv_gqa, seq_kv, head_dim)), | ||||
|             make((batch, num_heads_kv_gqa, seq_kv, head_dim)), | ||||
|             enable_gqa=True | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|     yield from samples | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user