From 938f37b745dce8fcfc4602702a96ca65cc7f0825 Mon Sep 17 00:00:00 2001 From: chilli Date: Wed, 21 Aug 2024 10:41:38 -0700 Subject: [PATCH] Added batching rule for sdpa_math, sdpa_efficient_attention forward, cudnn, and flash attention (#133964) Fixes https://github.com/pytorch/pytorch/issues/117016, https://github.com/pytorch/pytorch/issues/102457, https://github.com/pytorch/pytorch/issues/110525, https://github.com/pytorch/pytorch/issues/108065, Pull Request resolved: https://github.com/pytorch/pytorch/pull/133964 Approved by: https://github.com/Skylion007 --- .../ATen/functorch/BatchRulesConvolution.cpp | 75 --------- .../functorch/BatchRulesDecompositions.cpp | 1 + .../functorch/BatchRulesLinearAlgebra.cpp | 143 ++++++++++++++++++ .../ATen/functorch/BatchRulesReduceOps.cpp | 1 + test/functorch/test_vmap.py | 75 ++++++++- .../_internal/common_methods_invocations.py | 3 - torchgen/gen_vmap_plumbing.py | 9 +- 7 files changed, 227 insertions(+), 80 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesConvolution.cpp b/aten/src/ATen/functorch/BatchRulesConvolution.cpp index c3db153cdcd3..3cf00f33def5 100644 --- a/aten/src/ATen/functorch/BatchRulesConvolution.cpp +++ b/aten/src/ATen/functorch/BatchRulesConvolution.cpp @@ -138,81 +138,6 @@ static Tensor _convolution_decomp( input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_); } -// TODO: delete the following after confirming performance -// bool first_dim_has_size_1(const Tensor& value, int64_t bdim) { -// if (bdim == 0) { -// return value.size(1) == 1; -// } -// return value.size(0) == 1; -// } -// -// std::tuple cudnn_conv_per_sample_grad_rule( -// const Tensor& self, std::optional self_bdim, -// const Tensor& grad_output, std::optional grad_output_bdim, -// const Tensor& weight, std::optional weight_bdim, -// IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, -// bool deterministic, bool allow_tf32, std::array output_mask) { -// TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim); -// // TODO: No clue if this works if the first non-batch dim isn't size 1 -// TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim)); -// TORCH_INTERNAL_ASSERT(self.dim() == 5); -// -// auto bdim_size = self.size(*self_bdim); -// auto self_ = reshape_dim_into(*self_bdim, 0, self); -// auto in_channels = self_.size(1); -// auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output); -// -// auto grad_self = at::cudnn_convolution_backward_input( -// self_.sizes(), grad_output_, weight, -// padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); -// grad_self = reshape_dim_outof(0, bdim_size, grad_self); -// -// // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py -// auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride); -// auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)}); -// auto grad_sample = at::einsum("noq,npq->nop", {B, A}); -// grad_sample = grad_sample.view({ -// bdim_size, groups, -1, groups, in_channels / groups, -// weight.size(2) * weight.size(3) }); -// grad_sample = at::einsum("ngrg...->ngr...", {grad_sample}); -// grad_sample = grad_sample.reshape( -// {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)}); -// -// return std::make_tuple(grad_self, 0, grad_sample, 0); -// } -// -// std::tuple cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array output_mask) { -// auto maybe_layer = maybeCurrentDynamicLayer(); -// TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); -// int64_t cur_level = maybe_layer->layerId(); -// -// Tensor self_value; -// std::optional self_bdim; -// std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); -// Tensor grad_output_value; -// std::optional grad_output_bdim; -// std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level); -// Tensor weight_value; -// std::optional weight_bdim; -// std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level); -// -// if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) { -// c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); -// auto result = cudnn_conv_per_sample_grad_rule( -// self_value, self_bdim, -// grad_output_value, grad_output_bdim, -// weight_value, weight_bdim, -// padding, stride, dilation, groups, -// benchmark, deterministic, allow_tf32, output_mask); -// return std::make_tuple( -// makeBatched(std::get<0>(result), std::get<1>(result), cur_level), -// makeBatched(std::get<2>(result), std::get<3>(result), cur_level)); -// } -// -// static auto op = c10::Dispatcher::singleton() -// .findSchemaOrThrow("aten::cudnn_convolution_backward", ""); -// return slow_fallback(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask }); - static Tensor compute_grad_bias( const Tensor& grad_output_, std::array output_mask) { if (!output_mask[2]) { diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 09b0d253ec4f..5739e88d5ddc 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -261,6 +261,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(special_xlogy); OP_DECOMPOSE2(special_xlogy, other_scalar); OP_DECOMPOSE2(special_xlogy, self_scalar); + OP_DECOMPOSE(_scaled_dot_product_attention_math); m.impl("split.sizes", native::split_symint); diff --git a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp index 8d0203da704e..6047a6eddb65 100644 --- a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp +++ b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp @@ -485,6 +485,145 @@ pinv_batch_rule( const std::optional rtol_bdim, bool hermitian) { return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_pinv, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "linalg.pinv"); } + +std::tuple, Tensor, std::optional, Tensor, std::optional, Tensor, std::optional, SymInt, SymInt, Tensor, std::optional, Tensor, std::optional, Tensor, std::optional> +_scaled_dot_product_flash_attention_batch_rule( + const Tensor& query, std::optional query_bdim, + const Tensor& key, std::optional key_bdim, + const Tensor& value, std::optional value_bdim, + double dropout_p, + bool is_causal, + bool return_debug_mask, + c10::optional scale +) { + auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); + auto query_ = moveBatchDimToFront(query, query_bdim); + auto key_ = moveBatchDimToFront(key, key_bdim); + auto value_ = moveBatchDimToFront(value, value_bdim); + query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size); + key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size); + value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size); + query_ = query_.flatten(0, 1); + key_ = key_.flatten(0, 1); + value_ = value_.flatten(0, 1); + + const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_flash_attention( + query_, key_, value_, dropout_p, is_causal, return_debug_mask, scale); + + const auto res0_ = reshape_dim_outof(0, batch_size, res0); + const auto res1_ = reshape_dim_outof(0, batch_size, res1); + // res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors + // res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping + // res6 and res7 (philox seed and offset) are always non-batched + const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8; + + return std::make_tuple( + res0_, 0, + res1_, 0, + res2, std::nullopt, + res3, std::nullopt, + res4, + res5, + res6, std::nullopt, + res7, std::nullopt, + res8_, return_debug_mask ? std::optional(0) : std::nullopt + ); +} + +fourOutputs _scaled_dot_product_efficient_attention_batch_rule( + const Tensor& query, optional query_bdim, + const Tensor& key, optional key_bdim, + const Tensor& value, optional value_bdim, + const std::optional& attn_bias, optional attn_bias_bdim, + bool compute_log_sumexp, + double dropout_p, + bool is_causal, + c10::optional scale +) { + auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); + auto query_ = moveBatchDimToFront(query, query_bdim); + auto key_ = moveBatchDimToFront(key, key_bdim); + auto value_ = moveBatchDimToFront(value, value_bdim); + query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size); + key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size); + value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size); + + query_ = query_.flatten(0, 1); + key_ = key_.flatten(0, 1); + value_ = value_.flatten(0, 1); + + std::optional attn_bias_; + if (attn_bias.has_value() && attn_bias->defined()) { + attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value(); + } + const auto [res0, res1, res2, res3] = at::_scaled_dot_product_efficient_attention( + query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, scale); + const auto res0_ = reshape_dim_outof(0, batch_size, res0); + const auto res1_ = reshape_dim_outof(0, batch_size, res1); + // philox seed is always non-batched + return std::make_tuple(res0_, 0, res1_, 0, res2, std::nullopt, res3, std::nullopt); +} + +// Please unify SDPA APIs!!! +std::tuple, Tensor, std::optional, Tensor, std::optional, Tensor, std::optional, SymInt, SymInt, Tensor, std::optional, Tensor, std::optional, Tensor, std::optional> +_scaled_dot_product_cudnn_attention_batch_rule( + const Tensor& query, std::optional query_bdim, + const Tensor& key, std::optional key_bdim, + const Tensor& value, std::optional value_bdim, + const std::optional& attn_bias, std::optional attn_bias_bdim, + bool compute_log_sumexp, + double dropout_p, + bool is_causal, + bool return_debug_mask, + c10::optional scale +) { + auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim); + auto query_ = moveBatchDimToFront(query, query_bdim); + auto key_ = moveBatchDimToFront(key, key_bdim); + auto value_ = moveBatchDimToFront(value, value_bdim); + query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size); + key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size); + value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size); + query_ = query_.flatten(0, 1); + key_ = key_.flatten(0, 1); + value_ = value_.flatten(0, 1); + + std::optional attn_bias_; + if (attn_bias.has_value() && attn_bias->defined()) { + attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value(); + } + + const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_cudnn_attention( + query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + + const auto res0_ = reshape_dim_outof(0, batch_size, res0); + Tensor res1_; + std::optional res1_bdim; + if (compute_log_sumexp) { + res1_ = reshape_dim_outof(0, batch_size, res1); + res1_bdim = 0; + } else { + res1_ = res1; + res1_bdim = std::nullopt; + } + // res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors + // res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping + // res6 and res7 (philox seed and offset) are always non-batched + const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8; + + return std::make_tuple( + res0_, 0, + res1_, res1_bdim, + res2, std::nullopt, + res3, std::nullopt, + res4, + res5, + res6, std::nullopt, + res7, std::nullopt, + res8_, return_debug_mask ? std::optional(0) : std::nullopt + ); +} + } #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\ @@ -612,6 +751,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule); VMAP_SUPPORT(linalg_cross, cross_batch_rule); VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule); + VMAP_SUPPORT(_scaled_dot_product_efficient_attention, _scaled_dot_product_efficient_attention_batch_rule); + + VMAP_SUPPORT(_scaled_dot_product_flash_attention, _scaled_dot_product_flash_attention_batch_rule); + VMAP_SUPPORT(_scaled_dot_product_cudnn_attention, _scaled_dot_product_cudnn_attention_batch_rule); VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule); diff --git a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp index 6c2a9f984714..8385660be0b3 100644 --- a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp @@ -492,6 +492,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { REDUCTION_WITH_KEEPDIM_ARG(prod.dim_int); REDUCTION_BOXED_ARGS(std.correction, 1, KEEPDIM_CASE_VARIABLE, 3); REDUCTION_NO_KEEPDIM_ARG(_softmax); + REDUCTION_NO_KEEPDIM_ARG(_safe_softmax); REDUCTION_NO_KEEPDIM_ARG(sort); REDUCTION_BOXED_ARGS(sort.stable, 2, KEEPDIM_CASE_TRUE, -1); REDUCTION_BOXED_ARGS(std_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3); diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 6698dc49d9e5..a8358e627962 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -44,8 +44,14 @@ from torch import Tensor from torch._C._functorch import reshape_dim_into, reshape_dim_outof from torch._functorch.make_functional import functional_init_with_buffers from torch._functorch.vmap import restore_vmap +from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing._internal.autograd_function_db import autograd_function_db -from torch.testing._internal.common_cuda import with_tf32_off +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_CUDNN_ATTENTION, + PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + with_tf32_off, +) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, @@ -72,6 +78,19 @@ from torch.testing._internal.custom_op_db import custom_op_db from torch.utils import _pytree as pytree +def get_platform_specific_sdpa(): + ret = [SDPBackend.MATH] + if PLATFORM_SUPPORTS_FLASH_ATTENTION: + ret.append(SDPBackend.FLASH_ATTENTION) + if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: + ret.append(SDPBackend.EFFICIENT_ATTENTION) + if PLATFORM_SUPPORTS_CUDNN_ATTENTION: + ret.append(SDPBackend.CUDNN_ATTENTION) + return ret + + +PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa() + FALLBACK_REGEX = "There is a performance drop" @@ -3850,6 +3869,60 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase): x = torch.randn(2, 3, device=device, requires_grad=True) self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) + @parametrize("backend", PLATFORM_SPECIFIC_SDPA) + def test_sdpa(self, device, backend): + if device == "cpu": + raise unittest.SkipTest("This test is only for CUDA for now") + + def T(*args): + return torch.randn(*args, dtype=torch.float16, device=device) + + backend_ctx = sdpa_kernel([backend]) + with backend_ctx: + for batching in [ + (True, True, True), + (True, False, False), + (False, True, True), + ]: + size = [8, 4, 128, 64] + if batching[0]: + query = T(3, *size) + else: + query = T(*size) + if batching[1]: + key = T(3, *size) + else: + key = T(*size) + if batching[2]: + value = T(3, *size) + else: + value = T(*size) + in_dims = tuple(0 if b else None for b in batching) + attention = F.scaled_dot_product_attention + + self._vmap_test( + attention, + (query, key, value), + in_dims=in_dims, + ) + # Backwards test doesn't work yet + # self._batched_grad_test( + # lambda query, key, value: F.scaled_dot_product_attention( + # query, key, value + # ), + # (query, key, value), + # ) + + B = 4 + query = torch.rand(4, 32, B, 8, 128, dtype=torch.float16, device=device) + key = torch.rand(4, B, 32, 8, 128, dtype=torch.float16, device=device) + value = torch.rand(4, 32, 8, 128, dtype=torch.float16, device=device) + self._vmap_test( + F.scaled_dot_product_attention, + (query, key, value), + in_dims=(2, 1, None), + ) + @allowVmapFallbackUsage def test_inplace_view(self, device): leaf = torch.randn(4, 5, requires_grad=True) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 43ffd6dbbf4e..bb8486802b91 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16217,9 +16217,6 @@ op_db: List[OpInfo] = [ decorators=[], skips=( DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'), - DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'), - DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), ), ), OpInfo( diff --git a/torchgen/gen_vmap_plumbing.py b/torchgen/gen_vmap_plumbing.py index 913be7784612..af9af6454eb0 100644 --- a/torchgen/gen_vmap_plumbing.py +++ b/torchgen/gen_vmap_plumbing.py @@ -207,7 +207,14 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str | None: return None if len(returns) == 0: return gen_vmap_plumbing_no_returns(native_function) - if not all(ret.type.is_tensor_like() for ret in returns): + return_symint_overrides = [ + "_scaled_dot_product_flash_attention", + "_scaled_dot_product_cudnn_attention", + ] + if ( + not all(ret.type.is_tensor_like() for ret in returns) + and schema.name.unambiguous_name() not in return_symint_overrides + ): return None # in-place views need special handling if "inplace_view" in native_function.tags: