From db74d604900d397e4ee524f93bcb256537679ce4 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Thu, 28 Aug 2025 08:25:56 -0700 Subject: [PATCH] [Bugfix] Add fake mode around passes (#23349) Signed-off-by: angelayi --- vllm/compilation/activation_quant_fusion.py | 2 + vllm/compilation/collective_fusion.py | 6 ++ vllm/compilation/fusion.py | 2 + vllm/compilation/fusion_attn.py | 71 ++++++++++----------- vllm/compilation/inductor_pass.py | 20 ++++++ vllm/compilation/sequence_parallelism.py | 2 + 6 files changed, 64 insertions(+), 39 deletions(-) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index ce4e50a2b0..826014f770 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -61,6 +62,7 @@ class ActivationQuantFusionPass(VllmInductorPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0c545d8cff..7a99aaff70 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -19,6 +19,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass FP8_DTYPE = current_platform.fp8_dtype() @@ -349,6 +350,7 @@ class AllGatherCutlassScaledMMPattern(BasePattern): class AsyncTPPass(VllmInductorPass): + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -1121,6 +1123,10 @@ class AllReduceFusionPass(VllmInductorPass): # in fallback path, when we don't use flashinfer fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + self.register_patterns() + + @enable_fake_mode + def register_patterns(self): for epsilon in [1e-5, 1e-6]: AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 0d8d562514..afa739c966 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe +from .inductor_pass import enable_fake_mode from .multi_output_match import MultiOutputMatch from .vllm_inductor_pass import VllmInductorPass @@ -528,6 +529,7 @@ class FusionPass(VllmInductorPass): cls._instance.pass_config = config.compilation_config.pass_config return cls._instance + @enable_fake_mode def __init__(self, config: VllmConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index f942afe6a2..3095f17110 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -7,8 +7,6 @@ import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._subclasses.fake_tensor import (FakeTensorMode, - unset_fake_temporarily) from vllm.attention import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config @@ -19,6 +17,7 @@ from vllm.platforms import current_platform from vllm.utils import round_up from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -139,24 +138,21 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): output_block_scale=None) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) - # Need custom fake mode, otherwise tracing happens with real tensors. - # That would not work for the unified_attention custom op. - with unset_fake_temporarily(), FakeTensorMode(): - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # attn_output - self.empty_quant(5, self.num_heads * - self.head_size), # quant_output - empty_fp32(1, 1) # scale - ] + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # attn_output + self.empty_quant(5, + self.num_heads * self.head_size), # quant_output + empty_fp32(1, 1) # scale + ] - pm.register_replacement( - pattern, replacement, inputs, - AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), - pm_pass) + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttentionNvfp4QuantPattern(AttentionQuantPattern): @@ -219,27 +215,23 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): [-1, self.num_heads * self.head_size // 2]) return output, at2[2] - # Need custom fake mode, otherwise tracing happens with real tensors. - # That would not work for the unified_attention custom op. - with unset_fake_temporarily(), FakeTensorMode(): - inputs = [ - empty_bf16(5, self.num_heads, self.head_size), # q - empty_bf16(5, self.num_heads, self.head_size), # k - empty_bf16(5, self.num_heads, self.head_size), # v - empty_bf16(5, self.num_heads, self.head_size), # output_attn - self.empty_quant(5, self.num_heads * self.head_size // - 2), # output_quant - empty_i32(128, - round_up(self.num_heads * self.head_size // 16, - 4)), # output_scale - empty_fp32(1, 1), # input_scale - ] + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // + 2), # output_quant + empty_i32(128, round_up(self.num_heads * self.head_size // 16, + 4)), # output_scale + empty_fp32(1, 1), # input_scale + ] - pm.register_replacement( - pattern, replacement, inputs, - AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), - pm_pass) + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttnFusionPass(VllmInductorPass): @@ -255,6 +247,7 @@ class AttnFusionPass(VllmInductorPass): support are attention kernels, which need to support fusing output quant. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 2a149c65b3..e1b691df38 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import functools import hashlib import inspect import json @@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union import torch from torch import fx +from torch._subclasses.fake_tensor import (FakeTensorMode, + unset_fake_temporarily) from vllm.utils import is_torch_equal_or_newer @@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass): def uuid(self) -> Any: return self._uuid + + +def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: + """ + Applies a FakeTensorMode context. This is useful when you don't want to + create or run things with real tensors. + """ + + @functools.wraps(fn) + def fn_new(*args, **kwargs) -> Any: + with torch._guards.tracing( + None), unset_fake_temporarily(), FakeTensorMode(): + result = fn(*args, **kwargs) + + return result + + return fn_new diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index ebc025cba7..1758ed4c86 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -14,6 +14,7 @@ from vllm.distributed.parallel_state import ( from vllm.logger import init_logger from vllm.platforms import current_platform +from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -436,6 +437,7 @@ class SequenceParallelismPass(VllmInductorPass): performance. """ + @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config)