mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add a private _safe_softmax (#131060)
# Summary Changes the stance of SDPA on what to do for fully masked out rows ## Current Behavior Several PyTorch users have expressed frustration over this issue: - https://github.com/pytorch/pytorch/issues/41508 - https://github.com/pytorch/pytorch/issues/103749 - https://github.com/pytorch/pytorch/issues/103963 These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here: https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617 Can be paraphrased as follows: When passing in fully masked out rows, attention becomes ambiguous. We have two main options: 1. Uniformly attend to all values: ```python scores[masked_out_rows] = 1 / len(row) out[masked_out_rows] = 1 / len(row) * value ``` 2. Decide that attention between no queries (masked) and no keys (masked) is meaningless: ```python output[fully_masked_rows] = NaN ``` We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs: ``` Python >fill_value = -float("inf") >row0 = torch.randn(4) >row1 = torch.tensor([(fill_value for _ in range(4)]) >matrix = torch.stack([row0, row1]).requires_grad_(True) >out = torch.softmax(matrix, 1) >out = out[0] >print(out) tensor([0.5377, 0.2729, 0.0692, 0.1201]) ``` Cool, problem solved. But what happends when you call backwards.. ```Python >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08], [ nan, nan, nan, nan]]) ``` Those pesky NaNs are back! ## Why do we see NaNs today? The core of the problem revolves around using softmax function in sdpa: ```python > row = torch.tensor([(-float("inf")) for _ in range(4)]) > torch.softmax(row, 0) tensor([nan, nan, nan, nan]) ``` ## Quick Aside: Masking in Attention Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs. We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values. ## Alternative Approaches If we use a very large negative number instead of -inf: ```python > row = torch.tensor([(-1e6) for _ in range(4)]) > torch.softmax(row, 0) tensor([0.2500, 0.2500, 0.2500, 0.2500]) ``` However if users always remembered to "slice" out their outputs i.e.: ```Python >fill_value = -1e6 >... >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[-0.0563, -0.0564, 0.1613, -0.0486], [ 0.0000, 0.0000, 0.0000, 0.0000]]) ``` This would bring us back into a better state. ## A Third Option We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation. This PR implements the new semantic for masking w/ attention in fully masked-out rows: ```python out[masked_out_rows] = 0 ``` **Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption. ## Details This PR stack does 3 things: 1. Adds a PRIVATE _safe_softmax op 2. Updates semantic for flash_cpu fused kernel 3. Updates semantic for efficient_cuda fused kernel _safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num. Why I think this is okay? (please find a counter point if avail) There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them? The only case that this can happen is if the input itself had a NaN or an Inf For example: ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = torch.finfo(torch.float16).max print(a.softmax(-1)) ``` Will return `tensor([0., 1., 0., 0.], dtype=torch.float16)` Where ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = float("inf") a.softmax(-1) ``` returns: `tensor([nan, nan, nan, nan], dtype=torch.float16)` If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this ```Python max = torch.max(a, dim=-1, keepdim=True) exp = torch.exp(a - max.values) denom = torch.sum(exp, dim=-1, keepdim=True) softmax = exp / denom softmax = torch.where(max.values == float('-inf'), 0.0, softmax) ``` however we would be paying for this in math performance. ## Why Now I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
1f66487c69
commit
1434e0b121
@ -14724,6 +14724,11 @@
|
||||
NestedTensorCUDA: NestedTensor_softmax_dropout_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _safe_softmax
|
||||
NestedTensorCPU, NestedTensorCUDA: _safe_softmax
|
||||
|
||||
# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is.
|
||||
- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor
|
||||
variants: function
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
@ -30,6 +29,7 @@ Tensor& NestedTensor_abs_(Tensor& self) {
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
Tensor NestedTensor_sgn(const Tensor& self) {
|
||||
return map_nt(self, at::sgn);
|
||||
}
|
||||
|
@ -19,7 +19,6 @@
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
|
||||
#include <type_traits>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
@ -70,6 +69,9 @@
|
||||
#include <ATen/ops/where.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/zeros_like.h>
|
||||
#include <ATen/ops/_safe_softmax.h>
|
||||
#include <ATen/ops/_safe_softmax_native.h>
|
||||
#include <ATen/ops/all.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
@ -529,7 +531,6 @@ std::optional<Tensor> convert_boolean_attn_mask(const std::optional<Tensor>& att
|
||||
// Convert boolean mask to additive mask; need to invert mask to indicate what
|
||||
// to mask *out*.
|
||||
if (attn_mask->dtype() == at::kBool) {
|
||||
// TODO Use the max type of the input and output
|
||||
return at::where(attn_mask->logical_not(), -std::numeric_limits<double>::infinity(), at::scalar_tensor(0.0, at::TensorOptions().dtype(dtype).device(attn_mask->device())));
|
||||
}
|
||||
// Otherwise, attn_mask represents an additive attention tensor
|
||||
@ -641,6 +642,15 @@ std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input(
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor _safe_softmax(
|
||||
const Tensor& self,
|
||||
int64_t dim,
|
||||
std::optional<ScalarType> dtype) {
|
||||
auto out = at::softmax(self, dim, dtype);
|
||||
const auto masked = self.eq(-std::numeric_limits<float>::infinity());
|
||||
const auto masked_rows = all(masked, dim, true);
|
||||
return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out);
|
||||
}
|
||||
// Computes scaled dot product attention on query, key and value tensors, using
|
||||
// an optional attention mask if passed, and applying dropout if a probability
|
||||
// greater than 0.0 is specified.
|
||||
|
@ -1791,6 +1791,9 @@ class TestOperators(TestCase):
|
||||
), # NYI: forward-AD for soft_margin_loss_backward
|
||||
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
|
||||
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
|
||||
xfail(
|
||||
"torch.ops.aten._safe_softmax.default"
|
||||
), # NYI: forward-AD for _safe_softmax
|
||||
skip("nn.functional.scaled_dot_product_attention"),
|
||||
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
||||
xfail(
|
||||
@ -1973,6 +1976,9 @@ class TestOperators(TestCase):
|
||||
xfail(
|
||||
"nn.functional.ctc_loss"
|
||||
), # ForwardAD not implemented and no decomposition
|
||||
xfail(
|
||||
"torch.ops.aten._safe_softmax.default"
|
||||
), # ForwardAD not implemented
|
||||
xfail("nn.functional.dropout2d"), # calls random op
|
||||
xfail("nn.functional.dropout3d"), # calls random op
|
||||
xfail("nn.functional.dropout"), # calls random op
|
||||
|
@ -374,6 +374,7 @@ inductor_override_kwargs = {
|
||||
"rtol": 0.02,
|
||||
},
|
||||
("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
|
||||
("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02},
|
||||
("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02},
|
||||
("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
|
||||
("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
|
||||
|
@ -223,6 +223,7 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
(torch.float16, torch.ops.aten.mv.default): 1e-5,
|
||||
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
|
||||
(torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
|
||||
(torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
|
||||
}
|
||||
if ref.is_floating_point():
|
||||
orig_diff = (orig - ref).abs().max()
|
||||
|
@ -2854,7 +2854,10 @@
|
||||
- name: _nested_get_values(Tensor(a) self) -> Tensor(a)
|
||||
self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_max_seqlen(self)) : ::std::nullopt)"
|
||||
|
||||
# Transformers
|
||||
# Transformer
|
||||
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
|
||||
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
|
||||
|
||||
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
output_differentiability: [True, False, False, False]
|
||||
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
|
||||
|
@ -408,6 +408,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
aten.rrelu_with_noise,
|
||||
aten.rrelu_with_noise_,
|
||||
aten.rsub,
|
||||
aten._safe_softmax,
|
||||
aten._scaled_dot_product_flash_attention_for_cpu.default,
|
||||
aten.select_backward,
|
||||
aten.select_scatter,
|
||||
|
@ -421,6 +421,15 @@ def mse_loss_backward(
|
||||
return norm * (input - target) * grad_output
|
||||
|
||||
|
||||
@register_decomposition(aten._safe_softmax)
|
||||
def safe_softmax(self, dim, dtype=None):
|
||||
out = torch.softmax(self, dim=dim, dtype=dtype)
|
||||
masked = self.eq(float("-inf"))
|
||||
masked_rows = torch.all(masked, dim=dim, keepdim=True)
|
||||
zeros = torch.zeros_like(out)
|
||||
return torch.where(masked_rows, zeros, out)
|
||||
|
||||
|
||||
@register_decomposition(aten.smooth_l1_loss)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
|
@ -457,10 +457,11 @@ def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate
|
||||
|
||||
|
||||
@register_op_strategy(
|
||||
[aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1)
|
||||
[aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default],
|
||||
schema_info=RuntimeSchemaInfo(1),
|
||||
)
|
||||
def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
|
||||
input_strategy, softmax_dim, _ = op_schema.args_schema
|
||||
input_strategy, softmax_dim, *_ = op_schema.args_schema
|
||||
input_strategy = cast(OpStrategy, input_strategy)
|
||||
softmax_dim = cast(int, softmax_dim)
|
||||
softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)
|
||||
|
@ -711,7 +711,6 @@ def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
|
||||
yield SampleInput(lhs, args=(lhs.clone().detach_(),))
|
||||
|
||||
|
||||
|
||||
def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
@ -4408,6 +4407,70 @@ def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
# Test case for no optional kwargs
|
||||
yield SampleInput(make_arg((1, 2, 3)), kwargs={})
|
||||
|
||||
def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
def make_bool_mask(*shape):
|
||||
return torch.randint(0, 2, shape, device=device, dtype=torch.bool)
|
||||
|
||||
def mask_two_rows(rows, cols):
|
||||
mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device)
|
||||
mask_two_rows[rows - 1] = False
|
||||
mask_two_rows[rows - 3] = False
|
||||
return mask_two_rows
|
||||
|
||||
def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(~mask, float('-inf'), 0.0)
|
||||
|
||||
def with_requires_grad(tensor):
|
||||
return tensor.requires_grad_(requires_grad)
|
||||
|
||||
def generate_input_from_mask(mask_shape, dim):
|
||||
mask = make_bool_mask(*mask_shape)
|
||||
input_tensor = make_arg(mask_shape)
|
||||
masked_input = input_tensor + convert_to_float_mask(mask)
|
||||
return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim})
|
||||
|
||||
samples = [
|
||||
# Basic 3D tensor with mask
|
||||
generate_input_from_mask((2, 3, 4), dim=1),
|
||||
# 2D tensor with mask, testing different dim
|
||||
generate_input_from_mask((5, 5), dim=0),
|
||||
# 4D tensor, testing with a different dim
|
||||
generate_input_from_mask((2, 3, 4, 5), dim=2),
|
||||
# Edge case: 1D tensor
|
||||
generate_input_from_mask((10,), dim=0),
|
||||
# Edge case: tensor with one dimension of size 1
|
||||
generate_input_from_mask((1, 5, 5), dim=1),
|
||||
# Testing with all elements masked
|
||||
SampleInput(
|
||||
with_requires_grad(
|
||||
make_arg((3, 3))
|
||||
+ convert_to_float_mask(
|
||||
torch.zeros((3, 3), dtype=torch.bool, device=device)
|
||||
)
|
||||
),
|
||||
kwargs={"dim": 1},
|
||||
),
|
||||
# Testing with no elements masked
|
||||
SampleInput(
|
||||
with_requires_grad(
|
||||
make_arg((3, 3))
|
||||
+ convert_to_float_mask(
|
||||
torch.ones((3, 3), dtype=torch.bool, device=device)
|
||||
)
|
||||
),
|
||||
kwargs={"dim": 1},
|
||||
),
|
||||
# Testing with two rows masked
|
||||
SampleInput(
|
||||
with_requires_grad(
|
||||
make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3))
|
||||
),
|
||||
kwargs={"dim": 1},
|
||||
),
|
||||
]
|
||||
yield from samples
|
||||
|
||||
def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
@ -16131,6 +16194,24 @@ op_db: List[OpInfo] = [
|
||||
dtypes=(torch.float8_e4m3fn,)),
|
||||
)
|
||||
),
|
||||
OpInfo(
|
||||
'torch.ops.aten._safe_softmax.default',
|
||||
dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
|
||||
sample_inputs_func=sample_inputs_safe_softmax,
|
||||
assert_jit_shape_analysis=True,
|
||||
assert_autodiffed=True,
|
||||
supports_forward_ad=False,
|
||||
supports_fwgrad_bwgrad=False,
|
||||
supports_out=False,
|
||||
supports_cow_input_no_materialize_backward=False,
|
||||
decorators=[],
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
'nn.functional.scaled_dot_product_attention',
|
||||
op=lambda *args, **kwargs:
|
||||
|
Reference in New Issue
Block a user