mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI Perf] Only test bfloat16 for tests/compile/test_fusion_all_reduce.py (#23132)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -148,7 +148,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
|
Reference in New Issue
Block a user