mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import vllm.plugins
|
|
from vllm.compilation.fusion import (
|
|
FUSED_OPS,
|
|
QUANT_OPS,
|
|
FusedRMSQuantKey,
|
|
RMSNormQuantFusionPass,
|
|
)
|
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
|
from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
GroupShape,
|
|
QuantKey,
|
|
ScaleDesc,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
Fp8LinearOp,
|
|
cutlass_fp8_supported,
|
|
maybe_create_device_identity,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
from ..utils import override_cutlass_fp8_supported
|
|
from .backend import TestBackend
|
|
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
|
|
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
eps: float,
|
|
static: bool,
|
|
cuda_force_torch: bool,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self.cuda_force_torch = cuda_force_torch
|
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
|
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
|
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
|
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
|
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
|
if static:
|
|
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
|
else:
|
|
self.scale = [None for _ in range(2)]
|
|
self.w = [
|
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
|
for _ in range(2)
|
|
]
|
|
|
|
with override_cutlass_fp8_supported(not cuda_force_torch):
|
|
self.fp8_linear = Fp8LinearOp(
|
|
act_quant_static=static,
|
|
act_quant_group_shape=group_shape,
|
|
)
|
|
|
|
def forward(self, x):
|
|
resid = torch.sqrt(x)
|
|
y = self.norm[0](x)
|
|
|
|
x2 = self.fp8_linear.apply(
|
|
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
|
)
|
|
# make sure resid is used for replacement to work
|
|
y2, resid = self.norm[1](x2, resid)
|
|
|
|
x3 = self.fp8_linear.apply(
|
|
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
|
)
|
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
|
return y3
|
|
|
|
def ops_in_model_before(self):
|
|
return [QUANT_OPS[self.key]]
|
|
|
|
def ops_in_model_after(self):
|
|
return [
|
|
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
|
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
@pytest.mark.parametrize("hidden_size", [64])
|
|
@pytest.mark.parametrize("num_tokens", [257])
|
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
|
@pytest.mark.parametrize("static", [True, False])
|
|
# cuda_force_torch used to test torch code path on platforms that
|
|
# cutlass_fp8_supported() == True.
|
|
@pytest.mark.parametrize(
|
|
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
|
|
)
|
|
@pytest.mark.skipif(
|
|
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
|
)
|
|
def test_fusion_rmsnorm_quant(
|
|
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
|
|
):
|
|
torch.set_default_device("cuda")
|
|
torch.set_default_dtype(dtype)
|
|
torch.manual_seed(1)
|
|
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
|
|
|
vllm_config = VllmConfig(
|
|
compilation_config=CompilationConfig(
|
|
level=CompilationLevel.PIECEWISE,
|
|
custom_ops=["+rms_norm", "+quant_fp8"],
|
|
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
|
)
|
|
)
|
|
with vllm.config.set_current_vllm_config(vllm_config):
|
|
# Reshape pass is needed for the fusion pass to work
|
|
noop_pass = NoOpEliminationPass(vllm_config)
|
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
|
cleanup_pass = PostCleanupPass(vllm_config)
|
|
|
|
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
|
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
|
|
|
# First dimension dynamic
|
|
x = torch.rand(num_tokens, hidden_size)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
|
|
result = model(x)
|
|
|
|
model2 = torch.compile(model, backend=backend)
|
|
result2 = model2(x)
|
|
|
|
# Higher tol for dynamic, even higher for bfloat16
|
|
if static:
|
|
ATOL, RTOL = (1e-3, 1e-3)
|
|
elif dtype == torch.float16:
|
|
ATOL, RTOL = (2e-3, 2e-3)
|
|
else:
|
|
ATOL, RTOL = (1e-2, 1e-2)
|
|
|
|
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
|
|
|
assert fusion_pass.matched_count == 2
|
|
|
|
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
|
backend.check_before_ops(model.ops_in_model_before())
|
|
|
|
# In post-nodes, fused kernels should be there and fp8 quant should not
|
|
backend.check_after_ops(model.ops_in_model_after())
|