mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added batching rule for sdpa_math, sdpa_efficient_attention forward, cudnn, and flash attention (#133964)
Fixes https://github.com/pytorch/pytorch/issues/117016, https://github.com/pytorch/pytorch/issues/102457, https://github.com/pytorch/pytorch/issues/110525, https://github.com/pytorch/pytorch/issues/108065, Pull Request resolved: https://github.com/pytorch/pytorch/pull/133964 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -138,81 +138,6 @@ static Tensor _convolution_decomp(
|
||||
input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
|
||||
}
|
||||
|
||||
// TODO: delete the following after confirming performance
|
||||
// bool first_dim_has_size_1(const Tensor& value, int64_t bdim) {
|
||||
// if (bdim == 0) {
|
||||
// return value.size(1) == 1;
|
||||
// }
|
||||
// return value.size(0) == 1;
|
||||
// }
|
||||
//
|
||||
// std::tuple<Tensor,int64_t,Tensor,int64_t> cudnn_conv_per_sample_grad_rule(
|
||||
// const Tensor& self, std::optional<int64_t> self_bdim,
|
||||
// const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
|
||||
// const Tensor& weight, std::optional<int64_t> weight_bdim,
|
||||
// IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark,
|
||||
// bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
|
||||
// TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim);
|
||||
// // TODO: No clue if this works if the first non-batch dim isn't size 1
|
||||
// TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim));
|
||||
// TORCH_INTERNAL_ASSERT(self.dim() == 5);
|
||||
//
|
||||
// auto bdim_size = self.size(*self_bdim);
|
||||
// auto self_ = reshape_dim_into(*self_bdim, 0, self);
|
||||
// auto in_channels = self_.size(1);
|
||||
// auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
|
||||
//
|
||||
// auto grad_self = at::cudnn_convolution_backward_input(
|
||||
// self_.sizes(), grad_output_, weight,
|
||||
// padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
|
||||
// grad_self = reshape_dim_outof(0, bdim_size, grad_self);
|
||||
//
|
||||
// // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py
|
||||
// auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride);
|
||||
// auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)});
|
||||
// auto grad_sample = at::einsum("noq,npq->nop", {B, A});
|
||||
// grad_sample = grad_sample.view({
|
||||
// bdim_size, groups, -1, groups, in_channels / groups,
|
||||
// weight.size(2) * weight.size(3) });
|
||||
// grad_sample = at::einsum("ngrg...->ngr...", {grad_sample});
|
||||
// grad_sample = grad_sample.reshape(
|
||||
// {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)});
|
||||
//
|
||||
// return std::make_tuple(grad_self, 0, grad_sample, 0);
|
||||
// }
|
||||
//
|
||||
// std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
|
||||
// auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
// TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
|
||||
// int64_t cur_level = maybe_layer->layerId();
|
||||
//
|
||||
// Tensor self_value;
|
||||
// std::optional<int64_t> self_bdim;
|
||||
// std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
|
||||
// Tensor grad_output_value;
|
||||
// std::optional<int64_t> grad_output_bdim;
|
||||
// std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level);
|
||||
// Tensor weight_value;
|
||||
// std::optional<int64_t> weight_bdim;
|
||||
// std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
|
||||
//
|
||||
// if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) {
|
||||
// c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
||||
// auto result = cudnn_conv_per_sample_grad_rule(
|
||||
// self_value, self_bdim,
|
||||
// grad_output_value, grad_output_bdim,
|
||||
// weight_value, weight_bdim,
|
||||
// padding, stride, dilation, groups,
|
||||
// benchmark, deterministic, allow_tf32, output_mask);
|
||||
// return std::make_tuple(
|
||||
// makeBatched(std::get<0>(result), std::get<1>(result), cur_level),
|
||||
// makeBatched(std::get<2>(result), std::get<3>(result), cur_level));
|
||||
// }
|
||||
//
|
||||
// static auto op = c10::Dispatcher::singleton()
|
||||
// .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
|
||||
// return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
|
||||
|
||||
static Tensor compute_grad_bias(
|
||||
const Tensor& grad_output_, std::array<bool, 3> output_mask) {
|
||||
if (!output_mask[2]) {
|
||||
|
@ -261,6 +261,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
||||
OP_DECOMPOSE(special_xlogy);
|
||||
OP_DECOMPOSE2(special_xlogy, other_scalar);
|
||||
OP_DECOMPOSE2(special_xlogy, self_scalar);
|
||||
OP_DECOMPOSE(_scaled_dot_product_attention_math);
|
||||
|
||||
|
||||
m.impl("split.sizes", native::split_symint);
|
||||
|
@ -485,6 +485,145 @@ pinv_batch_rule(
|
||||
const std::optional<int64_t> rtol_bdim, bool hermitian) {
|
||||
return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_pinv, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "linalg.pinv");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
|
||||
_scaled_dot_product_flash_attention_batch_rule(
|
||||
const Tensor& query, std::optional<int64_t> query_bdim,
|
||||
const Tensor& key, std::optional<int64_t> key_bdim,
|
||||
const Tensor& value, std::optional<int64_t> value_bdim,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
c10::optional<double> scale
|
||||
) {
|
||||
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
|
||||
auto query_ = moveBatchDimToFront(query, query_bdim);
|
||||
auto key_ = moveBatchDimToFront(key, key_bdim);
|
||||
auto value_ = moveBatchDimToFront(value, value_bdim);
|
||||
query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
|
||||
key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
|
||||
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
|
||||
query_ = query_.flatten(0, 1);
|
||||
key_ = key_.flatten(0, 1);
|
||||
value_ = value_.flatten(0, 1);
|
||||
|
||||
const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_flash_attention(
|
||||
query_, key_, value_, dropout_p, is_causal, return_debug_mask, scale);
|
||||
|
||||
const auto res0_ = reshape_dim_outof(0, batch_size, res0);
|
||||
const auto res1_ = reshape_dim_outof(0, batch_size, res1);
|
||||
// res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
|
||||
// res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
|
||||
// res6 and res7 (philox seed and offset) are always non-batched
|
||||
const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
|
||||
|
||||
return std::make_tuple(
|
||||
res0_, 0,
|
||||
res1_, 0,
|
||||
res2, std::nullopt,
|
||||
res3, std::nullopt,
|
||||
res4,
|
||||
res5,
|
||||
res6, std::nullopt,
|
||||
res7, std::nullopt,
|
||||
res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
|
||||
);
|
||||
}
|
||||
|
||||
fourOutputs _scaled_dot_product_efficient_attention_batch_rule(
|
||||
const Tensor& query, optional<int64_t> query_bdim,
|
||||
const Tensor& key, optional<int64_t> key_bdim,
|
||||
const Tensor& value, optional<int64_t> value_bdim,
|
||||
const std::optional<Tensor>& attn_bias, optional<int64_t> attn_bias_bdim,
|
||||
bool compute_log_sumexp,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
c10::optional<double> scale
|
||||
) {
|
||||
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
|
||||
auto query_ = moveBatchDimToFront(query, query_bdim);
|
||||
auto key_ = moveBatchDimToFront(key, key_bdim);
|
||||
auto value_ = moveBatchDimToFront(value, value_bdim);
|
||||
query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
|
||||
key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
|
||||
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
|
||||
|
||||
query_ = query_.flatten(0, 1);
|
||||
key_ = key_.flatten(0, 1);
|
||||
value_ = value_.flatten(0, 1);
|
||||
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.has_value() && attn_bias->defined()) {
|
||||
attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
|
||||
}
|
||||
const auto [res0, res1, res2, res3] = at::_scaled_dot_product_efficient_attention(
|
||||
query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, scale);
|
||||
const auto res0_ = reshape_dim_outof(0, batch_size, res0);
|
||||
const auto res1_ = reshape_dim_outof(0, batch_size, res1);
|
||||
// philox seed is always non-batched
|
||||
return std::make_tuple(res0_, 0, res1_, 0, res2, std::nullopt, res3, std::nullopt);
|
||||
}
|
||||
|
||||
// Please unify SDPA APIs!!!
|
||||
std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
|
||||
_scaled_dot_product_cudnn_attention_batch_rule(
|
||||
const Tensor& query, std::optional<int64_t> query_bdim,
|
||||
const Tensor& key, std::optional<int64_t> key_bdim,
|
||||
const Tensor& value, std::optional<int64_t> value_bdim,
|
||||
const std::optional<Tensor>& attn_bias, std::optional<int64_t> attn_bias_bdim,
|
||||
bool compute_log_sumexp,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
c10::optional<double> scale
|
||||
) {
|
||||
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
|
||||
auto query_ = moveBatchDimToFront(query, query_bdim);
|
||||
auto key_ = moveBatchDimToFront(key, key_bdim);
|
||||
auto value_ = moveBatchDimToFront(value, value_bdim);
|
||||
query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
|
||||
key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
|
||||
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
|
||||
query_ = query_.flatten(0, 1);
|
||||
key_ = key_.flatten(0, 1);
|
||||
value_ = value_.flatten(0, 1);
|
||||
|
||||
std::optional<Tensor> attn_bias_;
|
||||
if (attn_bias.has_value() && attn_bias->defined()) {
|
||||
attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
|
||||
}
|
||||
|
||||
const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_cudnn_attention(
|
||||
query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale);
|
||||
|
||||
const auto res0_ = reshape_dim_outof(0, batch_size, res0);
|
||||
Tensor res1_;
|
||||
std::optional<int64_t> res1_bdim;
|
||||
if (compute_log_sumexp) {
|
||||
res1_ = reshape_dim_outof(0, batch_size, res1);
|
||||
res1_bdim = 0;
|
||||
} else {
|
||||
res1_ = res1;
|
||||
res1_bdim = std::nullopt;
|
||||
}
|
||||
// res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
|
||||
// res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
|
||||
// res6 and res7 (philox seed and offset) are always non-batched
|
||||
const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
|
||||
|
||||
return std::make_tuple(
|
||||
res0_, 0,
|
||||
res1_, res1_bdim,
|
||||
res2, std::nullopt,
|
||||
res3, std::nullopt,
|
||||
res4,
|
||||
res5,
|
||||
res6, std::nullopt,
|
||||
res7, std::nullopt,
|
||||
res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\
|
||||
@ -612,6 +751,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule);
|
||||
VMAP_SUPPORT(linalg_cross, cross_batch_rule);
|
||||
VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule);
|
||||
VMAP_SUPPORT(_scaled_dot_product_efficient_attention, _scaled_dot_product_efficient_attention_batch_rule);
|
||||
|
||||
VMAP_SUPPORT(_scaled_dot_product_flash_attention, _scaled_dot_product_flash_attention_batch_rule);
|
||||
VMAP_SUPPORT(_scaled_dot_product_cudnn_attention, _scaled_dot_product_cudnn_attention_batch_rule);
|
||||
|
||||
VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule);
|
||||
|
||||
|
@ -492,6 +492,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
REDUCTION_WITH_KEEPDIM_ARG(prod.dim_int);
|
||||
REDUCTION_BOXED_ARGS(std.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
|
||||
REDUCTION_NO_KEEPDIM_ARG(_softmax);
|
||||
REDUCTION_NO_KEEPDIM_ARG(_safe_softmax);
|
||||
REDUCTION_NO_KEEPDIM_ARG(sort);
|
||||
REDUCTION_BOXED_ARGS(sort.stable, 2, KEEPDIM_CASE_TRUE, -1);
|
||||
REDUCTION_BOXED_ARGS(std_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
|
||||
|
@ -44,8 +44,14 @@ from torch import Tensor
|
||||
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
|
||||
from torch._functorch.make_functional import functional_init_with_buffers
|
||||
from torch._functorch.vmap import restore_vmap
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.testing._internal.autograd_function_db import autograd_function_db
|
||||
from torch.testing._internal.common_cuda import with_tf32_off
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
with_tf32_off,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
@ -72,6 +78,19 @@ from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
def get_platform_specific_sdpa():
|
||||
ret = [SDPBackend.MATH]
|
||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
||||
ret.append(SDPBackend.FLASH_ATTENTION)
|
||||
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
|
||||
ret.append(SDPBackend.EFFICIENT_ATTENTION)
|
||||
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
|
||||
ret.append(SDPBackend.CUDNN_ATTENTION)
|
||||
return ret
|
||||
|
||||
|
||||
PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
|
||||
|
||||
FALLBACK_REGEX = "There is a performance drop"
|
||||
|
||||
|
||||
@ -3850,6 +3869,60 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
|
||||
x = torch.randn(2, 3, device=device, requires_grad=True)
|
||||
self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
|
||||
|
||||
@parametrize("backend", PLATFORM_SPECIFIC_SDPA)
|
||||
def test_sdpa(self, device, backend):
|
||||
if device == "cpu":
|
||||
raise unittest.SkipTest("This test is only for CUDA for now")
|
||||
|
||||
def T(*args):
|
||||
return torch.randn(*args, dtype=torch.float16, device=device)
|
||||
|
||||
backend_ctx = sdpa_kernel([backend])
|
||||
with backend_ctx:
|
||||
for batching in [
|
||||
(True, True, True),
|
||||
(True, False, False),
|
||||
(False, True, True),
|
||||
]:
|
||||
size = [8, 4, 128, 64]
|
||||
if batching[0]:
|
||||
query = T(3, *size)
|
||||
else:
|
||||
query = T(*size)
|
||||
if batching[1]:
|
||||
key = T(3, *size)
|
||||
else:
|
||||
key = T(*size)
|
||||
if batching[2]:
|
||||
value = T(3, *size)
|
||||
else:
|
||||
value = T(*size)
|
||||
in_dims = tuple(0 if b else None for b in batching)
|
||||
attention = F.scaled_dot_product_attention
|
||||
|
||||
self._vmap_test(
|
||||
attention,
|
||||
(query, key, value),
|
||||
in_dims=in_dims,
|
||||
)
|
||||
# Backwards test doesn't work yet
|
||||
# self._batched_grad_test(
|
||||
# lambda query, key, value: F.scaled_dot_product_attention(
|
||||
# query, key, value
|
||||
# ),
|
||||
# (query, key, value),
|
||||
# )
|
||||
|
||||
B = 4
|
||||
query = torch.rand(4, 32, B, 8, 128, dtype=torch.float16, device=device)
|
||||
key = torch.rand(4, B, 32, 8, 128, dtype=torch.float16, device=device)
|
||||
value = torch.rand(4, 32, 8, 128, dtype=torch.float16, device=device)
|
||||
self._vmap_test(
|
||||
F.scaled_dot_product_attention,
|
||||
(query, key, value),
|
||||
in_dims=(2, 1, None),
|
||||
)
|
||||
|
||||
@allowVmapFallbackUsage
|
||||
def test_inplace_view(self, device):
|
||||
leaf = torch.randn(4, 5, requires_grad=True)
|
||||
|
@ -16217,9 +16217,6 @@ op_db: List[OpInfo] = [
|
||||
decorators=[],
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
|
||||
DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
|
@ -207,7 +207,14 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
|
||||
return None
|
||||
if len(returns) == 0:
|
||||
return gen_vmap_plumbing_no_returns(native_function)
|
||||
if not all(ret.type.is_tensor_like() for ret in returns):
|
||||
return_symint_overrides = [
|
||||
"_scaled_dot_product_flash_attention",
|
||||
"_scaled_dot_product_cudnn_attention",
|
||||
]
|
||||
if (
|
||||
not all(ret.type.is_tensor_like() for ret in returns)
|
||||
and schema.name.unambiguous_name() not in return_symint_overrides
|
||||
):
|
||||
return None
|
||||
# in-place views need special handling
|
||||
if "inplace_view" in native_function.tags:
|
||||
|
Reference in New Issue
Block a user