diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index afe3b3814ea8..5ee8617fee4f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14724,6 +14724,11 @@ NestedTensorCUDA: NestedTensor_softmax_dropout_cuda tags: nondeterministic_seeded +- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + dispatch: + CompositeExplicitAutograd: _safe_softmax + NestedTensorCPU, NestedTensorCUDA: _safe_softmax + # Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is. - func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor variants: function diff --git a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp index dc31b2c0de24..df8d7c193e2c 100644 --- a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp @@ -14,7 +14,6 @@ #include #include - namespace at { namespace native { @@ -30,6 +29,7 @@ Tensor& NestedTensor_abs_(Tensor& self) { return self; } + Tensor NestedTensor_sgn(const Tensor& self) { return map_nt(self, at::sgn); } diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 74d203773389..5eb636a17f11 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include @@ -70,6 +69,9 @@ #include #include #include +#include +#include +#include #endif #include @@ -529,7 +531,6 @@ std::optional convert_boolean_attn_mask(const std::optional& att // Convert boolean mask to additive mask; need to invert mask to indicate what // to mask *out*. if (attn_mask->dtype() == at::kBool) { - // TODO Use the max type of the input and output return at::where(attn_mask->logical_not(), -std::numeric_limits::infinity(), at::scalar_tensor(0.0, at::TensorOptions().dtype(dtype).device(attn_mask->device()))); } // Otherwise, attn_mask represents an additive attention tensor @@ -641,6 +642,15 @@ std::tuple pre_process_group_query_attention_input( } // namespace +Tensor _safe_softmax( + const Tensor& self, + int64_t dim, + std::optional dtype) { + auto out = at::softmax(self, dim, dtype); + const auto masked = self.eq(-std::numeric_limits::infinity()); + const auto masked_rows = all(masked, dim, true); + return at::where(masked_rows, at::scalar_tensor(0.0, at::TensorOptions().dtype(out.dtype()).device(out.device())), out); +} // Computes scaled dot product attention on query, key and value tensors, using // an optional attention mask if passed, and applying dropout if a probability // greater than 0.0 is specified. diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 03744b7a8ef8..64c3a1706b5b 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1791,6 +1791,9 @@ class TestOperators(TestCase): ), # NYI: forward-AD for soft_margin_loss_backward xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward + xfail( + "torch.ops.aten._safe_softmax.default" + ), # NYI: forward-AD for _safe_softmax skip("nn.functional.scaled_dot_product_attention"), xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints xfail( @@ -1973,6 +1976,9 @@ class TestOperators(TestCase): xfail( "nn.functional.ctc_loss" ), # ForwardAD not implemented and no decomposition + xfail( + "torch.ops.aten._safe_softmax.default" + ), # ForwardAD not implemented xfail("nn.functional.dropout2d"), # calls random op xfail("nn.functional.dropout3d"), # calls random op xfail("nn.functional.dropout"), # calls random op diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index ef72716cecb5..25cfadb37cbf 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -374,6 +374,7 @@ inductor_override_kwargs = { "rtol": 0.02, }, ("sinc", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, + ("torch.ops.aten._safe_softmax.default", "cuda", f16): {"atol": 5e-4, "rtol": 0.02}, ("softmax", "cpu", f16): {"atol": 1e-4, "rtol": 0.02}, ("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02}, ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, diff --git a/test/test_decomp.py b/test/test_decomp.py index 3c90c0f16b70..b3ccf2985169 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -223,6 +223,7 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs) (torch.float16, torch.ops.aten.mv.default): 1e-5, (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5, + (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7, } if ref.is_floating_point(): orig_diff = (orig - ref).abs().max() diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e68758e95374..cbebe3ccd47d 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2854,7 +2854,10 @@ - name: _nested_get_values(Tensor(a) self) -> Tensor(a) self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? c10::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? c10::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt)" -# Transformers +# Transformer +- name: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + self: _softmax_backward_data(grad, result, dim, self.scalar_type()) + - name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) output_differentiability: [True, False, False, False] query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index fa516194cbad..816500704c84 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -408,6 +408,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.rrelu_with_noise, aten.rrelu_with_noise_, aten.rsub, + aten._safe_softmax, aten._scaled_dot_product_flash_attention_for_cpu.default, aten.select_backward, aten.select_scatter, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index c3f2ff203cf6..fde4b1cbc414 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -421,6 +421,15 @@ def mse_loss_backward( return norm * (input - target) * grad_output +@register_decomposition(aten._safe_softmax) +def safe_softmax(self, dim, dtype=None): + out = torch.softmax(self, dim=dim, dtype=dtype) + masked = self.eq(float("-inf")) + masked_rows = torch.all(masked, dim=dim, keepdim=True) + zeros = torch.zeros_like(out) + return torch.where(masked_rows, zeros, out) + + @register_decomposition(aten.smooth_l1_loss) @out_wrapper() @pw_cast_for_opmath @@ -1576,7 +1585,7 @@ def native_group_norm_backward( utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) torch._check( input.numel() == N * C * HxW, - lambda: f"Expect input to have { N * C * HxW} elements", + lambda: f"Expect input to have {N * C * HxW} elements", ) torch._check( mean.shape == (N, group), diff --git a/torch/distributed/_tensor/ops/_math_ops.py b/torch/distributed/_tensor/ops/_math_ops.py index 9a4f0f1c3341..43ad901a8b59 100644 --- a/torch/distributed/_tensor/ops/_math_ops.py +++ b/torch/distributed/_tensor/ops/_math_ops.py @@ -457,10 +457,11 @@ def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate @register_op_strategy( - [aten._log_softmax.default, aten._softmax.default], schema_info=RuntimeSchemaInfo(1) + [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default], + schema_info=RuntimeSchemaInfo(1), ) def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: - input_strategy, softmax_dim, _ = op_schema.args_schema + input_strategy, softmax_dim, *_ = op_schema.args_schema input_strategy = cast(OpStrategy, input_strategy) softmax_dim = cast(int, softmax_dim) softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c0323c54cc75..9696fd66156c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -711,7 +711,6 @@ def sample_inputs_equal(op, device, dtype, requires_grad, **kwargs): yield SampleInput(lhs, args=(lhs.clone().detach_(),)) - def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -4408,6 +4407,70 @@ def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): # Test case for no optional kwargs yield SampleInput(make_arg((1, 2, 3)), kwargs={}) +def sample_inputs_safe_softmax(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) + + def make_bool_mask(*shape): + return torch.randint(0, 2, shape, device=device, dtype=torch.bool) + + def mask_two_rows(rows, cols): + mask_two_rows = torch.ones((rows, cols), dtype=torch.bool, device=device) + mask_two_rows[rows - 1] = False + mask_two_rows[rows - 3] = False + return mask_two_rows + + def convert_to_float_mask(mask: torch.Tensor) -> torch.Tensor: + return torch.where(~mask, float('-inf'), 0.0) + + def with_requires_grad(tensor): + return tensor.requires_grad_(requires_grad) + + def generate_input_from_mask(mask_shape, dim): + mask = make_bool_mask(*mask_shape) + input_tensor = make_arg(mask_shape) + masked_input = input_tensor + convert_to_float_mask(mask) + return SampleInput(with_requires_grad(masked_input), kwargs={'dim': dim}) + + samples = [ + # Basic 3D tensor with mask + generate_input_from_mask((2, 3, 4), dim=1), + # 2D tensor with mask, testing different dim + generate_input_from_mask((5, 5), dim=0), + # 4D tensor, testing with a different dim + generate_input_from_mask((2, 3, 4, 5), dim=2), + # Edge case: 1D tensor + generate_input_from_mask((10,), dim=0), + # Edge case: tensor with one dimension of size 1 + generate_input_from_mask((1, 5, 5), dim=1), + # Testing with all elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.zeros((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with no elements masked + SampleInput( + with_requires_grad( + make_arg((3, 3)) + + convert_to_float_mask( + torch.ones((3, 3), dtype=torch.bool, device=device) + ) + ), + kwargs={"dim": 1}, + ), + # Testing with two rows masked + SampleInput( + with_requires_grad( + make_arg((6, 3)) + convert_to_float_mask(mask_two_rows(6, 3)) + ), + kwargs={"dim": 1}, + ), + ] + yield from samples def sample_inputs_layer_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -11666,7 +11729,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): in_rank = len(in_shape) for d in start_dim, end_dim: if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): - raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank-1}], but got {d}") + raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank - 1}], but got {d}") end_dim = end_dim if end_dim >= 0 else in_rank + end_dim start_dim = start_dim if start_dim >= 0 else in_rank + start_dim if in_rank == 0: @@ -16131,6 +16194,24 @@ op_db: List[OpInfo] = [ dtypes=(torch.float8_e4m3fn,)), ) ), + OpInfo( + 'torch.ops.aten._safe_softmax.default', + dtypes=all_types_and(torch.half, torch.bfloat16, torch.bool), + sample_inputs_func=sample_inputs_safe_softmax, + assert_jit_shape_analysis=True, + assert_autodiffed=True, + supports_forward_ad=False, + supports_fwgrad_bwgrad=False, + supports_out=False, + supports_cow_input_no_materialize_backward=False, + decorators=[], + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), + DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + ), + ), OpInfo( 'nn.functional.scaled_dot_product_attention', op=lambda *args, **kwargs: