[Bugfix] Add fake mode around passes (#23349)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2025-08-28 08:25:56 -07:00
committed by GitHub
parent 95089607fa
commit db74d60490
6 changed files with 64 additions and 39 deletions

View File

@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
@ -61,6 +62,7 @@ class ActivationQuantFusionPass(VllmInductorPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)

View File

@ -19,6 +19,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
FP8_DTYPE = current_platform.fp8_dtype()
@ -349,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern):
class AsyncTPPass(VllmInductorPass):
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)
@ -1121,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass):
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
self.register_patterns()
@enable_fake_mode
def register_patterns(self):
for epsilon in [1e-5, 1e-6]:
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,

View File

@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
from .inductor_pass import enable_fake_mode
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
@ -528,6 +529,7 @@ class FusionPass(VllmInductorPass):
cls._instance.pass_config = config.compilation_config.pass_config
return cls._instance
@enable_fake_mode
def __init__(self, config: VllmConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"

View File

@ -7,8 +7,6 @@ import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
@ -19,6 +17,7 @@ from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
@ -139,24 +138,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty_quant(5, self.num_heads *
self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # attn_output
self.empty_quant(5,
self.num_heads * self.head_size), # quant_output
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
@ -219,27 +215,23 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
[-1, self.num_heads * self.head_size // 2])
return output, at2[2]
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128,
round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128, round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttnFusionPass(VllmInductorPass):
@ -255,6 +247,7 @@ class AttnFusionPass(VllmInductorPass):
support are attention kernels, which need to support fusing output quant.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import inspect
import json
@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import torch
from torch import fx
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.utils import is_torch_equal_or_newer
@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(
None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new

View File

@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass):
performance.
"""
@enable_fake_mode
def __init__(self, config: VllmConfig):
super().__init__(config)