mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Transformer/SDPA unit test parity (#163745)
## Major Changes * Efficient Attention on ROCM requires last dimensions of input tensors align with 16 bytes. - Unlike FA, ME does not pad input tensors in `scaled_dot_product_attention` and hence this is required. * Fix `atomic_counter` handling in varlen FA API * Unskips a few unit tests. Fixes #157120 Fixes #157121 Fixes #157122 Fixes #157167 Fixes #155217 Fixes #157043 Fixes #157060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163745 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
112e204797
commit
3cbfbbd691
@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if constexpr(caller_is_meff) {
|
||||
bool is_half = (params.query.dtype() == at::kHalf) ||
|
||||
(params.query.dtype() == at::kBFloat16);
|
||||
const int64_t alignment = is_half ? 8 : 4;
|
||||
if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
|
||||
value_size_last % alignment == 0 && value_size_last > 0)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Mem efficient attention requires last dimension of inputs to be divisible by ",
|
||||
alignment,
|
||||
". ",
|
||||
"Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
params.key.sym_size(-1),
|
||||
", Value.size(-1): ",
|
||||
params.value.sym_size(-1),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::mk_philoxtensor;
|
||||
using sdp::aotriton_adapter::mk_atomictensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
at::Tensor atomic_counter;
|
||||
if (is_causal) {
|
||||
atomic_counter = at::zeros({1}, q.options());
|
||||
atomic_counter = at::zeros({1}, q.options().dtype(at::kInt));
|
||||
}
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
|
||||
@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
auto nullscalar = mk_philoxtensor(nullptr);
|
||||
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
|
||||
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
|
||||
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
|
||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
|
||||
#if AOTRITON_V3_API
|
||||
using aotriton::v3::flash::CausalType;
|
||||
|
@ -17,7 +17,6 @@ from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize as parametrize_test,
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TEST_NUMPY,
|
||||
TEST_WITH_CROSSREF,
|
||||
)
|
||||
@ -746,7 +745,6 @@ class TestMultiheadAttentionNN(NNTestCase):
|
||||
|
||||
|
||||
class TestMultiheadAttentionNNDeviceType(NNTestCase):
|
||||
@skipIfRocm(msg="To investigate: yields NaN")
|
||||
def test_multihead_self_attn_two_masks_fast_path(self, device):
|
||||
"""
|
||||
Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path
|
||||
|
@ -15,7 +15,6 @@ from torch.testing._internal.common_cuda import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skipIfRocm,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TestCase,
|
||||
)
|
||||
@ -463,7 +462,6 @@ class TestFlopCounter(TestCase):
|
||||
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
|
||||
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
|
||||
|
||||
@skipIfRocm # Nested tensor
|
||||
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
@ -683,7 +681,6 @@ class TestFlopCounter(TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
@skipIfRocm # Nested tensor
|
||||
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
|
@ -39,7 +39,7 @@ from torch.testing._internal.common_utils import dtype_name, freeze_rng_state, r
|
||||
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
|
||||
skipIfTorchDynamo, gcIfJetson, set_default_dtype
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
|
||||
_get_torch_rocm_version
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
|
||||
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
|
||||
@ -3167,7 +3167,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
||||
[2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
|
||||
torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
|
||||
|
||||
@skipIfRocm(msg='Large numerical errors')
|
||||
def test_transformerdecoder(self):
|
||||
def get_a_test_layer(use_cuda, activation, batch_first=False):
|
||||
d_model = 4
|
||||
@ -13020,8 +13019,6 @@ if __name__ == '__main__':
|
||||
@dtypes(torch.float)
|
||||
@dtypesIfCUDA(torch.double, torch.float, torch.half)
|
||||
def test_transformerencoderlayer(self, device, dtype):
|
||||
if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
|
||||
self.skipTest("Skip on ROCM due to Flash Attention tolerances")
|
||||
# this is a deterministic test for TransformerEncoderLayer
|
||||
d_model = 4
|
||||
nhead = 2
|
||||
@ -13243,8 +13240,6 @@ if __name__ == '__main__':
|
||||
@dtypes(torch.float)
|
||||
@dtypesIfCUDA(torch.half, torch.float)
|
||||
def test_transformerencoderlayer_gelu(self, device, dtype):
|
||||
if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
|
||||
self.skipTest("Skip on ROCM due to Flash Attention tolerances")
|
||||
# this is a deterministic test for TransformerEncoderLayer with gelu activation
|
||||
d_model = 4
|
||||
nhead = 2
|
||||
|
@ -344,9 +344,6 @@ class TestTransformers(NNTestCase):
|
||||
@parametrize("key_padding_mask_dim", [2, None])
|
||||
@parametrize("mask_dtype", [torch.bool, torch.float32])
|
||||
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
|
||||
if TEST_WITH_ROCM:
|
||||
if attn_mask_dim is not None and mask_dtype == torch.bool:
|
||||
self.skipTest("boolean mask is not fully supported on ROCm yet.")
|
||||
# MHA converts all
|
||||
with torch.no_grad():
|
||||
B = 2
|
||||
@ -429,8 +426,7 @@ class TestTransformers(NNTestCase):
|
||||
# remove hook
|
||||
handle.remove()
|
||||
|
||||
@skipIfRocm
|
||||
@tf32_on_and_off(0.001)
|
||||
@tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001)
|
||||
@parametrize("use_torchscript", [False])
|
||||
@parametrize("enable_nested_tensor", [True, False])
|
||||
@parametrize("use_autocast", [True, False])
|
||||
@ -1420,7 +1416,6 @@ class TestTransformers(NNTestCase):
|
||||
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@skipIfRocm # Missing EFFICIENT_ATTENTION
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
|
||||
)
|
||||
@ -1713,7 +1708,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
|
||||
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
||||
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
||||
ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
|
||||
ctxmgr = self.assertRaises(RuntimeError)
|
||||
with ctxmgr:
|
||||
torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
|
||||
|
||||
@ -2611,7 +2606,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
|
||||
return S_converted[:, :, :seqlen_q, :seqlen_k]
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
def test_cudnn_attention_different_dk_dv(self, device):
|
||||
dtype = torch.bfloat16
|
||||
@ -2635,7 +2629,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
def test_cudnn_attention_gqa(self, device):
|
||||
batch = 4
|
||||
@ -2659,7 +2652,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
self.assertEqual(output_math, output_cudnn)
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
def test_cudnn_attention_d256_heuristic(self, device):
|
||||
dtype = torch.bfloat16
|
||||
@ -2690,7 +2682,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
||||
test()
|
||||
|
||||
@skipIfRocm(msg="No cuDNN on ROCm")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
def test_fused_attention_different_dk_dv(self, device):
|
||||
dtype = torch.bfloat16
|
||||
@ -2714,7 +2705,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
||||
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
@unittest.skipIf(True, "broken as of cuDNN 9.10")
|
||||
def test_cudnn_attention_fail_d128(self, device):
|
||||
# Test that cuDNN attention dispatching correctly bails out on d > 128
|
||||
@ -2736,7 +2727,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
||||
torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
@skipIfRocm(msg="No cuDNN on ROCm")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_trivial_output_transpose(self, device):
|
||||
# see also: https://github.com/pytorch/pytorch/issues/134001
|
||||
@ -2752,7 +2742,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
o.backward(o)
|
||||
torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3)
|
||||
|
||||
@skipIfRocm # No cuDNN Attention
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_nonmodulo64seqlen(self, device):
|
||||
# see also: https://github.com/pytorch/pytorch/issues/137347
|
||||
@ -2792,7 +2781,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
|
||||
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_preserves_query_layout(self, device):
|
||||
|
||||
@ -2822,7 +2810,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
for permute_order in permute_orders:
|
||||
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||
def test_cudnn_attention_compiles(self):
|
||||
q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True)
|
||||
@ -3232,7 +3219,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
|
||||
|
||||
@skipIfRocm
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
|
||||
|
Reference in New Issue
Block a user