[ROCm] Fix some kernels failed unit tests (#2498)

This commit is contained in:
Hongxia Yang
2024-02-05 17:25:36 -05:00
committed by GitHub
parent 72d3a30c63
commit 56f738ae9b
5 changed files with 62 additions and 12 deletions

View File

@ -0,0 +1,18 @@
import torch
# Reference default values of atol and rtol are from
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
default_rtol = {
torch.float16: 1e-3,
torch.bfloat16: 1.6e-2,
torch.float: 1.3e-6
}
def get_default_atol(output) -> float:
return default_atol[output.dtype]
def get_default_rtol(output) -> float:
return default_rtol[output.dtype]

View File

@ -2,6 +2,7 @@ import pytest
import torch
from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
from allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
@ -33,7 +34,10 @@ def test_silu_and_mul(
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@ -57,7 +61,10 @@ def test_gelu_new(
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@ -80,4 +87,7 @@ def test_gelu_fast(
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))

View File

@ -8,6 +8,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import ops, cache_ops
from vllm.utils import get_max_shared_memory_bytes
from vllm.utils import is_hip
from allclose_default import get_default_atol, get_default_rtol
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
@ -17,12 +19,18 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
DTYPES = [torch.half, torch.bfloat16, torch.float]
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [torch.half, torch.bfloat16, torch.float
] if not is_hip() else [torch.half, torch.bfloat16]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 256
] if not is_hip() else [64, 80, 96, 112, 128]
BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@ -251,9 +259,11 @@ def test_paged_attention(
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8_e5m2":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
@ -357,4 +367,6 @@ def test_multi_query_kv_attention(
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)

View File

@ -6,6 +6,7 @@ import torch
from typing import Tuple
from vllm._C import cache_ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -14,7 +15,10 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
# reduce the size for ROCm test to avoid HIP OOM
NUM_BLOCKS = [1024, 36000] if not is_hip else [
1024, 10000
] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [

View File

@ -2,7 +2,7 @@ from typing import Optional
import pytest
import torch
from allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope
IS_NEOX_STYLE = [True, False]
@ -64,5 +64,11 @@ def test_rotary_embedding(
ref_query, ref_key = rope._forward(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))