mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@ -13,21 +13,26 @@ class TestBackend:
|
|||||||
This class provides a simple Inductor backend that can be used for testing.
|
This class provides a simple Inductor backend that can be used for testing.
|
||||||
It takes a list of custom passes and runs them after Inductor's passes.
|
It takes a list of custom passes and runs them after Inductor's passes.
|
||||||
It also saves the graph before and after the custom passes for inspection.
|
It also saves the graph before and after the custom passes for inspection.
|
||||||
|
|
||||||
|
Inductor config can be modified directly by editing the inductor_config
|
||||||
|
property. This can be helpful for adding passes like the
|
||||||
|
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
||||||
None]]):
|
None]]):
|
||||||
self.custom_passes = list(passes)
|
self.custom_passes = list(passes)
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
self.current_config = config.shallow_copy_dict()
|
self.inductor_config = config.shallow_copy_dict()
|
||||||
self.current_config['force_disable_caches'] = True
|
self.inductor_config['force_disable_caches'] = True
|
||||||
self.current_config['post_grad_custom_post_pass'] = self.post_pass
|
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
|
||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||||
|
self.graph_pre_compile = deepcopy(graph)
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
return compile_fx(graph,
|
return compile_fx(graph,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
config_patches=self.current_config)
|
config_patches=self.inductor_config)
|
||||||
|
|
||||||
def post_pass(self, graph: fx.Graph):
|
def post_pass(self, graph: fx.Graph):
|
||||||
self.graph_pre_pass = deepcopy(graph)
|
self.graph_pre_pass = deepcopy(graph)
|
||||||
|
@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
|||||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||||
enable_reshape=True)
|
enable_noop=True)
|
||||||
reshape_pass = RedundantReshapesPass(config)
|
noop_pass = NoOpEliminationPass(config)
|
||||||
fusion_pass = FusionPass.instance(config)
|
fusion_pass = FusionPass.instance(config)
|
||||||
|
|
||||||
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
|
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||||
func_pass = FixFunctionalizationPass(config)
|
func_pass = FixFunctionalizationPass(config)
|
||||||
backend_func = TestBackend(*passes, func_pass)
|
backend_func = TestBackend(*passes, func_pass)
|
||||||
backend_no_func = TestBackend(*passes)
|
backend_no_func = TestBackend(*passes)
|
||||||
|
@ -5,23 +5,25 @@ import torch
|
|||||||
from compressed_tensors.quantization import FP8_DTYPE
|
from compressed_tensors.quantization import FP8_DTYPE
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.plugins
|
||||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||||
FusionPass, QuantKey)
|
FusionPass, QuantKey)
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear)
|
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
|
def __init__(self, hidden_size: int, eps: float, static: bool,
|
||||||
**kwargs):
|
cutlass_fp8_enabled: bool, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||||
if static:
|
if static:
|
||||||
@ -41,7 +43,8 @@ class TestModel(torch.nn.Module):
|
|||||||
self.w[0],
|
self.w[0],
|
||||||
self.wscale[0],
|
self.wscale[0],
|
||||||
self.scale[0],
|
self.scale[0],
|
||||||
use_per_token_if_dynamic=True)
|
use_per_token_if_dynamic=True,
|
||||||
|
cutlass_fp8_supported=self.cutlass_fp8_enabled)
|
||||||
# make sure resid is used for replacement to work
|
# make sure resid is used for replacement to work
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
@ -49,7 +52,8 @@ class TestModel(torch.nn.Module):
|
|||||||
self.w[1],
|
self.w[1],
|
||||||
self.wscale[1],
|
self.wscale[1],
|
||||||
self.scale[1],
|
self.scale[1],
|
||||||
use_per_token_if_dynamic=True)
|
use_per_token_if_dynamic=True,
|
||||||
|
cutlass_fp8_supported=self.cutlass_fp8_enabled)
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
return y3
|
return y3
|
||||||
|
|
||||||
@ -59,60 +63,67 @@ class TestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
@pytest.mark.parametrize("static", [True, False])
|
@pytest.mark.parametrize("static", [True, False])
|
||||||
|
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||||
|
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
|
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||||
|
cutlass_fp8_enabled):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||||
|
|
||||||
# Reshape pass is needed for the fusion pass to work
|
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||||
enable_reshape=True)
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
reshape_pass = RedundantReshapesPass(config)
|
# Reshape pass is needed for the fusion pass to work
|
||||||
fusion_pass = FusionPass.instance(config)
|
config = CompilationConfig.PassConfig(enable_fusion=True,
|
||||||
|
enable_noop=True)
|
||||||
|
noop_pass = NoOpEliminationPass(config)
|
||||||
|
fusion_pass = FusionPass.instance(config)
|
||||||
|
|
||||||
backend = TestBackend(reshape_pass, fusion_pass)
|
backend = TestBackend(noop_pass, fusion_pass)
|
||||||
model = TestModel(hidden_size, eps, static)
|
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
x = torch.rand(num_tokens, hidden_size)
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
|
|
||||||
result = model(x)
|
result = model(x)
|
||||||
|
|
||||||
model2 = torch.compile(model, backend=backend)
|
model2 = torch.compile(model, backend=backend)
|
||||||
result2 = model2(x)
|
result2 = model2(x)
|
||||||
|
|
||||||
# Higher tol for dynamic, even higher for bfloat16
|
# Higher tol for dynamic, even higher for bfloat16
|
||||||
if static:
|
if static:
|
||||||
ATOL, RTOL = (1e-3, 1e-3)
|
ATOL, RTOL = (1e-3, 1e-3)
|
||||||
elif dtype == torch.float16:
|
elif dtype == torch.float16:
|
||||||
ATOL, RTOL = (2e-3, 2e-3)
|
ATOL, RTOL = (2e-3, 2e-3)
|
||||||
else:
|
else:
|
||||||
ATOL, RTOL = (1e-2, 1e-2)
|
ATOL, RTOL = (1e-2, 1e-2)
|
||||||
|
|
||||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
# Check substitution worked
|
# Check substitution worked
|
||||||
pre_nodes = backend.graph_pre_pass.nodes
|
pre_nodes = backend.graph_pre_pass.nodes
|
||||||
post_nodes = backend.graph_post_pass.nodes
|
post_nodes = backend.graph_post_pass.nodes
|
||||||
|
|
||||||
# static is per-tensor, dynamic is per-token
|
# static is per-tensor, dynamic is per-token
|
||||||
key = QuantKey(dtype=FP8_DTYPE,
|
key = QuantKey(dtype=FP8_DTYPE,
|
||||||
static=static,
|
static=static,
|
||||||
per_tensor=static,
|
per_tensor=static,
|
||||||
symmetric=True)
|
symmetric=True)
|
||||||
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
|
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
|
||||||
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
|
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
|
||||||
fp8_quant = QUANT_OPS[key]
|
fp8_quant = QUANT_OPS[key]
|
||||||
|
|
||||||
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||||
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
||||||
find_auto_fn(pre_nodes, fp8_quant)
|
find_auto_fn(pre_nodes, fp8_quant)
|
||||||
|
|
||||||
# In post-nodes, fused kernels should be present and fp8 quant should not
|
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||||
find_auto_fn(post_nodes, rms_quant)
|
find_auto_fn(post_nodes, rms_quant)
|
||||||
find_auto_fn(post_nodes, add_rms_quant)
|
find_auto_fn(post_nodes, add_rms_quant)
|
||||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
||||||
|
135
vllm/compilation/noop_elimination.py
Normal file
135
vllm/compilation/noop_elimination.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Iterable, Union
|
||||||
|
|
||||||
|
import torch.fx
|
||||||
|
from torch import SymInt
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .fx_utils import is_func
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpEliminationPass(VllmInductorPass):
|
||||||
|
"""
|
||||||
|
This is an inductor pass that removes redundant reshape/slice operations.
|
||||||
|
It is required for RMSNorm-quant fusion to work properly.
|
||||||
|
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||||
|
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||||
|
not handle certain slice variants.
|
||||||
|
|
||||||
|
Example graph 1:
|
||||||
|
getitem_1: "f16[s0, 4096]" = ...
|
||||||
|
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||||
|
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||||
|
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||||
|
|
||||||
|
Can be replaced with:
|
||||||
|
getitem_1: "f16[s0, 4096]" = ...
|
||||||
|
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||||
|
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||||
|
|
||||||
|
Example graph 2:
|
||||||
|
arg0: "s0" = SymInt(s0)
|
||||||
|
scaled_mm: "f16[s0, 4096]" = ...
|
||||||
|
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||||
|
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||||
|
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||||
|
|
||||||
|
Can be replaced with:
|
||||||
|
arg0: "s0" = SymInt(s0)
|
||||||
|
scaled_mm: "f16[s0, 4096]" = ...
|
||||||
|
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||||
|
out: "f16[s0, 4096]" = at[1]
|
||||||
|
|
||||||
|
TODO(luka): This is currently tested in test_fusion,
|
||||||
|
but separate tests could be good.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.begin()
|
||||||
|
self.dump_graph(graph, "before_noop_elimination")
|
||||||
|
count = 0
|
||||||
|
# Remove no-op reshapes/views:
|
||||||
|
for node in graph.nodes:
|
||||||
|
if is_func(node, torch.ops.aten.reshape.default):
|
||||||
|
input, shape = node.args[:2]
|
||||||
|
input_shape = input.meta["val"].shape
|
||||||
|
if len(shape) != len(input_shape):
|
||||||
|
# Reshape changing rank, skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
if shape.count(-1) > 1:
|
||||||
|
# Invalid reshape args, skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.all_dims_equivalent(shape, input_shape):
|
||||||
|
node.replace_all_uses_with(input)
|
||||||
|
graph.erase_node(node)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
elif is_func(node, torch.ops.aten.slice.Tensor):
|
||||||
|
input, dim_index, start, end = node.args[:4]
|
||||||
|
input_shape = input.meta["val"].shape
|
||||||
|
i_dim = input_shape[dim_index]
|
||||||
|
|
||||||
|
if start == 0 and self.dims_equivalent(end, i_dim):
|
||||||
|
node.replace_all_uses_with(input)
|
||||||
|
graph.erase_node(node)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||||
|
base, view, dim_index, start, end = node.args[:5]
|
||||||
|
base_shape = base.meta["val"].shape
|
||||||
|
view_shape = view.meta["val"].shape
|
||||||
|
|
||||||
|
view_dim = view_shape[dim_index]
|
||||||
|
|
||||||
|
# Check that view fully covers base and the full view is used
|
||||||
|
# (if the view fully covered the base after slicing but was not
|
||||||
|
# fully used, we could replace slice_scatter with a simple slice
|
||||||
|
# but that's a niche case).
|
||||||
|
if (base_shape == view_shape and start == 0
|
||||||
|
and self.dims_equivalent(end, view_dim)):
|
||||||
|
node.replace_all_uses_with(view)
|
||||||
|
graph.erase_node(node)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||||
|
self.dump_graph(graph, "after_noop_elimination")
|
||||||
|
self.end_and_log()
|
||||||
|
|
||||||
|
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
|
||||||
|
i_dims: Iterable[Union[int, SymInt]]):
|
||||||
|
return all(
|
||||||
|
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||||
|
|
||||||
|
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||||
|
i_dim: Union[int, SymInt]) -> bool:
|
||||||
|
"""
|
||||||
|
This function checks if two dimensions are equivalent.
|
||||||
|
:param dim: The dimension arg to reshape/slice
|
||||||
|
:param i_dim: The corresponding dimension in the input tensor
|
||||||
|
:return: Are the dimensions equivalent?
|
||||||
|
|
||||||
|
There are three cases in which the dimensions are equivalent:
|
||||||
|
1. The dimensions are equal (both integers)
|
||||||
|
2. The reshape dimension is -1 (i.e. inferred)
|
||||||
|
3. The dimensions both correspond to the same SymInt
|
||||||
|
|
||||||
|
While case 2 does not guarantee the dimensions are equal,
|
||||||
|
they are equal if all other dimensions are equal.
|
||||||
|
|
||||||
|
In case 3, the reshape dimension is a torch.fx.Node,
|
||||||
|
and its value is a SymInt. That value is equal to the
|
||||||
|
input dimension.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Case 1 and 2
|
||||||
|
if dim == i_dim or dim == -1:
|
||||||
|
return True
|
||||||
|
# Case 3
|
||||||
|
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
@ -11,7 +11,7 @@ from vllm.logger import init_logger
|
|||||||
from .fix_functionalization import FixFunctionalizationPass
|
from .fix_functionalization import FixFunctionalizationPass
|
||||||
from .fusion import FusionPass
|
from .fusion import FusionPass
|
||||||
from .inductor_pass import InductorPass
|
from .inductor_pass import InductorPass
|
||||||
from .reshapes import RedundantReshapesPass
|
from .noop_elimination import NoOpEliminationPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
|
|||||||
|
|
||||||
The order of the post-grad post-passes is:
|
The order of the post-grad post-passes is:
|
||||||
1. passes (constructor parameter)
|
1. passes (constructor parameter)
|
||||||
2. default passes (RedundantReshapesPass, FusionPass)
|
2. default passes (NoopEliminationPass, FusionPass)
|
||||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||||
4. fix_functionalization
|
4. fix_functionalization
|
||||||
This way, all passes operate on a functionalized graph.
|
This way, all passes operate on a functionalized graph.
|
||||||
@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
|
|||||||
|
|
||||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
def configure(self, pass_config: CompilationConfig.PassConfig):
|
||||||
self.pass_config = pass_config
|
self.pass_config = pass_config
|
||||||
if pass_config.enable_reshape:
|
if pass_config.enable_noop:
|
||||||
self.passes += [RedundantReshapesPass(pass_config)]
|
self.passes += [NoOpEliminationPass(pass_config)]
|
||||||
|
|
||||||
if pass_config.enable_fusion:
|
if pass_config.enable_fusion:
|
||||||
self.passes += [FusionPass.instance(pass_config)]
|
self.passes += [FusionPass.instance(pass_config)]
|
||||||
|
@ -1,90 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch.fx
|
|
||||||
from torch import SymInt
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
|
|
||||||
from .fx_utils import is_func
|
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RedundantReshapesPass(VllmInductorPass):
|
|
||||||
"""
|
|
||||||
This is an inductor pass that removes redundant reshape operations.
|
|
||||||
It is required for RMSNorm-quant fusion to work properly.
|
|
||||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
|
||||||
in the 2D-case.
|
|
||||||
|
|
||||||
Example graph:
|
|
||||||
|
|
||||||
getitem_1: "f16[s0, 4096]" = ...
|
|
||||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
|
||||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
|
||||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
|
||||||
|
|
||||||
Can be replaced with:
|
|
||||||
getitem_1: "f16[s0, 4096]" = ...
|
|
||||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
|
||||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
|
||||||
self.begin()
|
|
||||||
self.dump_graph(graph, "before_reshapes")
|
|
||||||
count = 0
|
|
||||||
# Remove no-op reshapes/views:
|
|
||||||
for node in graph.nodes:
|
|
||||||
if is_func(node, torch.ops.aten.reshape.default):
|
|
||||||
input, shape = node.args[:2]
|
|
||||||
input_shape = input.meta["val"].shape
|
|
||||||
if len(shape) != len(input_shape):
|
|
||||||
# Reshape changing rank, skip
|
|
||||||
continue
|
|
||||||
|
|
||||||
if shape.count(-1) > 1:
|
|
||||||
# Invalid reshape args, skip
|
|
||||||
continue
|
|
||||||
|
|
||||||
if all(
|
|
||||||
self.dims_equivalent(s, i_s)
|
|
||||||
for s, i_s in zip(shape, input_shape)):
|
|
||||||
node.replace_all_uses_with(input)
|
|
||||||
graph.erase_node(node)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
logger.debug("Removed %s no-op reshapes", count)
|
|
||||||
|
|
||||||
self.dump_graph(graph, "after_reshapes")
|
|
||||||
self.end_and_log()
|
|
||||||
|
|
||||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
|
||||||
i_dim: Union[int, SymInt]) -> bool:
|
|
||||||
"""
|
|
||||||
This function checks if two dimensions are equivalent.
|
|
||||||
:param dim: The dimension arg to reshape
|
|
||||||
:param i_dim: The corresponding dimension in the input tensor
|
|
||||||
:return: Are the dimensions equivalent?
|
|
||||||
|
|
||||||
There are three cases in which the dimensions are equivalent:
|
|
||||||
1. The dimensions are equal (both integers)
|
|
||||||
2. The reshape dimension is -1 (i.e. inferred)
|
|
||||||
3. The dimensions both correspond to the same SymInt
|
|
||||||
|
|
||||||
While case 2 does not guarantee the dimensions are equal,
|
|
||||||
they are equal if all other dimensions are equal.
|
|
||||||
|
|
||||||
In case 3, the reshape dimension is a torch.fx.Node,
|
|
||||||
and its value is a SymInt. That value is equal to the
|
|
||||||
input dimension.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Case 1 and 2
|
|
||||||
if dim == i_dim or dim == -1:
|
|
||||||
return True
|
|
||||||
# Case 3
|
|
||||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
|
@ -28,8 +28,8 @@ class VllmInductorPass(InductorPass):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.pass_name = self.__class__.__name__
|
self.pass_name = self.__class__.__name__
|
||||||
|
|
||||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
|
||||||
if stage in self.config.dump_graph_stages:
|
if stage in self.config.dump_graph_stages or always:
|
||||||
# Make sure filename includes rank in the distributed setting
|
# Make sure filename includes rank in the distributed setting
|
||||||
parallel = p_is_init() and get_tp_world_size() > 1
|
parallel = p_is_init() and get_tp_world_size() > 1
|
||||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||||
@ -49,3 +49,17 @@ class VllmInductorPass(InductorPass):
|
|||||||
self._end_time = time.perf_counter_ns()
|
self._end_time = time.perf_counter_ns()
|
||||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||||
|
|
||||||
|
|
||||||
|
class PrinterInductorPass(VllmInductorPass):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
name: str,
|
||||||
|
config: CompilationConfig.PassConfig,
|
||||||
|
always=False):
|
||||||
|
super().__init__(config)
|
||||||
|
self.name = name
|
||||||
|
self.always = always
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.dump_graph(graph, self.name, always=self.always)
|
||||||
|
@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
|
|||||||
Each pass defines its own stages (before, after, maybe in-between).
|
Each pass defines its own stages (before, after, maybe in-between).
|
||||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
- dump_graph_dir: directory to dump the graphs. Default is .
|
||||||
- enable_fusion: whether to enable the custom fusion pass.
|
- enable_fusion: whether to enable the custom fusion pass.
|
||||||
- enable_reshape: whether to enable the custom reshape elimination pass.
|
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||||
TODO better pass enabling system.
|
TODO(luka) better pass enabling system.
|
||||||
"""
|
"""
|
||||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||||
dump_graph_dir: Path = Field(default=Path("."))
|
dump_graph_dir: Path = Field(default=Path("."))
|
||||||
enable_fusion: bool = True
|
enable_fusion: bool = True
|
||||||
enable_reshape: bool = True
|
enable_noop: bool = True
|
||||||
|
|
||||||
def uuid(self):
|
def uuid(self):
|
||||||
"""
|
"""
|
||||||
@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
|
|||||||
Do not include dump_graph_* in the hash - they don't affect
|
Do not include dump_graph_* in the hash - they don't affect
|
||||||
compilation.
|
compilation.
|
||||||
"""
|
"""
|
||||||
dict_ = self.model_dump(
|
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
|
||||||
include={"enable_fusion", "enable_reshape"})
|
|
||||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||||
return hashlib.sha256(encoded).digest()
|
return hashlib.sha256(encoded).digest()
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if not self.enable_reshape and self.enable_fusion:
|
if not self.enable_noop and self.enable_fusion:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Fusion enabled but reshape elimination disabled. "
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
"RMSNorm + quant (fp8) fusion might not work")
|
"RMSNorm + quant (fp8) fusion might not work")
|
||||||
@ -3411,7 +3410,7 @@ class VllmConfig:
|
|||||||
self.compilation_config.use_inductor = True
|
self.compilation_config.use_inductor = True
|
||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
self.compilation_config.pass_config.enable_fusion = False
|
self.compilation_config.pass_config.enable_fusion = False
|
||||||
self.compilation_config.pass_config.enable_reshape = False
|
self.compilation_config.pass_config.enable_noop = False
|
||||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||||
|
|
||||||
self._set_cudagraph_sizes()
|
self._set_cudagraph_sizes()
|
||||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
@ -161,10 +162,14 @@ def apply_fp8_linear(
|
|||||||
# Note: we pad the input because torch._scaled_mm is more performant
|
# Note: we pad the input because torch._scaled_mm is more performant
|
||||||
# for matrices with batch dimension > 16.
|
# for matrices with batch dimension > 16.
|
||||||
# This could change in the future.
|
# This could change in the future.
|
||||||
|
# We also don't pad when using torch.compile,
|
||||||
|
# as it breaks with dynamic shapes.
|
||||||
|
config = get_current_vllm_config().compilation_config
|
||||||
|
do_pad = config.level < CompilationLevel.PIECEWISE
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
num_token_padding=17,
|
num_token_padding=17 if do_pad else None,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||||
|
|
||||||
per_tensor_weights = (weight_scale.numel() == 1)
|
per_tensor_weights = (weight_scale.numel() == 1)
|
||||||
|
Reference in New Issue
Block a user