mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Add fake mode around passes (#23349)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
@ -10,6 +10,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -61,6 +62,7 @@ class ActivationQuantFusionPass(VllmInductorPass):
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -349,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
|
||||
class AsyncTPPass(VllmInductorPass):
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1121,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
|
||||
# in fallback path, when we don't use flashinfer
|
||||
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
|
||||
|
||||
self.register_patterns()
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self):
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
|
@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .fx_utils import find_getitem_maybe
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .multi_output_match import MultiOutputMatch
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
@ -528,6 +529,7 @@ class FusionPass(VllmInductorPass):
|
||||
cls._instance.pass_config = config.compilation_config.pass_config
|
||||
return cls._instance
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
|
@ -7,8 +7,6 @@ import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
@ -19,6 +17,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -139,24 +138,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
output_block_scale=None)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
# That would not work for the unified_attention custom op.
|
||||
with unset_fake_temporarily(), FakeTensorMode():
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # attn_output
|
||||
self.empty_quant(5, self.num_heads *
|
||||
self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # attn_output
|
||||
self.empty_quant(5,
|
||||
self.num_heads * self.head_size), # quant_output
|
||||
empty_fp32(1, 1) # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
@ -219,27 +215,23 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
[-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
# Need custom fake mode, otherwise tracing happens with real tensors.
|
||||
# That would not work for the unified_attention custom op.
|
||||
with unset_fake_temporarily(), FakeTensorMode():
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size //
|
||||
2), # output_quant
|
||||
empty_i32(128,
|
||||
round_up(self.num_heads * self.head_size // 16,
|
||||
4)), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size //
|
||||
2), # output_quant
|
||||
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
|
||||
4)), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
pm.register_replacement(
|
||||
pattern, replacement, inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
|
||||
pm_pass)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmInductorPass):
|
||||
@ -255,6 +247,7 @@ class AttnFusionPass(VllmInductorPass):
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import (FakeTensorMode,
|
||||
unset_fake_temporarily)
|
||||
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args, **kwargs) -> Any:
|
||||
with torch._guards.tracing(
|
||||
None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
|
@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
|
||||
performance.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
Reference in New Issue
Block a user