[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-02-28 18:20:11 -05:00
committed by GitHub
parent 084bbac8cc
commit bd56c983d6
9 changed files with 239 additions and 160 deletions

View File

@ -13,21 +13,26 @@ class TestBackend:
This class provides a simple Inductor backend that can be used for testing. This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes. It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection. It also saves the graph before and after the custom passes for inspection.
Inductor config can be modified directly by editing the inductor_config
property. This can be helpful for adding passes like the
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
""" """
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]): None]]):
self.custom_passes = list(passes) self.custom_passes = list(passes)
from torch._inductor import config from torch._inductor import config
self.current_config = config.shallow_copy_dict() self.inductor_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True self.inductor_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
def __call__(self, graph: fx.GraphModule, example_inputs): def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
return compile_fx(graph, return compile_fx(graph,
example_inputs, example_inputs,
config_patches=self.current_config) config_patches=self.inductor_config)
def post_pass(self, graph: fx.Graph): def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph) self.graph_pre_pass = deepcopy(graph)

View File

@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym) kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.reshapes import RedundantReshapesPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig from vllm.config import CompilationConfig
from .backend import TestBackend from .backend import TestBackend
@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda") torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion, config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_reshape=True) enable_noop=True)
reshape_pass = RedundantReshapesPass(config) noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config) fusion_pass = FusionPass.instance(config)
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(config) func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass) backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes) backend_no_func = TestBackend(*passes)

View File

@ -5,23 +5,25 @@ import torch
from compressed_tensors.quantization import FP8_DTYPE from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey) FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.reshapes import RedundantReshapesPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear) CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
from .backend import TestBackend from .backend import TestBackend
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, static: bool, *args, def __init__(self, hidden_size: int, eps: float, static: bool,
**kwargs): cutlass_fp8_enabled: bool, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
if static: if static:
@ -41,7 +43,8 @@ class TestModel(torch.nn.Module):
self.w[0], self.w[0],
self.wscale[0], self.wscale[0],
self.scale[0], self.scale[0],
use_per_token_if_dynamic=True) use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
@ -49,7 +52,8 @@ class TestModel(torch.nn.Module):
self.w[1], self.w[1],
self.wscale[1], self.wscale[1],
self.scale[1], self.scale[1],
use_per_token_if_dynamic=True) use_per_token_if_dynamic=True,
cutlass_fp8_supported=self.cutlass_fp8_enabled)
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
return y3 return y3
@ -59,60 +63,67 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA") reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
# Reshape pass is needed for the fusion pass to work vllm_config = VllmConfig(compilation_config=CompilationConfig(
config = CompilationConfig.PassConfig(enable_fusion=True, level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
enable_reshape=True) with vllm.config.set_current_vllm_config(vllm_config):
reshape_pass = RedundantReshapesPass(config) # Reshape pass is needed for the fusion pass to work
fusion_pass = FusionPass.instance(config) config = CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
noop_pass = NoOpEliminationPass(config)
fusion_pass = FusionPass.instance(config)
backend = TestBackend(reshape_pass, fusion_pass) backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static) model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
# First dimension dynamic # First dimension dynamic
x = torch.rand(num_tokens, hidden_size) x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
result = model(x) result = model(x)
model2 = torch.compile(model, backend=backend) model2 = torch.compile(model, backend=backend)
result2 = model2(x) result2 = model2(x)
# Higher tol for dynamic, even higher for bfloat16 # Higher tol for dynamic, even higher for bfloat16
if static: if static:
ATOL, RTOL = (1e-3, 1e-3) ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16: elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3) ATOL, RTOL = (2e-3, 2e-3)
else: else:
ATOL, RTOL = (1e-2, 1e-2) ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# Check substitution worked # Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes post_nodes = backend.graph_post_pass.nodes
# static is per-tensor, dynamic is per-token # static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE, key = QuantKey(dtype=FP8_DTYPE,
static=static, static=static,
per_tensor=static, per_tensor=static,
symmetric=True) symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key] fp8_quant = QUANT_OPS[key]
# In pre-nodes, fp8 quant should be present and fused kernels should not # In pre-nodes, fp8 quant should be there and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant) find_auto_fn(pre_nodes, fp8_quant)
# In post-nodes, fused kernels should be present and fp8 quant should not # In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn(post_nodes, rms_quant) find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant) find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

View File

@ -0,0 +1,135 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Union
import torch.fx
from torch import SymInt
from vllm.logger import init_logger
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class NoOpEliminationPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape/slice operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case. Additionally, torch internal no-op elimination pass does
not handle certain slice variants.
Example graph 1:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Example graph 2:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
Can be replaced with:
arg0: "s0" = SymInt(s0)
scaled_mm: "f16[s0, 4096]" = ...
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
out: "f16[s0, 4096]" = at[1]
TODO(luka): This is currently tested in test_fusion,
but separate tests could be good.
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_noop_elimination")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if self.all_dims_equivalent(shape, input_shape):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice.Tensor):
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
i_dim = input_shape[dim_index]
if start == 0 and self.dims_equivalent(end, i_dim):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
elif is_func(node, torch.ops.aten.slice_scatter.default):
base, view, dim_index, start, end = node.args[:5]
base_shape = base.meta["val"].shape
view_shape = view.meta["val"].shape
view_dim = view_shape[dim_index]
# Check that view fully covers base and the full view is used
# (if the view fully covered the base after slicing but was not
# fully used, we could replace slice_scatter with a simple slice
# but that's a niche case).
if (base_shape == view_shape and start == 0
and self.dims_equivalent(end, view_dim)):
node.replace_all_uses_with(view)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes and slices", count)
self.dump_graph(graph, "after_noop_elimination")
self.end_and_log()
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
i_dims: Iterable[Union[int, SymInt]]):
return all(
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape/slice
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred)
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

