[ROCm] Transformer/SDPA unit test parity (#163745)

## Major Changes

* Efficient Attention on ROCM requires last dimensions of input tensors align with 16 bytes.
  - Unlike FA, ME does not pad input tensors in `scaled_dot_product_attention` and hence this is required.
* Fix `atomic_counter` handling in varlen FA API
* Unskips a few unit tests.

Fixes #157120
Fixes #157121
Fixes #157122
Fixes #157167
Fixes #155217
Fixes #157043
Fixes #157060

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163745
Approved by: https://github.com/jeffdaily
This commit is contained in:
Xinya Zhang
2025-09-25 17:14:16 +00:00
committed by PyTorch MergeBot
parent 112e204797
commit 3cbfbbd691
6 changed files with 29 additions and 30 deletions

View File

@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
}
return false;
}
if constexpr(caller_is_meff) {
bool is_half = (params.query.dtype() == at::kHalf) ||
(params.query.dtype() == at::kBFloat16);
const int64_t alignment = is_half ? 8 : 4;
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
value_size_last % alignment == 0 && value_size_last > 0)) {
if (debug) {
TORCH_WARN(
"Mem efficient attention requires last dimension of inputs to be divisible by ",
alignment,
". ",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.sym_size(-1),
", Value.size(-1): ",
params.value.sym_size(-1),
" instead.");
}
return false;
}
}
return true;
}

View File

@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::mk_atomictensor;
using sdp::aotriton_adapter::cast_dtype;
at::Tensor atomic_counter;
if (is_causal) {
atomic_counter = at::zeros({1}, q.options());
atomic_counter = at::zeros({1}, q.options().dtype(at::kInt));
}
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto nullscalar = mk_philoxtensor(nullptr);
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;

View File

@ -17,7 +17,6 @@ from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize as parametrize_test,
run_tests,
skipIfRocm,
TEST_NUMPY,
TEST_WITH_CROSSREF,
)
@ -746,7 +745,6 @@ class TestMultiheadAttentionNN(NNTestCase):
class TestMultiheadAttentionNNDeviceType(NNTestCase):
@skipIfRocm(msg="To investigate: yields NaN")
def test_multihead_self_attn_two_masks_fast_path(self, device):
"""
Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path

View File

@ -15,7 +15,6 @@ from torch.testing._internal.common_cuda import (
)
from torch.testing._internal.common_utils import (
run_tests,
skipIfRocm,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
@ -463,7 +462,6 @@ class TestFlopCounter(TestCase):
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
@skipIfRocm # Nested tensor
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION
@ -683,7 +681,6 @@ class TestFlopCounter(TestCase):
),
)
@skipIfRocm # Nested tensor
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,

View File

@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import dtype_name, freeze_rng_state, r
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
skipIfTorchDynamo, gcIfJetson, set_default_dtype
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
_get_torch_rocm_version
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
@ -3167,7 +3167,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
[2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
@skipIfRocm(msg='Large numerical errors')
def test_transformerdecoder(self):
def get_a_test_layer(use_cuda, activation, batch_first=False):
d_model = 4
@ -13020,8 +13019,6 @@ if __name__ == '__main__':
@dtypes(torch.float)
@dtypesIfCUDA(torch.double, torch.float, torch.half)
def test_transformerencoderlayer(self, device, dtype):
if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
self.skipTest("Skip on ROCM due to Flash Attention tolerances")
# this is a deterministic test for TransformerEncoderLayer
d_model = 4
nhead = 2
@ -13243,8 +13240,6 @@ if __name__ == '__main__':
@dtypes(torch.float)
@dtypesIfCUDA(torch.half, torch.float)
def test_transformerencoderlayer_gelu(self, device, dtype):
if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
self.skipTest("Skip on ROCM due to Flash Attention tolerances")
# this is a deterministic test for TransformerEncoderLayer with gelu activation
d_model = 4
nhead = 2

View File

@ -344,9 +344,6 @@ class TestTransformers(NNTestCase):
@parametrize("key_padding_mask_dim", [2, None])
@parametrize("mask_dtype", [torch.bool, torch.float32])
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
if TEST_WITH_ROCM:
if attn_mask_dim is not None and mask_dtype == torch.bool:
self.skipTest("boolean mask is not fully supported on ROCm yet.")
# MHA converts all
with torch.no_grad():
B = 2
@ -429,8 +426,7 @@ class TestTransformers(NNTestCase):
# remove hook
handle.remove()
@skipIfRocm
@tf32_on_and_off(0.001)
@tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001)
@parametrize("use_torchscript", [False])
@parametrize("enable_nested_tensor", [True, False])
@parametrize("use_autocast", [True, False])
@ -1420,7 +1416,6 @@ class TestTransformers(NNTestCase):
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
torch.cuda.synchronize()
@skipIfRocm # Missing EFFICIENT_ATTENTION
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
)
@ -1713,7 +1708,7 @@ class TestSDPAFailureModes(NNTestCase):
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
ctxmgr = self.assertRaises(RuntimeError)
with ctxmgr:
torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
@ -2611,7 +2606,6 @@ class TestSDPACudaOnly(NNTestCase):
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_cudnn_attention_different_dk_dv(self, device):
dtype = torch.bfloat16
@ -2635,7 +2629,6 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_cudnn_attention_gqa(self, device):
batch = 4
@ -2659,7 +2652,6 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(output_math, output_cudnn)
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_cudnn_attention_d256_heuristic(self, device):
dtype = torch.bfloat16
@ -2690,7 +2682,6 @@ class TestSDPACudaOnly(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
test()
@skipIfRocm(msg="No cuDNN on ROCm")
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_fused_attention_different_dk_dv(self, device):
dtype = torch.bfloat16
@ -2714,7 +2705,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
@unittest.skipIf(True, "broken as of cuDNN 9.10")
def test_cudnn_attention_fail_d128(self, device):
# Test that cuDNN attention dispatching correctly bails out on d > 128
@ -2736,7 +2727,6 @@ class TestSDPACudaOnly(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
torch.nn.functional.scaled_dot_product_attention(q, k, v)
@skipIfRocm(msg="No cuDNN on ROCm")
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_trivial_output_transpose(self, device):
# see also: https://github.com/pytorch/pytorch/issues/134001
@ -2752,7 +2742,6 @@ class TestSDPACudaOnly(NNTestCase):
o.backward(o)
torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3)
@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_nonmodulo64seqlen(self, device):
# see also: https://github.com/pytorch/pytorch/issues/137347
@ -2792,7 +2781,6 @@ class TestSDPACudaOnly(NNTestCase):
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_preserves_query_layout(self, device):
@ -2822,7 +2810,6 @@ class TestSDPACudaOnly(NNTestCase):
for permute_order in permute_orders:
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
@skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_compiles(self):
q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True)
@ -3232,7 +3219,6 @@ class TestSDPACudaOnly(NNTestCase):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
@skipIfRocm
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")