Revert "[ROCm] SDPA fix mem fault when dropout is enabled (#154864)"

This reverts commit 3caddd4daa5b1a167663c07219e065e86247ad76.

Reverted https://github.com/pytorch/pytorch/pull/154864 on behalf of https://github.com/atalman due to reverted internally ([comment](https://github.com/pytorch/pytorch/pull/154864#issuecomment-3225554119))
This commit is contained in:
PyTorch MergeBot
2025-08-26 20:03:57 +00:00
parent caf98fde0d
commit 9f6e1b8730
8 changed files with 48 additions and 108 deletions

View File

@ -388,16 +388,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
dv_expanded = dv;
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
uint64_t drop_seed = 1, drop_offset = 0;
drop_seed = *philox_seed.data_ptr<int64_t>();
drop_offset = *philox_offset.data_ptr<int64_t>();
auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset);
uint64_t* drop_seed, drop_offset;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
std::pair<uint64_t*, uint64_t*> drop_seed_offset = {nullptr,nullptr};
if(is_dropout) {
drop_seed_offset.first = philox_seed[0].data_ptr<uint64_t>();
drop_seed_offset.second = philox_seed[1].data_ptr<uint64_t>();
}
if (seqlen_q > 0) {
ck_tile::stream_config stream_config{stream};

View File

@ -177,6 +177,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
const auto sizes = q.sizes();
const int batch_size = sizes[0];
int seqlen_q = sizes[1];
int num_heads = sizes[2];
@ -225,6 +226,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
at::Tensor q_padded, k_padded, v_padded;
if (head_size % 8 != 0) {
q_padded = at::pad(temp_q, {0, 8 - head_size % 8});
@ -237,6 +239,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
v_padded = v;
}
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
@ -263,6 +266,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
auto opts = q.options();
bool has_lse = true;
bool has_dropout = p_dropout > 0.0f;
at::Tensor softmax_lse;
// TODO - check gradient, only training require lse
softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
@ -273,41 +277,46 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte));
}
else {
p = at::empty({ 0 }, opts.dtype(at::kByte));
p = at::empty({ 0 }, opts);
}
uint64_t drop_seed = 1, drop_offset = 0;
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
auto rng_state = at::empty({2}, opts.dtype(at::kLong));
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
auto rng_state_options = at::TensorOptions().dtype(at::kUInt64).device(at::kCUDA);
auto rng_state = at::zeros({2}, rng_state_options.dtype(at::kUInt64));
auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
at::Tensor seed_t, offset_t;
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
auto philox_args = gen->philox_cuda_state(counter_offset);
std::tie(drop_seed, drop_offset) = at::cuda::philox::unpack(philox_args);
hipLaunchKernelGGL(
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
seed_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[0])), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[1])), at::dtype(at::kLong));
}
else
{
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
auto drop_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
std::optional<at::Tensor> attn_bias;
if( attn_bias_.has_value())
{
attn_bias = attn_bias_;
}
if (seqlen_k > 0) {
auto drop_seed_offset = std::make_pair(rng_state[0].data_ptr<uint64_t>(),
rng_state[1].data_ptr<uint64_t>());
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
auto stream = at::cuda::getCurrentHIPStream().stream();
ck_tile::stream_config stream_config{stream};
@ -323,7 +332,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
auto args =
get_ck_fmha_fwd_args(
has_lse,
has_dropout,
return_dropout_randval,
mask,
batch_size,
seqlen_q,
@ -349,11 +358,12 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
}
} //namespace pytorch_flash

View File

@ -169,10 +169,6 @@ These backends include:
.. autofunction:: torch.backends.cuda.sdp_kernel
```
```{eval-rst}
.. autofunction:: torch.backends.cuda.is_ck_sdpa_available
```
## torch.backends.cudnn
```{eval-rst}

View File

@ -49,7 +49,6 @@ from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
PLATFORM_SUPPORTS_FUSED_ATTENTION,
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
PLATFORM_SUPPORTS_CK_SDPA,
tf32_on_and_off,
tf32_enabled,
)
@ -86,6 +85,7 @@ isSM120Device = torch.cuda.is_available() and torch.cuda.get_device_capability()
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
TEST_WITH_CK = TEST_WITH_ROCM and torch.backends.cuda.preferred_rocm_fa_library() == torch.backends.cuda._ROCmFABackends['ck']
def _check_equal(
golden: torch.Tensor,
@ -3577,12 +3577,10 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("scale", [None, "l1"])
@parametrize("enable_gqa", [True, False])
@parametrize("n_heads", [[16, 8], [10, 2]])
@parametrize("sdpa_backend", ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"])
@tf32_enabled()
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
head_dim: int, is_causal: bool, dropout_p: float,
dtype: torch.dtype, scale: str, enable_gqa: bool,
n_heads: list[int], sdpa_backend: str):
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
scale: str, enable_gqa: bool, n_heads: list[int]):
if isSM8XDevice or isSM120Device and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k:
@ -3592,14 +3590,8 @@ class TestSDPACudaOnly(NNTestCase):
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
unittest.skip("Reference implementation OOM")
return
# ROCm now supports 2 different backends for SDPA that require different set up.
TEST_WITH_CK = False
if TEST_WITH_ROCM:
torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend)
# When no args are given to preferred_rocm_fa_library, it acts as a getter
TEST_WITH_CK = (torch.backends.cuda.preferred_rocm_fa_library() == torch._C._ROCmFABackend.Ck)
if TEST_WITH_CK and dropout_p != 0:
self.skipTest("CK does not support tensor format dropout masks")
if TEST_WITH_CK and head_dim > 128:
self.skipTest("CK does not support head dims over 128")
@ -3655,24 +3647,15 @@ class TestSDPACudaOnly(NNTestCase):
softmax_mask = self.convert_flash_attn_S_to_softmax(
dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
# This is the default implementation for the mask but we need to match CK if we are using it
dropout_mask = softmax_mask >= 0
# This logic matches how CK calculates the dropout mask.
# This is necessary because CK doesn't support passing in custom dropout masks
# So we use this logic to ensure we are comparing apples to apples.
if TEST_WITH_CK:
dropout_mask = (softmax_mask <= int((1.0 - dropout_p) * 255.0)).to(torch.float32)
# High Precision Math Reference
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal,
scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
query, key, value, dropout_mask=dropout_mask, dropout_p=dropout_p,
is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)[0]
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale,
dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0]
upstream_grad = torch.rand_like(out, requires_grad=False)
@ -3692,33 +3675,17 @@ class TestSDPACudaOnly(NNTestCase):
'grad_value': 4,
}
if TEST_WITH_ROCM:
if TEST_WITH_CK:
fudge_factors['out'] = 5
fudge_factors['grad_key'] = 145.0
fudge_factors['grad_query'] = 855.0 # ck min = 855.0
fudge_factors['grad_value'] = 6
if seq_len_k >= 1024:
fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048:
fudge_factors['grad_key'] = 190.0
fudge_factors['grad_query'] = 1550.0 # NEW CK MIN
if seq_len_q >= 2048:
fudge_factors['grad_query'] = 1100.0
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
else:
fudge_factors['grad_key'] = 45.0
fudge_factors['grad_query'] = 360.0
if seq_len_k >= 1024:
fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048:
fudge_factors['grad_key'] = 190.0
fudge_factors['grad_query'] = 650.0
if seq_len_q >= 2048:
fudge_factors['grad_query'] = 1100.0
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
fudge_factors['grad_key'] = 45.0
fudge_factors['grad_query'] = 360.0
if seq_len_k >= 1024:
fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048:
fudge_factors['grad_key'] = 190.0
fudge_factors['grad_query'] = 650.0
if seq_len_q >= 2048:
fudge_factors['grad_query'] = 1100.0
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
check_out_and_grad(
(out_ref, out_lp_ref, out),

View File

@ -2231,7 +2231,6 @@ def _is_flash_attention_available() -> _bool: ...
def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
def _is_ck_sdpa_available() -> _bool: ...
# Defined in torch/csrc/cuda/GdsFile.cpp
def _gds_register_buffer(t: Storage) -> None: ...

View File

@ -15,7 +15,6 @@ __all__ = [
"preferred_linalg_library",
"preferred_blas_library",
"preferred_rocm_fa_library",
"is_ck_sdpa_available",
"cufft_plan_cache",
"matmul",
"SDPAParams",
@ -333,16 +332,6 @@ SDPAParams.__module__ = "torch.backends.cuda"
SDPAParams.__name__ = "SDPAParams"
def is_ck_sdpa_available() -> bool:
r"""
.. warning:: This flag is beta and subject to change.
Returns whether composable_kernel may be used as the backend for
scaled-dot-product-attention.
"""
return torch._C._is_ck_sdpa_available()
def flash_sdp_enabled():
r"""
.. warning:: This flag is beta and subject to change.

View File

@ -2454,14 +2454,6 @@ Call this whenever a new thread is created in order to propagate values from
return at::globalContext().getROCmFAPreferredBackend();
});
py_module.def("_is_ck_sdpa_available", []() {
#ifdef USE_ROCM
return at::globalContext().ckSupported() && at::globalContext().hasCKSDPA();
#else
return false;
#endif
});
py_module.def(
"_set_sm_carveout_experimental", [](std::optional<int32_t> val) {
at::globalContext()._setSMCarveout_EXPERIMENTAL(val);

View File

@ -66,12 +66,6 @@ def evaluate_platform_supports_flash_attention():
return not IS_WINDOWS and SM80OrLater
return False
def evaluate_platform_supports_ck_sdpa():
if TEST_WITH_ROCM:
return torch.backends.cuda.is_ck_sdpa_available()
else:
return False
def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
@ -97,8 +91,6 @@ PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
PLATFORM_SUPPORTS_CK_SDPA: bool = LazyVal(lambda: evaluate_platform_supports_ck_sdpa())
def evaluate_platform_supports_fp8():
if torch.cuda.is_available():
if torch.version.hip: