Always use high precision for SDPA math backend (#128922)

Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922
Approved by: https://github.com/xw285cornell, https://github.com/drisspg
This commit is contained in:
Jianyu Huang
2024-08-01 18:55:48 +00:00
committed by PyTorch MergeBot
parent 0eea2b3947
commit fbf3bc0a60
4 changed files with 61 additions and 36 deletions

View File

@ -73,8 +73,7 @@
#endif
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
namespace at {
namespace native {
namespace at::native {
DEFINE_DISPATCH(_fused_sdp_choice_stub);
@ -768,29 +767,57 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
value.is_contiguous(),
"scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
}
auto attn_mask = attn_mask_;
// Naive, composite implementation defined here.
auto origin_dtype = query_.scalar_type();
// Keep query, key, value in high precision for accuracy
// NestedTensor reports issues for backward with autograd so disabled: must be
// contiguous to get buffer.
auto query_acc = (query_.scalar_type() == at::kHalf ||
query_.scalar_type() == at::kBFloat16) &&
!query_.is_nested()
? query_.to(at::kFloat)
: query_;
auto key_acc =
(key.scalar_type() == at::kHalf || key.scalar_type() == at::kBFloat16) &&
!key.is_nested()
? key.to(at::kFloat)
: key;
auto value_acc = (value.scalar_type() == at::kHalf ||
value.scalar_type() == at::kBFloat16) &&
!value.is_nested()
? value.to(at::kFloat)
: value;
auto attn_mask = attn_mask_;
// Naive, composite implementation defined here.
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for
// math
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
const auto scaling_factor =
sdp::calculate_scale(
query_acc, is_negative_scaling ? std::abs(scale.value()) : scale)
.sqrt();
const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
if (is_causal) {
TORCH_CHECK(!attn_mask.has_value(),
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
const auto query = query_acc *
(is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor
: scaling_factor);
if (is_causal) {
TORCH_CHECK(
!attn_mask.has_value(),
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
TORCH_CHECK(
!query.is_nested() && !key_acc.is_nested(),
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
const auto L = query.sym_size(-2), S = key.sym_size(-2);
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
// Replace attn_mask with causal mask; lower triangular elements take part
// in attention.
const auto L = query.sym_size(-2), S = key_acc.sym_size(-2);
attn_mask =
at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
}
// MQA/GQA handling
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key, value, enable_gqa);
auto [key_expanded, value_expanded] = pre_process_group_query_attention_input(query, key_acc, value_acc, enable_gqa);
auto attn = at::matmul(query, key_expanded.transpose(-2, -1) * scaling_factor);
if (attn_mask.has_value()) {
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
@ -807,13 +834,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
auto dropout_scaling = 1.0 / (1 - dropout_p);
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling), attn);
return std::make_tuple(at::matmul(attn, value_expanded * dropout_scaling).to(origin_dtype), attn.to(origin_dtype));
} else {
attn = at::dropout(attn, dropout_p, true);
}
}
return std::make_tuple(at::matmul(attn, value_expanded), attn);
return std::make_tuple(at::matmul(attn, value_expanded).to(origin_dtype), attn.to(origin_dtype));
}
std::tuple<at::Tensor, at::Tensor>
@ -1036,6 +1063,4 @@ Tensor triton_multi_head_attention(
#endif
return proj;
}
} // namespace native
} // namespace at
} // namespace at::native

View File

@ -31,7 +31,6 @@ from torch.testing._internal.common_methods_invocations import (
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
IS_MACOS,
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
@ -1175,7 +1174,7 @@ class DecompOneOffTests(TestCase):
[
xfail(
"nn.functional.scaled_dot_product_attention",
dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else []),
dtypes=[torch.half],
),
],
)

View File

@ -2865,8 +2865,8 @@ class TestSDPACudaOnly(NNTestCase):
(out_ref, out_lp_ref, out),
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors={
'out': 2.0 ,
'grad_query': 18.0 ,
'out': 3.0 ,
'grad_query': 150.0 ,
'grad_key': 25.0,
'grad_value': 8.5,
}
@ -2962,8 +2962,8 @@ class TestSDPACudaOnly(NNTestCase):
(out_ref, out_lp_ref, out),
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors={
"out": 1.75,
"grad_query": 18.0,
"out": 4,
"grad_query": 150.0,
"grad_key": 25.0,
"grad_value": 8.0,
"grad_attn_mask": 45.0,
@ -3072,10 +3072,10 @@ class TestSDPACudaOnly(NNTestCase):
(out_ref, out_lp_ref, out),
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors={
'out': 1.5,
'grad_query': 13.0,
'grad_key': 2.0,
'grad_value': 1.75,
'out': 2.2,
'grad_query': 160.0,
'grad_key': 8.0,
'grad_value': 4,
}
)
@ -3221,8 +3221,8 @@ class TestSDPACudaOnly(NNTestCase):
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors={
'out': 2.0,
'grad_query': 12.0,
'grad_key': 2.0,
'grad_query': 100.0,
'grad_key': 8.0,
'grad_value': 2.0,
}
)

View File

@ -5693,6 +5693,7 @@ scaled_dot_product_attention = _add_docstr(
Due to the nature of fusing floating point operations, the output of this function may be different
depending on what backend kernel is chosen.
The c++ implementation supports torch.float64 and can be used when higher precision is required.
For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
For more information please see :doc:`/notes/numerical_accuracy`
Grouped Query Attention (GQA) is an experimental feature. It currently works only for Flash_attention