View File

@ -11,7 +11,7 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass from .fusion import FusionPass
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
from .reshapes import RedundantReshapesPass from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__) logger = init_logger(__name__)
@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
The order of the post-grad post-passes is: The order of the post-grad post-passes is:
1. passes (constructor parameter) 1. passes (constructor parameter)
2. default passes (RedundantReshapesPass, FusionPass) 2. default passes (NoopEliminationPass, FusionPass)
3. config["post_grad_custom_post_pass"] (if it exists) 3. config["post_grad_custom_post_pass"] (if it exists)
4. fix_functionalization 4. fix_functionalization
This way, all passes operate on a functionalized graph. This way, all passes operate on a functionalized graph.
@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
def configure(self, pass_config: CompilationConfig.PassConfig): def configure(self, pass_config: CompilationConfig.PassConfig):
self.pass_config = pass_config self.pass_config = pass_config
if pass_config.enable_reshape: if pass_config.enable_noop:
self.passes += [RedundantReshapesPass(pass_config)] self.passes += [NoOpEliminationPass(pass_config)]
if pass_config.enable_fusion: if pass_config.enable_fusion:
self.passes += [FusionPass.instance(pass_config)] self.passes += [FusionPass.instance(pass_config)]

View File

@ -1,90 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Union
import torch.fx
from torch import SymInt
from vllm.logger import init_logger
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
class RedundantReshapesPass(VllmInductorPass):
"""
This is an inductor pass that removes redundant reshape operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case.
Example graph:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
"""
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_reshapes")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if all(
self.dims_equivalent(s, i_s)
for s, i_s in zip(shape, input_shape)):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
logger.debug("Removed %s no-op reshapes", count)
self.dump_graph(graph, "after_reshapes")
self.end_and_log()
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred)
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

View File

@ -28,8 +28,8 @@ class VllmInductorPass(InductorPass):
self.config = config self.config = config
self.pass_name = self.__class__.__name__ self.pass_name = self.__class__.__name__
def dump_graph(self, graph: torch.fx.Graph, stage: str): def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
if stage in self.config.dump_graph_stages: if stage in self.config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting # Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1 parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else "" rank = f"-{get_tp_rank()}" if parallel else ""
@ -49,3 +49,17 @@ class VllmInductorPass(InductorPass):
self._end_time = time.perf_counter_ns() self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6 duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
class PrinterInductorPass(VllmInductorPass):
def __init__(self,
name: str,
config: CompilationConfig.PassConfig,
always=False):
super().__init__(config)
self.name = name
self.always = always
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name, always=self.always)

View File

@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
Each pass defines its own stages (before, after, maybe in-between). Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graphs. Default is . - dump_graph_dir: directory to dump the graphs. Default is .
- enable_fusion: whether to enable the custom fusion pass. - enable_fusion: whether to enable the custom fusion pass.
- enable_reshape: whether to enable the custom reshape elimination pass. - enable_noop: whether to enable the custom no-op elimination pass.
TODO better pass enabling system. TODO(luka) better pass enabling system.
""" """
dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path(".")) dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True enable_fusion: bool = True
enable_reshape: bool = True enable_noop: bool = True
def uuid(self): def uuid(self):
""" """
@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
Do not include dump_graph_* in the hash - they don't affect Do not include dump_graph_* in the hash - they don't affect
compilation. compilation.
""" """
dict_ = self.model_dump( dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
include={"enable_fusion", "enable_reshape"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).digest() return hashlib.sha256(encoded).digest()
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
if not self.enable_reshape and self.enable_fusion: if not self.enable_noop and self.enable_fusion:
logger.warning_once( logger.warning_once(
"Fusion enabled but reshape elimination disabled. " "Fusion enabled but reshape elimination disabled. "
"RMSNorm + quant (fp8) fusion might not work") "RMSNorm + quant (fp8) fusion might not work")
@ -3411,7 +3410,7 @@ class VllmConfig:
self.compilation_config.use_inductor = True self.compilation_config.use_inductor = True
self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_reshape = False self.compilation_config.pass_config.enable_noop = False
self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.level = CompilationLevel.PIECEWISE
self._set_cudagraph_sizes() self._set_cudagraph_sizes()

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
@ -161,10 +162,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant # Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16. # for matrices with batch dimension > 16.
# This could change in the future. # This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
do_pad = config.level < CompilationLevel.PIECEWISE
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = ops.scaled_fp8_quant(
input_2d, input_2d,
input_scale, input_scale,
num_token_padding=17, num_token_padding=17 if do_pad else None,
use_per_token_if_dynamic=use_per_token_if_dynamic) use_per_token_if_dynamic=use_per_token_if_dynamic)
per_tensor_weights = (weight_scale.numel() == 1) per_tensor_weights = (weight_scale.numel() == 1)