sdp::SDPBackend::flash_attention support PrivateUse1 (#126392)

Fixes https://github.com/pytorch/pytorch/issues/124271

cc  @cpuhrsch @drisspg @albanD @soulitzer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126392
Approved by: https://github.com/drisspg
This commit is contained in:
FEI
2024-06-28 17:48:40 +00:00
committed by PyTorch MergeBot
parent 26d633b721
commit 59e4e92556
12 changed files with 244 additions and 3 deletions

View File

@ -153,6 +153,13 @@ void Context::setSDPUseCuDNN(bool e) {
enabled_cudnnSDP = e;
}
void Context::setSDPUseOverrideable(bool e) {
enabled_overrideable = e;
}
bool Context::userEnabledOverrideableSDP() const {
return enabled_overrideable;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";

View File

@ -216,6 +216,9 @@ class TORCH_API Context {
void setSDPUseCuDNN(bool);
bool userEnabledCuDNNSDP() const;
void setSDPUseOverrideable(bool);
bool userEnabledOverrideableSDP() const;
at::LinalgBackend linalgPreferredBackend() const;
void setLinalgPreferredBackend(at::LinalgBackend);
@ -368,6 +371,7 @@ class TORCH_API Context {
bool enabled_mem_efficientSDP = true;
bool enabled_mathSDP = true;
bool enabled_cudnnSDP = false;
bool enabled_overrideable = true;
#ifdef USE_ROCM
bool benchmark_cudnn = true;
#else

View File

@ -14762,6 +14762,11 @@
CPU: _scaled_dot_product_flash_attention_cpu
tags: nondeterministic_seeded
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
tags: nondeterministic_seeded
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
device_check: NoCheck
variants: function
@ -14775,6 +14780,12 @@
dispatch:
CPU: _scaled_dot_product_flash_attention_cpu_backward
- func: _scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias)
device_check: NoCheck
variants: function
dispatch:
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
dispatch:
CUDA: _scaled_dot_product_efficient_attention_cuda

View File

@ -43,6 +43,10 @@
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable.h>
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_native.h>
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_native.h>
#include <ATen/ops/_softmax.h>
#include <ATen/ops/_transform_bias_rescale_qkv.h>
#include <ATen/ops/_transform_bias_rescale_qkv_native.h>
@ -683,6 +687,11 @@ Tensor scaled_dot_product_attention(
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
return std::get<0>(out_and_lse);
}
case sdp::SDPBackend::overrideable: {
auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable(
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale);
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::math:
return std::get<0>(at::_scaled_dot_product_attention_math(
query_,
@ -838,6 +847,46 @@ _scaled_dot_product_flash_attention_cpu_backward(
return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable(
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const c10::optional<at::Tensor> & attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor & grad_out,
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const at::Tensor & attn_bias,
std::array<bool,4> grad_input_mask,
const at::Tensor & out,
const at::Tensor & logsumexp,
const at::Tensor & cum_seq_q,
const at::Tensor & cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor & philox_seed,
const at::Tensor & philox_offset,
std::optional<double> scale) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable_backward not implemented: This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
at::empty_like(query),
at::empty_like(key),
at::empty_like(value),
at::empty_like(attn_bias));
}
Tensor triton_multi_head_attention(
const Tensor& query,
const Tensor& key,

View File

@ -22,13 +22,14 @@
namespace sdp {
constexpr int32_t num_backends = 4;
constexpr int32_t num_backends = 5;
enum class SDPBackend {
error = -1,
math = 0,
flash_attention = 1,
efficient_attention = 2,
cudnn_attention = 3
cudnn_attention = 3,
overrideable = 4
};
// Note that if this changed make sure to update

View File

@ -21,6 +21,8 @@
#include <ATen/core/GeneratorForPrivateuseone.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>
#include <ATen/ops/view.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <ATen/native/transformers/attention.h>
static uint64_t add_counter = 0;
static uint64_t last_saved_value = 0;
@ -125,12 +127,18 @@ void quantize_tensor_per_tensor_affine_privateuse1(
// do nothing
}
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){
auto backend = sdp::SDPBackend::overrideable;
return static_cast<int64_t>(backend);
}
} // namespace
namespace at::native {
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
} // namespace at::native
struct CustomBackendMetadata : public c10::BackendMeta {
@ -458,6 +466,59 @@ const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
return self;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable(
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const c10::optional<at::Tensor> & attn_bias,
double dropout_p,
bool is_causal,
bool return_debug_mask,
std::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_q = query.size(2);
const int64_t max_seqlen_kv = key.size(2);
auto opts = query.options();
auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
opts.dtype(at::kFloat));
auto philox_seed = at::empty({}, at::dtype(at::kLong));
auto philox_offset = at::empty({}, at::dtype(at::kLong));
return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable_backward(
const at::Tensor & grad_out,
const at::Tensor & query,
const at::Tensor & key,
const at::Tensor & value,
const at::Tensor & attn_bias,
std::array<bool,4> grad_input_mask,
const at::Tensor & out,
const at::Tensor & logsumexp,
const at::Tensor & cum_seq_q,
const at::Tensor & cum_seq_k,
int64_t max_q,
int64_t max_k,
double dropout_p,
bool is_causal,
const at::Tensor & philox_seed,
const at::Tensor & philox_offset,
std::optional<double> scale) {
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
at::empty_like(query),
at::empty_like(key),
at::empty_like(value),
at::empty_like(attn_bias));
}
// This macro does the heavy lifting.
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
@ -482,6 +543,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("resize_", &custom_resize_);
m.impl("as_strided", at::native::as_strided_tensorimpl);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
}
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {

View File

@ -508,6 +508,8 @@ aten::_scaled_dot_product_efficient_attention_backward
aten::_scaled_dot_product_flash_attention
aten::_scaled_dot_product_flash_attention_backward
aten::_scaled_dot_product_flash_attention_for_cpu_backward
aten::_scaled_dot_product_fused_attention_overrideable
aten::_scaled_dot_product_fused_attention_overrideable_backward
aten::_scaled_mm
aten::_scaled_mm.out
aten::_segment_reduce_backward

View File

@ -26,6 +26,7 @@ all_operators_with_namedtuple_return = {
all_operators_with_namedtuple_return_skip_list = {
'_scaled_dot_product_flash_attention',
'_scaled_dot_product_fused_attention_overrideable',
'_scaled_dot_product_flash_attention_for_cpu',
'_scaled_dot_product_efficient_attention',
'_scaled_dot_product_cudnn_attention',

View File

@ -18,6 +18,7 @@ import itertools
import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
from typing import List, Tuple, Optional
import torch.utils.cpp_extension
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
@ -47,6 +48,11 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_CUDNN_ATTENTION
)
from test_cpp_extensions_open_device_registration import (
remove_build_path,
generate_faked_module
)
if TEST_FAIRSEQ:
import fairseq.models.transformer as fairseq_transformer
@ -3552,6 +3558,68 @@ class TestAttnBias(NNTestCase):
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
class TestSDPAPrivateUse1Only(NNTestCase):
@classmethod
def setUpClass(cls):
remove_build_path()
cls.module = torch.utils.cpp_extension.load(
name="custom_device_extension",
sources=[
"cpp_extensions/open_registration_extension.cpp",
],
extra_include_paths=["cpp_extensions"],
extra_cflags=["-g"],
verbose=True,
)
# register torch.foo module and foo device to torch
torch.utils.rename_privateuse1_backend("foo")
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
torch._register_device_module("foo", generate_faked_module())
@skipIfTorchDynamo()
def test_fused_sdp_choice_privateuseone(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("foo")
k_privateuse1 = k_cpu.to("foo")
v_privateuse1 = v_cpu.to("foo")
assert torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1) == SDPBackend.OVERRIDEABLE.value
def test_scaled_dot_product_fused_attention_overrideable(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
q_privateuse1 = q_cpu.to("foo")
k_privateuse1 = k_cpu.to("foo")
v_privateuse1 = v_cpu.to("foo")
actual = torch.nn.functional.scaled_dot_product_attention(
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0)
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16, requires_grad=True)
shape = (batch_size, num_heads, seq_len, head_dim)
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
q_privateuse1 = q_cpu.to("foo")
k_privateuse1 = k_cpu.to("foo")
v_privateuse1 = v_cpu.to("foo")
attn_mask_privateuse1 = attn_mask.to("foo")
output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = \
torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1)
rand_upward = torch.rand(shape, device="cpu", dtype=torch.float16, requires_grad=False)
rand_upward_privateuse1 = rand_upward.to("foo")
grad_input_mask = [True, True, True, True]
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
rand_upward_privateuse1, q_privateuse1, k_privateuse1, v_privateuse1, attn_mask_privateuse1,
grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p=0.0,
is_causal=False, philox_seed=philox_seed, philox_offset=philox_offset)
if NOTEST_CPU:
device_types = ("cuda", )
else:

View File

@ -2856,6 +2856,10 @@
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
# fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))

View File

@ -1152,6 +1152,8 @@ def _set_sdp_use_mem_efficient(
) -> None: ... # THPModule_setSDPUseMemEfficient
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
def _get_overrideable_sdp_enabled() -> _bool: ... # THPModule_userEnabledOverrideableSDP
def _set_sdp_use_overrideable(arg: _bool) -> None: ... # THPModule_setSDPUseOverrideable
def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
def _set_sdp_use_cudnn(arg: _bool) -> None: ... # THPModule_setSDPUseMath
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn

View File

@ -740,6 +740,25 @@ PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseOverrideable(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
PyBool_Check(arg),
"set_sdp_use_overrideable expects a bool, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setSDPUseOverrideable(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPModule_userEnabledOverrideableSDP(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().userEnabledOverrideableSDP())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseCuDNN(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
@ -1345,6 +1364,14 @@ static PyMethodDef TorchMethods[] = { // NOLINT
METH_NOARGS,
nullptr},
{"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
{"_get_overrideable_sdp_enabled",
THPModule_userEnabledOverrideableSDP,
METH_NOARGS,
nullptr},
{"_set_sdp_use_overrideable",
THPModule_setSDPUseOverrideable,
METH_O,
nullptr},
{"_get_cudnn_sdp_enabled",
THPModule_userEnabledCuDNNSDP,
METH_NOARGS,
@ -1930,7 +1957,8 @@ Call this whenever a new thread is created in order to propagate values from
.value("MATH", sdp::SDPBackend::math)
.value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)
.value("EFFICIENT_ATTENTION", sdp::SDPBackend::efficient_attention)
.value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention);
.value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention)
.value("OVERRIDEABLE", sdp::SDPBackend::overrideable);
py_module.def(
"_can_use_flash_attention",