mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][CI][V1] Work around V1 + CUDA Graph + torch._scaled_mm fallback issue (#13425)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
cd4a72a28d
commit
b3942e157e
@ -9,8 +9,8 @@ from torch.nn import Parameter
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale)
|
||||
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
@ -93,6 +93,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
|
@ -17,7 +17,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
|
||||
apply_fp8_linear, maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter)
|
||||
from vllm.platforms import current_platform
|
||||
@ -84,6 +85,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
@ -24,8 +24,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
@ -162,6 +162,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
|
@ -9,7 +9,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
TORCH_DEVICE_IDENTITY = None
|
||||
|
||||
# The condition to determine if it is on a platform that supports
|
||||
# torch._scaled_mm rowwise feature.
|
||||
@ -113,6 +113,13 @@ def requantize_with_max_scale(
|
||||
return max_w_scale, weight
|
||||
|
||||
|
||||
def maybe_create_device_identity():
|
||||
# Allocate dummy ones tensor for torch._scaled_mm
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY is None:
|
||||
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||
|
||||
|
||||
def apply_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@ -215,11 +222,6 @@ def apply_fp8_linear(
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
|
Reference in New Issue
Block a user