chilli
2024-08-21 10:41:38 -07:00
committed by PyTorch MergeBot
parent e2ff094008
commit 938f37b745
7 changed files with 227 additions and 80 deletions

View File

@ -138,81 +138,6 @@ static Tensor _convolution_decomp(
input_r, weight_r, bias_r_opt, stride_, padding_, dilation_, transposed_, output_padding_, groups_);
}
// TODO: delete the following after confirming performance
// bool first_dim_has_size_1(const Tensor& value, int64_t bdim) {
// if (bdim == 0) {
// return value.size(1) == 1;
// }
// return value.size(0) == 1;
// }
//
// std::tuple<Tensor,int64_t,Tensor,int64_t> cudnn_conv_per_sample_grad_rule(
// const Tensor& self, std::optional<int64_t> self_bdim,
// const Tensor& grad_output, std::optional<int64_t> grad_output_bdim,
// const Tensor& weight, std::optional<int64_t> weight_bdim,
// IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark,
// bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
// TORCH_INTERNAL_ASSERT(self_bdim && grad_output_bdim && !weight_bdim);
// // TODO: No clue if this works if the first non-batch dim isn't size 1
// TORCH_INTERNAL_ASSERT(first_dim_has_size_1(self, *self_bdim));
// TORCH_INTERNAL_ASSERT(self.dim() == 5);
//
// auto bdim_size = self.size(*self_bdim);
// auto self_ = reshape_dim_into(*self_bdim, 0, self);
// auto in_channels = self_.size(1);
// auto grad_output_ = reshape_dim_into(*grad_output_bdim, 0, grad_output);
//
// auto grad_self = at::cudnn_convolution_backward_input(
// self_.sizes(), grad_output_, weight,
// padding, stride, dilation, groups, benchmark, deterministic, allow_tf32);
// grad_self = reshape_dim_outof(0, bdim_size, grad_self);
//
// // Copied from https://github.com/pytorch/opacus/blob/master/opacus/grad_sample/conv.py
// auto A = at::im2col(self_, {weight.size(2), weight.size(3)}, dilation, padding, stride);
// auto B = grad_output_.reshape({bdim_size, -1, A.size(-1)});
// auto grad_sample = at::einsum("noq,npq->nop", {B, A});
// grad_sample = grad_sample.view({
// bdim_size, groups, -1, groups, in_channels / groups,
// weight.size(2) * weight.size(3) });
// grad_sample = at::einsum("ngrg...->ngr...", {grad_sample});
// grad_sample = grad_sample.reshape(
// {bdim_size, weight.size(0), weight.size(1), weight.size(2), weight.size(3)});
//
// return std::make_tuple(grad_self, 0, grad_sample, 0);
// }
//
// std::tuple<Tensor,Tensor> cudnn_convolution_backward_plumbing(const Tensor & self, const Tensor & grad_output, const Tensor & weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
// auto maybe_layer = maybeCurrentDynamicLayer();
// TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
// int64_t cur_level = maybe_layer->layerId();
//
// Tensor self_value;
// std::optional<int64_t> self_bdim;
// std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
// Tensor grad_output_value;
// std::optional<int64_t> grad_output_bdim;
// std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output, cur_level);
// Tensor weight_value;
// std::optional<int64_t> weight_bdim;
// std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
//
// if (self_bdim.has_value() && self_value.dim() == 5 && first_dim_has_size_1(self_value, *self_bdim) && grad_output_bdim.has_value() && !weight_bdim.has_value()) {
// c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
// auto result = cudnn_conv_per_sample_grad_rule(
// self_value, self_bdim,
// grad_output_value, grad_output_bdim,
// weight_value, weight_bdim,
// padding, stride, dilation, groups,
// benchmark, deterministic, allow_tf32, output_mask);
// return std::make_tuple(
// makeBatched(std::get<0>(result), std::get<1>(result), cur_level),
// makeBatched(std::get<2>(result), std::get<3>(result), cur_level));
// }
//
// static auto op = c10::Dispatcher::singleton()
// .findSchemaOrThrow("aten::cudnn_convolution_backward", "");
// return slow_fallback<Tensor,Tensor>(op, { self, grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, output_mask });
static Tensor compute_grad_bias(
const Tensor& grad_output_, std::array<bool, 3> output_mask) {
if (!output_mask[2]) {

View File

@ -261,6 +261,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(special_xlogy);
OP_DECOMPOSE2(special_xlogy, other_scalar);
OP_DECOMPOSE2(special_xlogy, self_scalar);
OP_DECOMPOSE(_scaled_dot_product_attention_math);
m.impl("split.sizes", native::split_symint);

View File

@ -485,6 +485,145 @@ pinv_batch_rule(
const std::optional<int64_t> rtol_bdim, bool hermitian) {
return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_pinv, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "linalg.pinv");
}
std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
_scaled_dot_product_flash_attention_batch_rule(
const Tensor& query, std::optional<int64_t> query_bdim,
const Tensor& key, std::optional<int64_t> key_bdim,
const Tensor& value, std::optional<int64_t> value_bdim,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale
) {
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
auto query_ = moveBatchDimToFront(query, query_bdim);
auto key_ = moveBatchDimToFront(key, key_bdim);
auto value_ = moveBatchDimToFront(value, value_bdim);
query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
query_ = query_.flatten(0, 1);
key_ = key_.flatten(0, 1);
value_ = value_.flatten(0, 1);
const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_flash_attention(
query_, key_, value_, dropout_p, is_causal, return_debug_mask, scale);
const auto res0_ = reshape_dim_outof(0, batch_size, res0);
const auto res1_ = reshape_dim_outof(0, batch_size, res1);
// res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
// res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
// res6 and res7 (philox seed and offset) are always non-batched
const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
return std::make_tuple(
res0_, 0,
res1_, 0,
res2, std::nullopt,
res3, std::nullopt,
res4,
res5,
res6, std::nullopt,
res7, std::nullopt,
res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
);
}
fourOutputs _scaled_dot_product_efficient_attention_batch_rule(
const Tensor& query, optional<int64_t> query_bdim,
const Tensor& key, optional<int64_t> key_bdim,
const Tensor& value, optional<int64_t> value_bdim,
const std::optional<Tensor>& attn_bias, optional<int64_t> attn_bias_bdim,
bool compute_log_sumexp,
double dropout_p,
bool is_causal,
c10::optional<double> scale
) {
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
auto query_ = moveBatchDimToFront(query, query_bdim);
auto key_ = moveBatchDimToFront(key, key_bdim);
auto value_ = moveBatchDimToFront(value, value_bdim);
query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
query_ = query_.flatten(0, 1);
key_ = key_.flatten(0, 1);
value_ = value_.flatten(0, 1);
std::optional<Tensor> attn_bias_;
if (attn_bias.has_value() && attn_bias->defined()) {
attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
}
const auto [res0, res1, res2, res3] = at::_scaled_dot_product_efficient_attention(
query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, scale);
const auto res0_ = reshape_dim_outof(0, batch_size, res0);
const auto res1_ = reshape_dim_outof(0, batch_size, res1);
// philox seed is always non-batched
return std::make_tuple(res0_, 0, res1_, 0, res2, std::nullopt, res3, std::nullopt);
}
// Please unify SDPA APIs!!!
std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
_scaled_dot_product_cudnn_attention_batch_rule(
const Tensor& query, std::optional<int64_t> query_bdim,
const Tensor& key, std::optional<int64_t> key_bdim,
const Tensor& value, std::optional<int64_t> value_bdim,
const std::optional<Tensor>& attn_bias, std::optional<int64_t> attn_bias_bdim,
bool compute_log_sumexp,
double dropout_p,
bool is_causal,
bool return_debug_mask,
c10::optional<double> scale
) {
auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
auto query_ = moveBatchDimToFront(query, query_bdim);
auto key_ = moveBatchDimToFront(key, key_bdim);
auto value_ = moveBatchDimToFront(value, value_bdim);
query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
query_ = query_.flatten(0, 1);
key_ = key_.flatten(0, 1);
value_ = value_.flatten(0, 1);
std::optional<Tensor> attn_bias_;
if (attn_bias.has_value() && attn_bias->defined()) {
attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
}
const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_cudnn_attention(
query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale);
const auto res0_ = reshape_dim_outof(0, batch_size, res0);
Tensor res1_;
std::optional<int64_t> res1_bdim;
if (compute_log_sumexp) {
res1_ = reshape_dim_outof(0, batch_size, res1);
res1_bdim = 0;
} else {
res1_ = res1;
res1_bdim = std::nullopt;
}
// res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
// res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
// res6 and res7 (philox seed and offset) are always non-batched
const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
return std::make_tuple(
res0_, 0,
res1_, res1_bdim,
res2, std::nullopt,
res3, std::nullopt,
res4,
res5,
res6, std::nullopt,
res7, std::nullopt,
res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
);
}
}
#define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\
@ -612,6 +751,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule);
VMAP_SUPPORT(linalg_cross, cross_batch_rule);
VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule);
VMAP_SUPPORT(_scaled_dot_product_efficient_attention, _scaled_dot_product_efficient_attention_batch_rule);
VMAP_SUPPORT(_scaled_dot_product_flash_attention, _scaled_dot_product_flash_attention_batch_rule);
VMAP_SUPPORT(_scaled_dot_product_cudnn_attention, _scaled_dot_product_cudnn_attention_batch_rule);
VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule);

View File

@ -492,6 +492,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
REDUCTION_WITH_KEEPDIM_ARG(prod.dim_int);
REDUCTION_BOXED_ARGS(std.correction, 1, KEEPDIM_CASE_VARIABLE, 3);
REDUCTION_NO_KEEPDIM_ARG(_softmax);
REDUCTION_NO_KEEPDIM_ARG(_safe_softmax);
REDUCTION_NO_KEEPDIM_ARG(sort);
REDUCTION_BOXED_ARGS(sort.stable, 2, KEEPDIM_CASE_TRUE, -1);
REDUCTION_BOXED_ARGS(std_mean.correction, 1, KEEPDIM_CASE_VARIABLE, 3);

View File

@ -44,8 +44,14 @@ from torch import Tensor
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
from torch._functorch.make_functional import functional_init_with_buffers
from torch._functorch.vmap import restore_vmap
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.testing._internal.autograd_function_db import autograd_function_db
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
with_tf32_off,
)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
@ -72,6 +78,19 @@ from torch.testing._internal.custom_op_db import custom_op_db
from torch.utils import _pytree as pytree
def get_platform_specific_sdpa():
ret = [SDPBackend.MATH]
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
ret.append(SDPBackend.FLASH_ATTENTION)
if PLATFORM_SUPPORTS_MEM_EFF_ATTENTION:
ret.append(SDPBackend.EFFICIENT_ATTENTION)
if PLATFORM_SUPPORTS_CUDNN_ATTENTION:
ret.append(SDPBackend.CUDNN_ATTENTION)
return ret
PLATFORM_SPECIFIC_SDPA = get_platform_specific_sdpa()
FALLBACK_REGEX = "There is a performance drop"
@ -3850,6 +3869,60 @@ class TestVmapBatchedGradient(Namespace.TestVmapBase):
x = torch.randn(2, 3, device=device, requires_grad=True)
self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
@parametrize("backend", PLATFORM_SPECIFIC_SDPA)
def test_sdpa(self, device, backend):
if device == "cpu":
raise unittest.SkipTest("This test is only for CUDA for now")
def T(*args):
return torch.randn(*args, dtype=torch.float16, device=device)
backend_ctx = sdpa_kernel([backend])
with backend_ctx:
for batching in [
(True, True, True),
(True, False, False),
(False, True, True),
]:
size = [8, 4, 128, 64]
if batching[0]:
query = T(3, *size)
else:
query = T(*size)
if batching[1]:
key = T(3, *size)
else:
key = T(*size)
if batching[2]:
value = T(3, *size)
else:
value = T(*size)
in_dims = tuple(0 if b else None for b in batching)
attention = F.scaled_dot_product_attention
self._vmap_test(
attention,
(query, key, value),
in_dims=in_dims,
)
# Backwards test doesn't work yet
# self._batched_grad_test(
# lambda query, key, value: F.scaled_dot_product_attention(
# query, key, value
# ),
# (query, key, value),
# )
B = 4
query = torch.rand(4, 32, B, 8, 128, dtype=torch.float16, device=device)
key = torch.rand(4, B, 32, 8, 128, dtype=torch.float16, device=device)
value = torch.rand(4, 32, 8, 128, dtype=torch.float16, device=device)
self._vmap_test(
F.scaled_dot_product_attention,
(query, key, value),
in_dims=(2, 1, None),
)
@allowVmapFallbackUsage
def test_inplace_view(self, device):
leaf = torch.randn(4, 5, requires_grad=True)

View File

@ -16217,9 +16217,6 @@ op_db: List[OpInfo] = [
decorators=[],
skips=(
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapjvpall_has_batch_rule'),
DecorateInfo(unittest.expectedFailure, 'TestOperators', 'test_vmapvjp_has_batch_rule'),
DecorateInfo(unittest.expectedFailure, "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"),
),
),
OpInfo(

View File

@ -207,7 +207,14 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
return None
if len(returns) == 0:
return gen_vmap_plumbing_no_returns(native_function)
if not all(ret.type.is_tensor_like() for ret in returns):
return_symint_overrides = [
"_scaled_dot_product_flash_attention",
"_scaled_dot_product_cudnn_attention",
]
if (
not all(ret.type.is_tensor_like() for ret in returns)
and schema.name.unambiguous_name() not in return_symint_overrides
):
return None
# in-place views need special handling
if "inplace_view" in native_function.tags: