Add a private _safe_softmax (#131060)

# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060
Approved by: https://github.com/jbschlosser
This commit is contained in:
drisspg
2024-08-06 13:01:38 -07:00
committed by PyTorch MergeBot
parent 1f66487c69
commit 1434e0b121
11 changed files with 127 additions and 9 deletions

View File

@ -14724,6 +14724,11 @@
NestedTensorCUDA: NestedTensor_softmax_dropout_cuda
tags: nondeterministic_seeded
- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
dispatch:
CompositeExplicitAutograd: _safe_softmax
NestedTensorCPU, NestedTensorCUDA: _safe_softmax
# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is.
- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor
variants: function

View File

@ -14,7 +14,6 @@
#include <ATen/native/nested/NestedTensorUtils.h>
#include <tuple>
namespace at {
namespace native {
@ -30,6 +29,7 @@ Tensor& NestedTensor_abs_(Tensor& self) {
return self;
}
Tensor NestedTensor_sgn(const Tensor& self) {
return map_nt(self, at::sgn);
}

View File

@ -19,7 +19,6 @@
#include <c10/core/DispatchKey.h>
#include <c10/core/DispatchKeySet.h>
#include <type_traits>
#include <limits>
#include <utility>
@ -70,6 +69,9 @@
#include <ATen/ops/where.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/_safe_softmax.h>
#include <ATen/ops/_safe_softmax_native.h>
#include <ATen/ops/all.h>
#endif
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
@ -529,7 +531,6 @@ std::optional<Tensor> convert_boolean_attn_mask(const std::optional<Tensor>& att
// Convert boolean mask to additive mask; need to invert mask to indicate what
// to mask *out*.
if (attn_mask->dtype() == at::kBool) {
// TODO Use the max type of the input and output
return at::where(attn_mask->logical_not(), -std::numeric_limits<double>::infinity(), at::scalar_tensor(0.0, at::TensorOptions().dtype(dtype).device(attn_mask->device())));
}
// Otherwise, attn_mask represents an additive attention tensor
@ -641,6 +642,15 @@ std::tuple<at::Tensor, at::Tensor> pre_process_group_query_attention_input(
} // namespace
Tensor _safe_softmax(
const Tensor& self,
int64_t dim,
std::optional<ScalarType> dtype) {
auto out = at::softmax(self, dim, dtype);
const auto masked = self.eq(-std::numeric_limits<float>::infinity());
const auto masked_rows = all(masked, dim, true);
return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out);
}
// Computes scaled dot product attention on query, key and value tensors, using
// an optional attention mask if passed, and applying dropout if a probability
// greater than 0.0 is specified.

View File

@ -1791,6 +1791,9 @@ class TestOperators(TestCase):
), # NYI: forward-AD for soft_margin_loss_backward
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
xfail(
"torch.ops.aten._safe_softmax.default"
), # NYI: forward-AD for _safe_softmax
skip("nn.functional.scaled_dot_product_attention"),
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
xfail(
@ -1973,6 +1976,9 @@ class TestOperators(TestCase):
xfail(
"nn.functional.ctc_loss"
), # ForwardAD not implemented and no decomposition
xfail(
"torch.ops.aten._safe_softmax.default"
), # ForwardAD not implemented
xfail("nn.functional.dropout2d"), # calls random op
xfail("nn.functional.dropout3d"), # calls random op
xfail("nn.functional.dropout"), # calls random op

View File

@ -374,6 +374,7 @@ inductor_override_kwargs = {
"rtol": 0.02,
},
("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02},
("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02},
("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002},

View File

@ -223,6 +223,7 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
(torch.float16, torch.ops.aten.mv.default): 1e-5,
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
(torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
(torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
}
if ref.is_floating_point():
orig_diff = (orig - ref).abs().max()

View File

@ -2854,7 +2854,10 @@
- name: _nested_get_values(Tensor(a) self) -> Tensor(a)
self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional<Tensor>(at::_nested_get_max_seqlen(self)) : ::std::nullopt)"
# Transformers
# Transformer
- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
self: _softmax_backward_data(grad, result, dim, self.scalar_type())
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
output_differentiability: [True, False, False, False]
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)

View File

@ -408,6 +408,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.rrelu_with_noise,
aten.rrelu_with_noise_,
aten.rsub,
aten._safe_softmax,
aten._scaled_dot_product_flash_attention_for_cpu.default,
aten.select_backward,
aten.select_scatter,

View File

@ -421,6 +421,15 @@ def mse_loss_backward(
return norm * (input - target) * grad_output
@register_decomposition(aten._safe_softmax)
def safe_softmax(self, dim, dtype=None):
out = torch.softmax(self, dim=dim, dtype=dtype)
masked = self.eq(float("-inf"))
masked_rows = torch.all(masked, dim=dim, keepdim=True)
zeros = torch.zeros_like(out)
return torch.where(masked_rows, zeros, out)
@register_decomposition(aten.smooth_l1_loss)
@out_wrapper()
@pw_cast_for_opmath

View File

@ -457,10 +457,11 @@ def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate
@register_op_strategy(
[aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1)
[aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default],
schema_info=RuntimeSchemaInfo(1),
)
def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
input_strategy, softmax_dim, _ = op_schema.args_schema
input_strategy, softmax_dim, *_ = op_schema.args_schema
input_strategy = cast(OpStrategy, input_strategy)
softmax_dim = cast(int, softmax_dim)
softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim)

View File

@ -711,7 +711,6 @@ def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(lhs, args=(lhs.clone().detach_(),))
def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -4408,6 +4407,70 @@ def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs):
# Test case for no optional kwargs
yield SampleInput(make_arg((1, 2, 3)), kwargs={})
def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
def make_bool_mask(*shape):
return torch.randint(0, 2, shape, device=device, dtype=torch.bool)
def mask_two_rows(rows, cols):
mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device)
mask_two_rows[rows - 1] = False
mask_two_rows[rows - 3] = False
return mask_two_rows
def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor:
return torch.where(~mask, float('-inf'), 0.0)
def with_requires_grad(tensor):
return tensor.requires_grad_(requires_grad)
def generate_input_from_mask(mask_shape, dim):
mask = make_bool_mask(*mask_shape)
input_tensor = make_arg(mask_shape)
masked_input = input_tensor + convert_to_float_mask(mask)
return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim})
samples = [
# Basic 3D tensor with mask
generate_input_from_mask((2, 3, 4), dim=1),
# 2D tensor with mask, testing different dim
generate_input_from_mask((5, 5), dim=0),
# 4D tensor, testing with a different dim
generate_input_from_mask((2, 3, 4, 5), dim=2),
# Edge case: 1D tensor
generate_input_from_mask((10,), dim=0),
# Edge case: tensor with one dimension of size 1
generate_input_from_mask((1, 5, 5), dim=1),
# Testing with all elements masked
SampleInput(
with_requires_grad(
make_arg((3, 3))
+ convert_to_float_mask(
torch.zeros((3, 3), dtype=torch.bool, device=device)
)
),
kwargs={"dim": 1},
),
# Testing with no elements masked
SampleInput(
with_requires_grad(
make_arg((3, 3))
+ convert_to_float_mask(
torch.ones((3, 3), dtype=torch.bool, device=device)
)
),
kwargs={"dim": 1},
),
# Testing with two rows masked
SampleInput(
with_requires_grad(
make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3))
),
kwargs={"dim": 1},
),
]
yield from samples
def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -16131,6 +16194,24 @@ op_db: List[OpInfo] = [
dtypes=(torch.float8_e4m3fn,)),
)
),
OpInfo(
'torch.ops.aten._safe_softmax.default',
dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool),
sample_inputs_func=sample_inputs_safe_softmax,
assert_jit_shape_analysis=True,
assert_autodiffed=True,
supports_forward_ad=False,
supports_fwgrad_bwgrad=False,
supports_out=False,
supports_cow_input_no_materialize_backward=False,
decorators=[],
skips=(
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
),
),
OpInfo(
'nn.functional.scaled_dot_product_attention',
op=lambda *args, **kwargs: