mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Always use high precision for SDPA math backend (#128922)
Summary: feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts. Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16. Differential Revision: D58710805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922 Approved by: https://github.com/xw285cornell, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
0eea2b3947
commit
fbf3bc0a60
@ -73,8 +73,7 @@
|
||||
#endif
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace at::native {
|
||||
|
||||
DEFINE_DISPATCH(_fused_sdp_choice_stub);
|
||||
|
||||
@ -768,29 +767,57 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
value.is_contiguous(),
|
||||
"scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
|
||||
}
|
||||
auto attn_mask = attn_mask_;
|
||||
// Naive, composite implementation defined here.
|
||||
auto origin_dtype = query_.scalar_type();
|
||||
// Keep query, key, value in high precision for accuracy
|
||||
// NestedTensor reports issues for backward with autograd so disabled: must be
|
||||
// contiguous to get buffer.
|
||||
auto query_acc = (query_.scalar_type() == at::kHalf ||
|
||||
query_.scalar_type() == at::kBFloat16) &&
|
||||
!query_.is_nested()
|
||||
? query_.to(at::kFloat)
|
||||
: query_;
|
||||
auto key_acc =
|
||||
(key.scalar_type() == at::kHalf || key.scalar_type() == at::kBFloat16) &&
|
||||
!key.is_nested()
|
||||
? key.to(at::kFloat)
|
||||
: key;
|
||||
auto value_acc = (value.scalar_type() == at::kHalf ||
|
||||
value.scalar_type() == at::kBFloat16) &&
|
||||
!value.is_nested()
|
||||
? value.to(at::kFloat)
|
||||
: value;
|
||||
auto attn_mask = attn_mask_;
|
||||
// Naive, composite implementation defined here.
|
||||
|
||||
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
|
||||
const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
|
||||
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for
|
||||
// math
|
||||
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
|
||||
const auto scaling_factor =
|
||||
sdp::calculate_scale(
|
||||
query_acc, is_negative_scaling ? std::abs(scale.value()) : scale)
|
||||
.sqrt();
|
||||
|
||||
const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(!attn_mask.has_value(),
|
||||
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
|
||||
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
|
||||
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
|
||||
const auto query = query_acc *
|
||||
(is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor
|
||||
: scaling_factor);
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(
|
||||
!attn_mask.has_value(),
|
||||
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
|
||||
TORCH_CHECK(
|
||||
!query.is_nested() && !key_acc.is_nested(),
|
||||
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
|
||||
|
||||
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
|
||||
const auto L = query.sym_size(-2), S = key.sym_size(-2);
|
||||
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
|
||||
// Replace attn_mask with causal mask; lower triangular elements take part
|
||||
// in attention.
|
||||
const auto L = query.sym_size(-2), S = key_acc.sym_size(-2);
|
||||
attn_mask =
|
||||
at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
|
||||
}
|
||||
|
||||
|
||||
// MQA/GQA handling
|
||||
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key, value, enable_gqa);
|
||||
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key_acc, value_acc, enable_gqa);
|
||||
auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor);
|
||||
if (attn_mask.has_value()) {
|
||||
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
|
||||
@ -807,13 +834,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
|
||||
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
|
||||
auto dropout_scaling = 1.0 / (1 - dropout_p);
|
||||
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling), attn);
|
||||
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling).to(origin_dtype), attn.to(origin_dtype));
|
||||
} else {
|
||||
attn = at::dropout(attn, dropout_p, true);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(at::matmul(attn, value_expanded), attn);
|
||||
return std::make_tuple(at::matmul(attn, value_expanded).to(origin_dtype), attn.to(origin_dtype));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
@ -1036,6 +1063,4 @@ Tensor triton_multi_head_attention(
|
||||
#endif
|
||||
return proj;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
} // namespace at::native
|
||||
|
@ -31,7 +31,6 @@ from torch.testing._internal.common_methods_invocations import (
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_utils import (
|
||||
is_iterable_of_tensors,
|
||||
IS_MACOS,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
@ -1175,7 +1174,7 @@ class DecompOneOffTests(TestCase):
|
||||
[
|
||||
xfail(
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else []),
|
||||
dtypes=[torch.half],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
@ -2865,8 +2865,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
'out': 2.0 ,
|
||||
'grad_query': 18.0 ,
|
||||
'out': 3.0 ,
|
||||
'grad_query': 150.0 ,
|
||||
'grad_key': 25.0,
|
||||
'grad_value': 8.5,
|
||||
}
|
||||
@ -2962,8 +2962,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
"out": 1.75,
|
||||
"grad_query": 18.0,
|
||||
"out": 4,
|
||||
"grad_query": 150.0,
|
||||
"grad_key": 25.0,
|
||||
"grad_value": 8.0,
|
||||
"grad_attn_mask": 45.0,
|
||||
@ -3072,10 +3072,10 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
'out': 1.5,
|
||||
'grad_query': 13.0,
|
||||
'grad_key': 2.0,
|
||||
'grad_value': 1.75,
|
||||
'out': 2.2,
|
||||
'grad_query': 160.0,
|
||||
'grad_key': 8.0,
|
||||
'grad_value': 4,
|
||||
}
|
||||
)
|
||||
|
||||
@ -3221,8 +3221,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
'out': 2.0,
|
||||
'grad_query': 12.0,
|
||||
'grad_key': 2.0,
|
||||
'grad_query': 100.0,
|
||||
'grad_key': 8.0,
|
||||
'grad_value': 2.0,
|
||||
}
|
||||
)
|
||||
|
@ -5693,6 +5693,7 @@ scaled_dot_product_attention = _add_docstr(
|
||||
Due to the nature of fusing floating point operations, the output of this function may be different
|
||||
depending on what backend kernel is chosen.
|
||||
The c++ implementation supports torch.float64 and can be used when higher precision is required.
|
||||
For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
|
||||
For more information please see :doc:`/notes/numerical_accuracy`
|
||||
|
||||
Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention
|
||||
|
Reference in New Issue
Block a user