mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
sdp::SDPBackend::flash_attention support PrivateUse1 (#126392)
Fixes https://github.com/pytorch/pytorch/issues/124271 cc @cpuhrsch @drisspg @albanD @soulitzer Pull Request resolved: https://github.com/pytorch/pytorch/pull/126392 Approved by: https://github.com/drisspg
This commit is contained in:
@ -153,6 +153,13 @@ void Context::setSDPUseCuDNN(bool e) {
|
||||
enabled_cudnnSDP = e;
|
||||
}
|
||||
|
||||
void Context::setSDPUseOverrideable(bool e) {
|
||||
enabled_overrideable = e;
|
||||
}
|
||||
|
||||
bool Context::userEnabledOverrideableSDP() const {
|
||||
return enabled_overrideable;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
|
||||
|
@ -216,6 +216,9 @@ class TORCH_API Context {
|
||||
void setSDPUseCuDNN(bool);
|
||||
bool userEnabledCuDNNSDP() const;
|
||||
|
||||
void setSDPUseOverrideable(bool);
|
||||
bool userEnabledOverrideableSDP() const;
|
||||
|
||||
at::LinalgBackend linalgPreferredBackend() const;
|
||||
void setLinalgPreferredBackend(at::LinalgBackend);
|
||||
|
||||
@ -368,6 +371,7 @@ class TORCH_API Context {
|
||||
bool enabled_mem_efficientSDP = true;
|
||||
bool enabled_mathSDP = true;
|
||||
bool enabled_cudnnSDP = false;
|
||||
bool enabled_overrideable = true;
|
||||
#ifdef USE_ROCM
|
||||
bool benchmark_cudnn = true;
|
||||
#else
|
||||
|
@ -14762,6 +14762,11 @@
|
||||
CPU: _scaled_dot_product_flash_attention_cpu
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
@ -14775,6 +14780,12 @@
|
||||
dispatch:
|
||||
CPU: _scaled_dot_product_flash_attention_cpu_backward
|
||||
|
||||
- func: _scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias)
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward
|
||||
|
||||
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_efficient_attention_cuda
|
||||
|
@ -43,6 +43,10 @@
|
||||
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h>
|
||||
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h>
|
||||
#include <ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h>
|
||||
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable.h>
|
||||
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_native.h>
|
||||
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h>
|
||||
#include <ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_native.h>
|
||||
#include <ATen/ops/_softmax.h>
|
||||
#include <ATen/ops/_transform_bias_rescale_qkv.h>
|
||||
#include <ATen/ops/_transform_bias_rescale_qkv_native.h>
|
||||
@ -683,6 +687,11 @@ Tensor scaled_dot_product_attention(
|
||||
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
|
||||
return std::get<0>(out_and_lse);
|
||||
}
|
||||
case sdp::SDPBackend::overrideable: {
|
||||
auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable(
|
||||
query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale);
|
||||
return std::get<0>(out_lse_softmax);
|
||||
}
|
||||
case sdp::SDPBackend::math:
|
||||
return std::get<0>(at::_scaled_dot_product_attention_math(
|
||||
query_,
|
||||
@ -838,6 +847,46 @@ _scaled_dot_product_flash_attention_cpu_backward(
|
||||
return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
|
||||
_scaled_dot_product_fused_attention_overrideable(
|
||||
const at::Tensor & query,
|
||||
const at::Tensor & key,
|
||||
const at::Tensor & value,
|
||||
const c10::optional<at::Tensor> & attn_bias,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
std::optional<double> scale) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable not implemented. This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
_scaled_dot_product_fused_attention_overrideable_backward(
|
||||
const at::Tensor & grad_out,
|
||||
const at::Tensor & query,
|
||||
const at::Tensor & key,
|
||||
const at::Tensor & value,
|
||||
const at::Tensor & attn_bias,
|
||||
std::array<bool,4> grad_input_mask,
|
||||
const at::Tensor & out,
|
||||
const at::Tensor & logsumexp,
|
||||
const at::Tensor & cum_seq_q,
|
||||
const at::Tensor & cum_seq_k,
|
||||
int64_t max_q,
|
||||
int64_t max_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
const at::Tensor & philox_seed,
|
||||
const at::Tensor & philox_offset,
|
||||
std::optional<double> scale) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable_backward not implemented: This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function ");
|
||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>(
|
||||
at::empty_like(query),
|
||||
at::empty_like(key),
|
||||
at::empty_like(value),
|
||||
at::empty_like(attn_bias));
|
||||
}
|
||||
|
||||
Tensor triton_multi_head_attention(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
|
@ -22,13 +22,14 @@
|
||||
|
||||
namespace sdp {
|
||||
|
||||
constexpr int32_t num_backends = 4;
|
||||
constexpr int32_t num_backends = 5;
|
||||
enum class SDPBackend {
|
||||
error = -1,
|
||||
math = 0,
|
||||
flash_attention = 1,
|
||||
efficient_attention = 2,
|
||||
cudnn_attention = 3
|
||||
cudnn_attention = 3,
|
||||
overrideable = 4
|
||||
};
|
||||
|
||||
// Note that if this changed make sure to update
|
||||
|
@ -21,6 +21,8 @@
|
||||
#include <ATen/core/GeneratorForPrivateuseone.h>
|
||||
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
||||
#include <ATen/ops/view.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
|
||||
static uint64_t add_counter = 0;
|
||||
static uint64_t last_saved_value = 0;
|
||||
@ -125,12 +127,18 @@ void quantize_tensor_per_tensor_affine_privateuse1(
|
||||
// do nothing
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
|
||||
const c10::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, c10::optional<double> scale){
|
||||
auto backend = sdp::SDPBackend::overrideable;
|
||||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace at::native {
|
||||
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
|
||||
REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
|
||||
|
||||
} // namespace at::native
|
||||
struct CustomBackendMetadata : public c10::BackendMeta {
|
||||
@ -458,6 +466,59 @@ const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
|
||||
return self;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
|
||||
custom_scaled_dot_product_fused_attention_overrideable(
|
||||
const at::Tensor & query,
|
||||
const at::Tensor & key,
|
||||
const at::Tensor & value,
|
||||
const c10::optional<at::Tensor> & attn_bias,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
std::optional<double> scale) {
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t max_seqlen_q = query.size(2);
|
||||
const int64_t max_seqlen_kv = key.size(2);
|
||||
|
||||
auto opts = query.options();
|
||||
auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
|
||||
auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
|
||||
opts.dtype(at::kFloat));
|
||||
auto philox_seed = at::empty({}, at::dtype(at::kLong));
|
||||
auto philox_offset = at::empty({}, at::dtype(at::kLong));
|
||||
|
||||
return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
|
||||
}
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
custom_scaled_dot_product_fused_attention_overrideable_backward(
|
||||
const at::Tensor & grad_out,
|
||||
const at::Tensor & query,
|
||||
const at::Tensor & key,
|
||||
const at::Tensor & value,
|
||||
const at::Tensor & attn_bias,
|
||||
std::array<bool,4> grad_input_mask,
|
||||
const at::Tensor & out,
|
||||
const at::Tensor & logsumexp,
|
||||
const at::Tensor & cum_seq_q,
|
||||
const at::Tensor & cum_seq_k,
|
||||
int64_t max_q,
|
||||
int64_t max_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
const at::Tensor & philox_seed,
|
||||
const at::Tensor & philox_offset,
|
||||
std::optional<double> scale) {
|
||||
return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
|
||||
at::empty_like(query),
|
||||
at::empty_like(key),
|
||||
at::empty_like(value),
|
||||
at::empty_like(attn_bias));
|
||||
}
|
||||
|
||||
// This macro does the heavy lifting.
|
||||
// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
|
||||
// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
|
||||
@ -482,6 +543,9 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
m.impl("resize_", &custom_resize_);
|
||||
m.impl("as_strided", at::native::as_strided_tensorimpl);
|
||||
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
|
||||
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
|
||||
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
|
||||
m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
|
||||
}
|
||||
|
||||
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
|
@ -508,6 +508,8 @@ aten::_scaled_dot_product_efficient_attention_backward
|
||||
aten::_scaled_dot_product_flash_attention
|
||||
aten::_scaled_dot_product_flash_attention_backward
|
||||
aten::_scaled_dot_product_flash_attention_for_cpu_backward
|
||||
aten::_scaled_dot_product_fused_attention_overrideable
|
||||
aten::_scaled_dot_product_fused_attention_overrideable_backward
|
||||
aten::_scaled_mm
|
||||
aten::_scaled_mm.out
|
||||
aten::_segment_reduce_backward
|
||||
|
@ -26,6 +26,7 @@ all_operators_with_namedtuple_return = {
|
||||
|
||||
all_operators_with_namedtuple_return_skip_list = {
|
||||
'_scaled_dot_product_flash_attention',
|
||||
'_scaled_dot_product_fused_attention_overrideable',
|
||||
'_scaled_dot_product_flash_attention_for_cpu',
|
||||
'_scaled_dot_product_efficient_attention',
|
||||
'_scaled_dot_product_cudnn_attention',
|
||||
|
@ -18,6 +18,7 @@ import itertools
|
||||
import torch.optim as optim
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU
|
||||
from typing import List, Tuple, Optional
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_nn import NNTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_ROCM,
|
||||
@ -47,6 +48,11 @@ from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION
|
||||
)
|
||||
|
||||
from test_cpp_extensions_open_device_registration import (
|
||||
remove_build_path,
|
||||
generate_faked_module
|
||||
)
|
||||
|
||||
if TEST_FAIRSEQ:
|
||||
import fairseq.models.transformer as fairseq_transformer
|
||||
|
||||
@ -3552,6 +3558,68 @@ class TestAttnBias(NNTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "CausalBias should not be used with causal=True"):
|
||||
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
|
||||
|
||||
class TestSDPAPrivateUse1Only(NNTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
remove_build_path()
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="custom_device_extension",
|
||||
sources=[
|
||||
"cpp_extensions/open_registration_extension.cpp",
|
||||
],
|
||||
extra_include_paths=["cpp_extensions"],
|
||||
extra_cflags=["-g"],
|
||||
verbose=True,
|
||||
)
|
||||
# register torch.foo module and foo device to torch
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
|
||||
torch._register_device_module("foo", generate_faked_module())
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_fused_sdp_choice_privateuseone(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("foo")
|
||||
k_privateuse1 = k_cpu.to("foo")
|
||||
v_privateuse1 = v_cpu.to("foo")
|
||||
assert torch._fused_sdp_choice(q_privateuse1, k_privateuse1, v_privateuse1) == SDPBackend.OVERRIDEABLE.value
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16)
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
q_privateuse1 = q_cpu.to("foo")
|
||||
k_privateuse1 = k_cpu.to("foo")
|
||||
v_privateuse1 = v_cpu.to("foo")
|
||||
actual = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_mask=None, dropout_p=0.0)
|
||||
|
||||
def test_scaled_dot_product_fused_attention_overrideable_backward(self):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 256, 2, 128
|
||||
make_tensor = partial(torch.rand, device="cpu", dtype=torch.float16, requires_grad=True)
|
||||
shape = (batch_size, num_heads, seq_len, head_dim)
|
||||
q_cpu, k_cpu, v_cpu = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
attn_mask = make_tensor((batch_size, num_heads, seq_len, seq_len))
|
||||
q_privateuse1 = q_cpu.to("foo")
|
||||
k_privateuse1 = k_cpu.to("foo")
|
||||
v_privateuse1 = v_cpu.to("foo")
|
||||
attn_mask_privateuse1 = attn_mask.to("foo")
|
||||
output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = \
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable(
|
||||
q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1)
|
||||
|
||||
rand_upward = torch.rand(shape, device="cpu", dtype=torch.float16, requires_grad=False)
|
||||
rand_upward_privateuse1 = rand_upward.to("foo")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
grad_q, grad_k, grad_v, grad_attn_mask = torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1, q_privateuse1, k_privateuse1, v_privateuse1, attn_mask_privateuse1,
|
||||
grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p=0.0,
|
||||
is_causal=False, philox_seed=philox_seed, philox_offset=philox_offset)
|
||||
|
||||
if NOTEST_CPU:
|
||||
device_types = ("cuda", )
|
||||
else:
|
||||
|
@ -2856,6 +2856,10 @@
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
# fft
|
||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
|
||||
|
@ -1152,6 +1152,8 @@ def _set_sdp_use_mem_efficient(
|
||||
) -> None: ... # THPModule_setSDPUseMemEfficient
|
||||
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
||||
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
||||
def _get_overrideable_sdp_enabled() -> _bool: ... # THPModule_userEnabledOverrideableSDP
|
||||
def _set_sdp_use_overrideable(arg: _bool) -> None: ... # THPModule_setSDPUseOverrideable
|
||||
def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
||||
def _set_sdp_use_cudnn(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
||||
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
|
||||
|
@ -740,6 +740,25 @@ PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseOverrideable(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(arg),
|
||||
"set_sdp_use_overrideable expects a bool, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setSDPUseOverrideable(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
PyObject* THPModule_userEnabledOverrideableSDP(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
if (at::globalContext().userEnabledOverrideableSDP())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseCuDNN(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
@ -1345,6 +1364,14 @@ static PyMethodDef TorchMethods[] = { // NOLINT
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
|
||||
{"_get_overrideable_sdp_enabled",
|
||||
THPModule_userEnabledOverrideableSDP,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_overrideable",
|
||||
THPModule_setSDPUseOverrideable,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_cudnn_sdp_enabled",
|
||||
THPModule_userEnabledCuDNNSDP,
|
||||
METH_NOARGS,
|
||||
@ -1930,7 +1957,8 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
.value("MATH", sdp::SDPBackend::math)
|
||||
.value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)
|
||||
.value("EFFICIENT_ATTENTION", sdp::SDPBackend::efficient_attention)
|
||||
.value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention);
|
||||
.value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention)
|
||||
.value("OVERRIDEABLE", sdp::SDPBackend::overrideable);
|
||||
|
||||
py_module.def(
|
||||
"_can_use_flash_attention",
|
||||
|
Reference in New Issue
Block a user