mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@ -13,21 +13,26 @@ class TestBackend:
|
||||
This class provides a simple Inductor backend that can be used for testing.
|
||||
It takes a list of custom passes and runs them after Inductor's passes.
|
||||
It also saves the graph before and after the custom passes for inspection.
|
||||
|
||||
Inductor config can be modified directly by editing the inductor_config
|
||||
property. This can be helpful for adding passes like the
|
||||
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
|
||||
"""
|
||||
|
||||
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
|
||||
None]]):
|
||||
self.custom_passes = list(passes)
|
||||
from torch._inductor import config
|
||||
self.current_config = config.shallow_copy_dict()
|
||||
self.current_config['force_disable_caches'] = True
|
||||
self.current_config['post_grad_custom_post_pass'] = self.post_pass
|
||||
self.inductor_config = config.shallow_copy_dict()
|
||||
self.inductor_config['force_disable_caches'] = True
|
||||
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
self.graph_pre_compile = deepcopy(graph)
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
return compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=self.current_config)
|
||||
config_patches=self.inductor_config)
|
||||
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
|
@ -9,7 +9,7 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
from .backend import TestBackend
|
||||
@ -50,11 +50,11 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||
enable_reshape=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
enable_noop=True)
|
||||
noop_pass = NoOpEliminationPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
|
||||
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||
func_pass = FixFunctionalizationPass(config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
@ -5,23 +5,25 @@ import torch
|
||||
from compressed_tensors.quantization import FP8_DTYPE
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, QuantKey)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear)
|
||||
CUTLASS_FP8_SUPPORTED, apply_fp8_linear, maybe_create_device_identity)
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
|
||||
**kwargs):
|
||||
def __init__(self, hidden_size: int, eps: float, static: bool,
|
||||
cutlass_fp8_enabled: bool, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
if static:
|
||||
@ -41,7 +43,8 @@ class TestModel(torch.nn.Module):
|
||||
self.w[0],
|
||||
self.wscale[0],
|
||||
self.scale[0],
|
||||
use_per_token_if_dynamic=True)
|
||||
use_per_token_if_dynamic=True,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_enabled)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
@ -49,7 +52,8 @@ class TestModel(torch.nn.Module):
|
||||
self.w[1],
|
||||
self.wscale[1],
|
||||
self.scale[1],
|
||||
use_per_token_if_dynamic=True)
|
||||
use_per_token_if_dynamic=True,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_enabled)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
@ -59,60 +63,67 @@ class TestModel(torch.nn.Module):
|
||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
cutlass_fp8_enabled):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_reshape=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_noop=True)
|
||||
noop_pass = NoOpEliminationPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
backend = TestBackend(reshape_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps, static)
|
||||
backend = TestBackend(noop_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
result = model(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
|
||||
# Higher tol for dynamic, even higher for bfloat16
|
||||
if static:
|
||||
ATOL, RTOL = (1e-3, 1e-3)
|
||||
elif dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
# Higher tol for dynamic, even higher for bfloat16
|
||||
if static:
|
||||
ATOL, RTOL = (1e-3, 1e-3)
|
||||
elif dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend.graph_pre_pass.nodes
|
||||
post_nodes = backend.graph_post_pass.nodes
|
||||
# Check substitution worked
|
||||
pre_nodes = backend.graph_pre_pass.nodes
|
||||
post_nodes = backend.graph_post_pass.nodes
|
||||
|
||||
# static is per-tensor, dynamic is per-token
|
||||
key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
symmetric=True)
|
||||
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
|
||||
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
|
||||
fp8_quant = QUANT_OPS[key]
|
||||
# static is per-tensor, dynamic is per-token
|
||||
key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
symmetric=True)
|
||||
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
|
||||
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
|
||||
fp8_quant = QUANT_OPS[key]
|
||||
|
||||
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
||||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
||||
find_auto_fn(pre_nodes, fp8_quant)
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
||||
find_auto_fn(pre_nodes, fp8_quant)
|
||||
|
||||
# In post-nodes, fused kernels should be present and fp8 quant should not
|
||||
find_auto_fn(post_nodes, rms_quant)
|
||||
find_auto_fn(post_nodes, add_rms_quant)
|
||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
find_auto_fn(post_nodes, rms_quant)
|
||||
find_auto_fn(post_nodes, add_rms_quant)
|
||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
||||
|
135
vllm/compilation/noop_elimination.py
Normal file
135
vllm/compilation/noop_elimination.py
Normal file
@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Example graph 1:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 2:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
|
||||
TODO(luka): This is currently tested in test_fusion,
|
||||
but separate tests could be good.
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_noop_elimination")
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if self.all_dims_equivalent(shape, input_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice.Tensor):
|
||||
input, dim_index, start, end = node.args[:4]
|
||||
input_shape = input.meta["val"].shape
|
||||
i_dim = input_shape[dim_index]
|
||||
|
||||
if start == 0 and self.dims_equivalent(end, i_dim):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
view_dim = view_shape[dim_index]
|
||||
|
||||
# Check that view fully covers base and the full view is used
|
||||
# (if the view fully covered the base after slicing but was not
|
||||
# fully used, we could replace slice_scatter with a simple slice
|
||||
# but that's a niche case).
|
||||
if (base_shape == view_shape and start == 0
|
||||
and self.dims_equivalent(end, view_dim)):
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
self.dump_graph(graph, "after_noop_elimination")
|
||||
self.end_and_log()
|
||||
|
||||
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
|
||||
i_dims: Iterable[Union[int, SymInt]]):
|
||||
return all(
|
||||
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
|
||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
@ -11,7 +11,7 @@ from vllm.logger import init_logger
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import InductorPass
|
||||
from .reshapes import RedundantReshapesPass
|
||||
from .noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -36,7 +36,7 @@ class PostGradPassManager(Parent):
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (RedundantReshapesPass, FusionPass)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
@ -54,8 +54,8 @@ class PostGradPassManager(Parent):
|
||||
|
||||
def configure(self, pass_config: CompilationConfig.PassConfig):
|
||||
self.pass_config = pass_config
|
||||
if pass_config.enable_reshape:
|
||||
self.passes += [RedundantReshapesPass(pass_config)]
|
||||
if pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(pass_config)]
|
||||
|
||||
if pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(pass_config)]
|
||||
|
@ -1,90 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .fx_utils import is_func
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RedundantReshapesPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case.
|
||||
|
||||
Example graph:
|
||||
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_reshapes")
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if all(
|
||||
self.dims_equivalent(s, i_s)
|
||||
for s, i_s in zip(shape, input_shape)):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes", count)
|
||||
|
||||
self.dump_graph(graph, "after_reshapes")
|
||||
self.end_and_log()
|
||||
|
||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
@ -28,8 +28,8 @@ class VllmInductorPass(InductorPass):
|
||||
self.config = config
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
if stage in self.config.dump_graph_stages:
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
|
||||
if stage in self.config.dump_graph_stages or always:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
@ -49,3 +49,17 @@ class VllmInductorPass(InductorPass):
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
config: CompilationConfig.PassConfig,
|
||||
always=False):
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
self.always = always
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, self.name, always=self.always)
|
||||
|
@ -2993,13 +2993,13 @@ class CompilationConfig(BaseModel):
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graphs. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
- enable_reshape: whether to enable the custom reshape elimination pass.
|
||||
TODO better pass enabling system.
|
||||
- enable_noop: whether to enable the custom no-op elimination pass.
|
||||
TODO(luka) better pass enabling system.
|
||||
"""
|
||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
enable_reshape: bool = True
|
||||
enable_noop: bool = True
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
@ -3008,13 +3008,12 @@ class CompilationConfig(BaseModel):
|
||||
Do not include dump_graph_* in the hash - they don't affect
|
||||
compilation.
|
||||
"""
|
||||
dict_ = self.model_dump(
|
||||
include={"enable_fusion", "enable_reshape"})
|
||||
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).digest()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not self.enable_reshape and self.enable_fusion:
|
||||
if not self.enable_noop and self.enable_fusion:
|
||||
logger.warning_once(
|
||||
"Fusion enabled but reshape elimination disabled. "
|
||||
"RMSNorm + quant (fp8) fusion might not work")
|
||||
@ -3411,7 +3410,7 @@ class VllmConfig:
|
||||
self.compilation_config.use_inductor = True
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
self.compilation_config.pass_config.enable_fusion = False
|
||||
self.compilation_config.pass_config.enable_reshape = False
|
||||
self.compilation_config.pass_config.enable_noop = False
|
||||
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
@ -161,10 +162,14 @@ def apply_fp8_linear(
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
config = get_current_vllm_config().compilation_config
|
||||
do_pad = config.level < CompilationLevel.PIECEWISE
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
input_2d,
|
||||
input_scale,
|
||||
num_token_padding=17,
|
||||
num_token_padding=17 if do_pad else None,
|
||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||
|
||||
per_tensor_weights = (weight_scale.numel() == 1)
|
||||
|
Reference in New Issue
Block a user