[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-02-28 18:20:11 -05:00
committed by GitHub
parent 084bbac8cc
commit bd56c983d6
9 changed files with 239 additions and 160 deletions

View File

@ -13,21 +13,26 @@ class TestBackend:
This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection.
Inductor config can be modified directly by editing the inductor_config
property. This can be helpful for adding passes like the
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
"""
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass
self.inductor_config = config.shallow_copy_dict()
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)
config_patches=self.inductor_config)
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)

View File

@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig
from .backend import TestBackend
@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

View File

@ -5,23 +5,25 @@ import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
**kwargs):
def __init__(self, hidden_size: int, eps: float, static: bool,
cutlass_fp8_enabled: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
if static:
@ -41,7 +43,8 @@ class TestModel(torch.nn.Module):
self.w[0],
self.wscale[0],
self.scale[0],
use_per_token_if_dynamic=True)
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
@ -49,7 +52,8 @@ class TestModel(torch.nn.Module):
self.w[1],
self.wscale[1],
self.scale[1],
use_per_token_if_dynamic=True)
use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
@ -59,60 +63,67 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps, static)
backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
# static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]
# static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]
# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
# In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

View File

@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Union
import torch.fx
from torch import SymInt
from vllm.logger import init_logger
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Example graph 1:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 2:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
TODO(luka): This is currently tested in test_fusion,
but separate tests could be good.
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_noop_elimination")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if self.all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
i_dim = input_shape[dim_index]
if start == 0 and self.dims_equivalent(end, i_dim):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
view_dim = view_shape[dim_index]
# Check that view fully covers base and the full view is used
# (if the view fully covered the base after slicing but was not
# fully used, we could replace slice_scatter with a simple slice
# but that's a niche case).
if (base_shape == view_shape and start == 0
and self.dims_equivalent(end, view_dim)):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
i_dims: Iterable[Union[int, SymInt]]):
return all(
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred)
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

View File

@ -11,7 +11,7 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .reshapes import RedundantReshapesPass
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
The order of the post-grad post-passes is:
1. passes (constructor parameter)
2. default passes (RedundantReshapesPass, FusionPass)
2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization
This way, all passes operate on a functionalized graph.
@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config
if pass_config.enable_reshape:
self.passes += [RedundantReshapesPass(pass_config)]
if pass_config.enable_noop:
self.passes += [NoOpEliminationPass(pass_config)]
if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)]

View File

@ -1,90 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch.fx
from torch import SymInt
from vllm.logger import init_logger
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class RedundantReshapesPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case.
Example graph:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_reshapes")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if all(
self.dims_equivalent(s, i_s)
for s, i_s in zip(shape, input_shape)):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes", count)
self.dump_graph(graph, "after_reshapes")
self.end_and_log()
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred)
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

View File

@ -28,8 +28,8 @@ class VllmInductorPass(InductorPass):
self.config = config
self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str):
if stage in self.config.dump_graph_stages:
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
if stage in self.config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
@ -49,3 +49,17 @@ class VllmInductorPass(InductorPass):
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class PrinterInductorPass(VllmInductorPass):
def __init__(self,
name: str,
config: CompilationConfig.PassConfig,
always=False):
super().__init__(config)
self.name = name
self.always = always
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name, always=self.always)

View File

@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass.
- enable_reshape: whether to enable the custom reshape elimination pass.
TODO better pass enabling system.
- enable_noop: whether to enable the custom no-op elimination pass.
TODO(luka) better pass enabling system.
"""
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
enable_reshape: bool = True
enable_noop: bool = True
def uuid(self):
"""
@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect
compilation.
"""
dict_ = self.model_dump(
include={"enable_fusion", "enable_reshape"})
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest()
def model_post_init(self, __context: Any) -> None:
if not self.enable_reshape and self.enable_fusion:
if not self.enable_noop and self.enable_fusion:
logger.warning_once(
"Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work")
@ -3411,7 +3410,7 @@ class VllmConfig:
self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.pass_config.enable_noop = False
self.compilation_config.level = CompilationLevel.PIECEWISE
self._set_cudagraph_sizes()

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting
@ -161,10 +162,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
do_pad = config.level < CompilationLevel.PIECEWISE
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=17,
num_token_padding=17 if do_pad else None,
use_per_token_if_dynamic=use_per_token_if_dynamic)
per_tensor_weights = (weight_scale.numel() == 1)