mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ROCm] SDPA fix mem fault when dropout is enabled (#154864)"
This reverts commit 3caddd4daa5b1a167663c07219e065e86247ad76. Reverted https://github.com/pytorch/pytorch/pull/154864 on behalf of https://github.com/atalman due to reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154864#issuecomment-3225554119))
This commit is contained in:
@ -388,16 +388,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
drop_seed = *philox_seed.data_ptr<int64_t>();
|
||||
drop_offset = *philox_offset.data_ptr<int64_t>();
|
||||
auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset);
|
||||
|
||||
uint64_t* drop_seed, drop_offset;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
std::pair<uint64_t*, uint64_t*> drop_seed_offset = {nullptr,nullptr};
|
||||
if(is_dropout) {
|
||||
drop_seed_offset.first = philox_seed[0].data_ptr<uint64_t>();
|
||||
drop_seed_offset.second = philox_seed[1].data_ptr<uint64_t>();
|
||||
}
|
||||
|
||||
if (seqlen_q > 0) {
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
@ -177,6 +177,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
int seqlen_q = sizes[1];
|
||||
int num_heads = sizes[2];
|
||||
@ -225,6 +226,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size % 8 != 0) {
|
||||
q_padded = at::pad(temp_q, {0, 8 - head_size % 8});
|
||||
@ -237,6 +239,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
v_padded = v;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
@ -263,6 +266,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
bool has_dropout = p_dropout > 0.0f;
|
||||
|
||||
at::Tensor softmax_lse;
|
||||
// TODO - check gradient, only training require lse
|
||||
softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
@ -273,41 +277,46 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte));
|
||||
}
|
||||
else {
|
||||
p = at::empty({ 0 }, opts.dtype(at::kByte));
|
||||
p = at::empty({ 0 }, opts);
|
||||
}
|
||||
|
||||
|
||||
uint64_t drop_seed = 1, drop_offset = 0;
|
||||
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
|
||||
auto rng_state = at::empty({2}, opts.dtype(at::kLong));
|
||||
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
auto rng_state_options = at::TensorOptions().dtype(at::kUInt64).device(at::kCUDA);
|
||||
auto rng_state = at::zeros({2}, rng_state_options.dtype(at::kUInt64));
|
||||
auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
|
||||
|
||||
at::Tensor seed_t, offset_t;
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
|
||||
std::tie(drop_seed, drop_offset) = at::cuda::philox::unpack(philox_args);
|
||||
|
||||
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[0])), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[1])), at::dtype(at::kLong));
|
||||
}
|
||||
else
|
||||
{
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
}
|
||||
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
|
||||
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
|
||||
auto drop_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
|
||||
|
||||
std::optional<at::Tensor> attn_bias;
|
||||
if( attn_bias_.has_value())
|
||||
{
|
||||
attn_bias = attn_bias_;
|
||||
}
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state[0].data_ptr<uint64_t>(),
|
||||
rng_state[1].data_ptr<uint64_t>());
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
@ -323,7 +332,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
auto args =
|
||||
get_ck_fmha_fwd_args(
|
||||
has_lse,
|
||||
has_dropout,
|
||||
return_dropout_randval,
|
||||
mask,
|
||||
batch_size,
|
||||
seqlen_q,
|
||||
@ -349,11 +358,12 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
out.zero_();
|
||||
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
|
||||
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
||||
}
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
}
|
||||
} //namespace pytorch_flash
|
||||
|
@ -169,10 +169,6 @@ These backends include:
|
||||
.. autofunction:: torch.backends.cuda.sdp_kernel
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.backends.cuda.is_ck_sdpa_available
|
||||
```
|
||||
|
||||
## torch.backends.cudnn
|
||||
|
||||
```{eval-rst}
|
||||
|
@ -49,7 +49,6 @@ from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
||||
PLATFORM_SUPPORTS_CK_SDPA,
|
||||
tf32_on_and_off,
|
||||
tf32_enabled,
|
||||
)
|
||||
@ -86,6 +85,7 @@ isSM120Device = torch.cuda.is_available() and torch.cuda.get_device_capability()
|
||||
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
|
||||
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
|
||||
|
||||
TEST_WITH_CK = TEST_WITH_ROCM and torch.backends.cuda.preferred_rocm_fa_library() == torch.backends.cuda._ROCmFABackends['ck']
|
||||
|
||||
def _check_equal(
|
||||
golden: torch.Tensor,
|
||||
@ -3577,12 +3577,10 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
@parametrize("scale", [None, "l1"])
|
||||
@parametrize("enable_gqa", [True, False])
|
||||
@parametrize("n_heads", [[16, 8], [10, 2]])
|
||||
@parametrize("sdpa_backend", ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"])
|
||||
@tf32_enabled()
|
||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||
head_dim: int, is_causal: bool, dropout_p: float,
|
||||
dtype: torch.dtype, scale: str, enable_gqa: bool,
|
||||
n_heads: list[int], sdpa_backend: str):
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str, enable_gqa: bool, n_heads: list[int]):
|
||||
if isSM8XDevice or isSM120Device and head_dim in range(193, 256 + 1):
|
||||
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
|
||||
if is_causal and seq_len_q != seq_len_k:
|
||||
@ -3592,14 +3590,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
|
||||
unittest.skip("Reference implementation OOM")
|
||||
return
|
||||
|
||||
# ROCm now supports 2 different backends for SDPA that require different set up.
|
||||
TEST_WITH_CK = False
|
||||
if TEST_WITH_ROCM:
|
||||
torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend)
|
||||
# When no args are given to preferred_rocm_fa_library, it acts as a getter
|
||||
TEST_WITH_CK = (torch.backends.cuda.preferred_rocm_fa_library() == torch._C._ROCmFABackend.Ck)
|
||||
|
||||
if TEST_WITH_CK and dropout_p != 0:
|
||||
self.skipTest("CK does not support tensor format dropout masks")
|
||||
if TEST_WITH_CK and head_dim > 128:
|
||||
self.skipTest("CK does not support head dims over 128")
|
||||
|
||||
@ -3655,24 +3647,15 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
|
||||
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
|
||||
|
||||
# This is the default implementation for the mask but we need to match CK if we are using it
|
||||
dropout_mask = softmax_mask >= 0
|
||||
|
||||
# This logic matches how CK calculates the dropout mask.
|
||||
# This is necessary because CK doesn't support passing in custom dropout masks
|
||||
# So we use this logic to ensure we are comparing apples to apples.
|
||||
if TEST_WITH_CK:
|
||||
dropout_mask = (softmax_mask <= int((1.0 - dropout_p) * 255.0)).to(torch.float32)
|
||||
|
||||
# High Precision Math Reference
|
||||
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
|
||||
scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
|
||||
# Low Precision Math Reference
|
||||
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||
query, key, value, dropout_mask=dropout_mask, dropout_p=dropout_p,
|
||||
is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)[0]
|
||||
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
|
||||
dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
|
||||
|
||||
upstream_grad = torch.rand_like(out, requires_grad=False)
|
||||
|
||||
@ -3692,33 +3675,17 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
'grad_value': 4,
|
||||
}
|
||||
if TEST_WITH_ROCM:
|
||||
if TEST_WITH_CK:
|
||||
fudge_factors['out'] = 5
|
||||
fudge_factors['grad_key'] = 145.0
|
||||
fudge_factors['grad_query'] = 855.0 # ck min = 855.0
|
||||
fudge_factors['grad_value'] = 6
|
||||
if seq_len_k >= 1024:
|
||||
fudge_factors['grad_key'] = 70.0
|
||||
if seq_len_k >= 2048:
|
||||
fudge_factors['grad_key'] = 190.0
|
||||
fudge_factors['grad_query'] = 1550.0 # NEW CK MIN
|
||||
if seq_len_q >= 2048:
|
||||
fudge_factors['grad_query'] = 1100.0
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
else:
|
||||
fudge_factors['grad_key'] = 45.0
|
||||
fudge_factors['grad_query'] = 360.0
|
||||
if seq_len_k >= 1024:
|
||||
fudge_factors['grad_key'] = 70.0
|
||||
if seq_len_k >= 2048:
|
||||
fudge_factors['grad_key'] = 190.0
|
||||
fudge_factors['grad_query'] = 650.0
|
||||
if seq_len_q >= 2048:
|
||||
fudge_factors['grad_query'] = 1100.0
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
|
||||
fudge_factors['grad_key'] = 45.0
|
||||
fudge_factors['grad_query'] = 360.0
|
||||
if seq_len_k >= 1024:
|
||||
fudge_factors['grad_key'] = 70.0
|
||||
if seq_len_k >= 2048:
|
||||
fudge_factors['grad_key'] = 190.0
|
||||
fudge_factors['grad_query'] = 650.0
|
||||
if seq_len_q >= 2048:
|
||||
fudge_factors['grad_query'] = 1100.0
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
|
||||
check_out_and_grad(
|
||||
(out_ref, out_lp_ref, out),
|
||||
|
@ -2231,7 +2231,6 @@ def _is_flash_attention_available() -> _bool: ...
|
||||
def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
|
||||
def _is_ck_sdpa_available() -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/GdsFile.cpp
|
||||
def _gds_register_buffer(t: Storage) -> None: ...
|
||||
|
@ -15,7 +15,6 @@ __all__ = [
|
||||
"preferred_linalg_library",
|
||||
"preferred_blas_library",
|
||||
"preferred_rocm_fa_library",
|
||||
"is_ck_sdpa_available",
|
||||
"cufft_plan_cache",
|
||||
"matmul",
|
||||
"SDPAParams",
|
||||
@ -333,16 +332,6 @@ SDPAParams.__module__ = "torch.backends.cuda"
|
||||
SDPAParams.__name__ = "SDPAParams"
|
||||
|
||||
|
||||
def is_ck_sdpa_available() -> bool:
|
||||
r"""
|
||||
.. warning:: This flag is beta and subject to change.
|
||||
|
||||
Returns whether composable_kernel may be used as the backend for
|
||||
scaled-dot-product-attention.
|
||||
"""
|
||||
return torch._C._is_ck_sdpa_available()
|
||||
|
||||
|
||||
def flash_sdp_enabled():
|
||||
r"""
|
||||
.. warning:: This flag is beta and subject to change.
|
||||
|
@ -2454,14 +2454,6 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
return at::globalContext().getROCmFAPreferredBackend();
|
||||
});
|
||||
|
||||
py_module.def("_is_ck_sdpa_available", []() {
|
||||
#ifdef USE_ROCM
|
||||
return at::globalContext().ckSupported() && at::globalContext().hasCKSDPA();
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_set_sm_carveout_experimental", [](std::optional<int32_t> val) {
|
||||
at::globalContext()._setSMCarveout_EXPERIMENTAL(val);
|
||||
|
@ -66,12 +66,6 @@ def evaluate_platform_supports_flash_attention():
|
||||
return not IS_WINDOWS and SM80OrLater
|
||||
return False
|
||||
|
||||
def evaluate_platform_supports_ck_sdpa():
|
||||
if TEST_WITH_ROCM:
|
||||
return torch.backends.cuda.is_ck_sdpa_available()
|
||||
else:
|
||||
return False
|
||||
|
||||
def evaluate_platform_supports_efficient_attention():
|
||||
if TEST_WITH_ROCM:
|
||||
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
|
||||
@ -97,8 +91,6 @@ PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
||||
|
||||
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
|
||||
|
||||
PLATFORM_SUPPORTS_CK_SDPA: bool = LazyVal(lambda: evaluate_platform_supports_ck_sdpa())
|
||||
|
||||
def evaluate_platform_supports_fp8():
|
||||
if torch.cuda.is_available():
|
||||
if torch.version.hip:
|
||||
|
Reference in New Issue
Block a user