mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bug] DeepGemm: Fix Cuda Init Error (#21312)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
if not has_deep_gemm():
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_per_block_cast_impl: Callable[..., Any] | None = None
|
||||
else:
|
||||
_dg = importlib.import_module("deep_gemm") # type: ignore
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_per_block_cast_impl: Callable[..., Any] | None = None
|
||||
|
||||
_fp8_gemm_nt_impl = _resolve_symbol(
|
||||
_dg,
|
||||
"fp8_gemm_nt",
|
||||
"gemm_fp8_fp8_bf16_nt",
|
||||
)
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
|
||||
_per_block_cast_impl
|
||||
|
||||
# fast path
|
||||
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None
|
||||
or _per_block_cast_impl is not None):
|
||||
return
|
||||
|
||||
if not has_deep_gemm():
|
||||
return
|
||||
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
|
||||
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
|
||||
"gemm_fp8_fp8_bf16_nt")
|
||||
_grouped_impl = _resolve_symbol(
|
||||
_dg,
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
||||
)
|
||||
_dg, "m_grouped_fp8_gemm_nt_contiguous",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
|
||||
_grouped_masked_impl = _resolve_symbol(
|
||||
_dg,
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||
)
|
||||
|
||||
_dg, "fp8_m_grouped_gemm_nt_masked",
|
||||
"m_grouped_gemm_fp8_fp8_bf16_nt_masked")
|
||||
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
|
||||
try:
|
||||
_math_mod = importlib.import_module(
|
||||
@ -80,24 +86,28 @@ else:
|
||||
|
||||
|
||||
def fp8_gemm_nt(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _fp8_gemm_nt_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(x, *args, **kwargs):
|
||||
_lazy_init()
|
||||
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
|
||||
return _per_block_cast_impl(x, use_ue8m0=True)
|
||||
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
|
||||
|
Reference in New Issue
Block a user