mirror of
https://github.com/vllm-project/vllm.git
synced 2025-11-14 06:32:22 +08:00
Compare commits
13 Commits
wentao-opt
...
v0.11.1rc6
| Author | SHA1 | Date | |
|---|---|---|---|
| 30700b1cd7 | |||
| 4b94ed8f92 | |||
| 6dec9f6109 | |||
| bf6a3d0ff5 | |||
| 40d33264c6 | |||
| 9c84ca8293 | |||
| 6d54336ae5 | |||
| 34553b9d27 | |||
| b039bfda8f | |||
| d0e186c16f | |||
| f080a83511 | |||
| 40e2eeeb92 | |||
| b06b9470ca |
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54
|
||||
GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
|
||||
|
||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||
// fixed GROUP_SIZE of 128.
|
||||
static constexpr int GROUP_SIZE = 128;
|
||||
|
||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
||||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(input.size(-1) % 256 == 0);
|
||||
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
|
||||
|
||||
using Idx_t = int64_t;
|
||||
|
||||
@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
|
||||
|
||||
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
||||
|
||||
static constexpr int GROUP_SIZE = 128;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
||||
@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
|
||||
|
||||
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||
|
||||
int const NUM_GROUPS = H / GROUP_SIZE;
|
||||
if (!use_ue8m0) {
|
||||
if (H >= 4096) {
|
||||
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||
/* 8 warps config */
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
/* 1 warp config */
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
||||
}
|
||||
} else {
|
||||
if (H >= 4096) {
|
||||
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
|
||||
/* 8 warps config */
|
||||
static constexpr int NUM_STAGES = 4;
|
||||
static constexpr int THREAD_COUNT = 256;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
||||
} else {
|
||||
/* 1 warp config */
|
||||
static constexpr int THREAD_COUNT = 32;
|
||||
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
||||
}
|
||||
|
||||
@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
|
||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
|
||||
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
|
||||
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
|
||||
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -412,14 +415,12 @@ def test_mixtral_moe(
|
||||
huggingface."""
|
||||
|
||||
# clear the cache before every test
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
# Force reload aiter_ops to pick up the new environment variables.
|
||||
if "rocm_aiter_ops" in sys.modules:
|
||||
importlib.reload(rocm_aiter_ops)
|
||||
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ CASES = [
|
||||
(8, 16, 128 * 2, fp8_dtype),
|
||||
(8, 16, 128 * 3, fp8_dtype),
|
||||
(8, 64, 7168, fp8_dtype),
|
||||
(8, 128, 128 * 33, fp8_dtype),
|
||||
(8, 128, 7168, fp8_dtype),
|
||||
(8, 512, 7168, fp8_dtype),
|
||||
(8, 1024, 7168, fp8_dtype),
|
||||
@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
|
||||
)
|
||||
|
||||
# Run the SiLU V2 kernel
|
||||
# TODO (varun): use_e8m0 is set to false as the reference impl does
|
||||
# not handle that case.
|
||||
y_q, y_s = persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, group_size=group_size
|
||||
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (
|
||||
@ -15,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
dispatch_topk_func,
|
||||
vllm_topk_softmax,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
topk_func = dispatch_topk_func()
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_topk_softmax,
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
||||
)
|
||||
def test_topk_dispatch(use_rocm_aiter: bool):
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter)
|
||||
|
||||
assert topk_func == rocm_aiter_topk_softmax
|
||||
if current_platform.is_rocm() and use_rocm_aiter:
|
||||
assert topk_func == rocm_aiter_ops.topk_softmax
|
||||
else:
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
|
||||
)
|
||||
def test_rms_norm_dispatch(
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
use_rocm_aiter: str,
|
||||
use_rocm_aiter_norm: str,
|
||||
monkeypatch,
|
||||
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
|
||||
|
||||
should_use_rocm_aiter = (
|
||||
current_platform.is_rocm()
|
||||
and int(use_rocm_aiter)
|
||||
and int(use_rocm_aiter_norm)
|
||||
and use_rocm_aiter
|
||||
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
)
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
||||
elif should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm
|
||||
elif add_residual:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
else:
|
||||
|
||||
118
tests/v1/structured_output/test_backend_guidance.py
Normal file
118
tests/v1/structured_output/test_backend_guidance.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import StructuredOutputsConfig, VllmConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import StructuredOutputOptions
|
||||
|
||||
TOKENIZER = "gpt2"
|
||||
|
||||
|
||||
def test_backend_guidance_rollback_terminated():
|
||||
# Test that the backend guidance successfully rollbacks from a
|
||||
# terminated state. This can happen with speculative decoding,
|
||||
# where the draft model proposes EOS and it is verified by the
|
||||
# guidance backend. In that case we are in a stopped state, but
|
||||
# it should be reverted in case EOS is not accepted by the target
|
||||
# model.
|
||||
vllm_config = VllmConfig(
|
||||
decoding_config=StructuredOutputsConfig(
|
||||
backend="guidance",
|
||||
)
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
|
||||
|
||||
backend = GuidanceBackend(
|
||||
vllm_config,
|
||||
tokenizer=tokenizer,
|
||||
vocab_size=50257,
|
||||
)
|
||||
|
||||
grammar = backend.compile_grammar(
|
||||
StructuredOutputOptions.JSON, '{"type": "object"}'
|
||||
)
|
||||
|
||||
prompt = tokenizer.encode('{"a": "b"}')
|
||||
assert len(prompt) > 1
|
||||
dummy_wrong = tokenizer.encode('{"a"}')
|
||||
for token in prompt:
|
||||
assert grammar.accept_tokens("", [token])
|
||||
assert not grammar.is_terminated()
|
||||
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
|
||||
assert grammar.is_terminated()
|
||||
# Giving any other token should also be accepted
|
||||
assert grammar.accept_tokens("", dummy_wrong)
|
||||
# Rollback is done from where state was terminated, so from '}' not EOS
|
||||
grammar.rollback(len(prompt) - 1)
|
||||
assert not grammar.is_terminated()
|
||||
assert grammar.validate_tokens([tokenizer.eos_token_id]) == []
|
||||
assert grammar.validate_tokens(dummy_wrong) != dummy_wrong
|
||||
assert grammar.accept_tokens("", prompt[1:])
|
||||
assert not grammar.is_terminated()
|
||||
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
|
||||
assert grammar.is_terminated()
|
||||
# Rollback of <= 0 should not change the terminated state
|
||||
grammar.rollback(0)
|
||||
assert grammar.is_terminated()
|
||||
grammar.rollback(-1)
|
||||
assert grammar.is_terminated()
|
||||
|
||||
|
||||
def test_grammar_bitmask_with_specdec():
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
|
||||
prompt = tokenizer.encode('{"a": "b"}')
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(tokenizer=TOKENIZER),
|
||||
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
|
||||
speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3),
|
||||
)
|
||||
structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
for i in range(1, 2):
|
||||
sampling_params = SamplingParams(
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
json='{"type": "object"}',
|
||||
),
|
||||
)
|
||||
sampling_params.structured_outputs._backend = "guidance"
|
||||
|
||||
my_req_id = f"my_req_id_{i}"
|
||||
request = Request(
|
||||
my_req_id,
|
||||
prompt_token_ids=prompt[:i],
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
structured_output_manager.grammar_init(request)
|
||||
|
||||
def grammar_bitmask(req: Request, tokens: list[int]) -> None:
|
||||
structured_output_manager.grammar_bitmask(
|
||||
requests={req.request_id: req},
|
||||
structured_output_request_ids={req.request_id: 0},
|
||||
scheduled_spec_decode_tokens={req.request_id: tokens},
|
||||
)
|
||||
# At this point, we rolled-back, so should not be terminated
|
||||
assert not req.structured_output_request.grammar.is_terminated()
|
||||
|
||||
# The grammar might not yet be compiled, so we wait for it
|
||||
while not request.structured_output_request._check_grammar_completion():
|
||||
continue
|
||||
|
||||
assert request.structured_output_request.grammar.accept_tokens(
|
||||
request.request_id, prompt[:i]
|
||||
)
|
||||
|
||||
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
|
||||
grammar_bitmask(
|
||||
request, prompt[i:] + [tokenizer.eos_token_id] + prompt
|
||||
) # EOS not the final token
|
||||
grammar_bitmask(request, prompt[i:]) # EOS not present
|
||||
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
|
||||
941
vllm/_aiter_ops.py
Normal file
941
vllm/_aiter_ops.py
Normal file
@ -0,0 +1,941 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
from importlib.util import find_spec
|
||||
|
||||
return find_spec("aiter") is not None
|
||||
|
||||
|
||||
# `find_spec` is not torch.compile compatible.
|
||||
# In cases where aiter availability might have
|
||||
# been checked in forward passes that are torch compiled.
|
||||
# we keep this global outside to not cause torch compile breaks.
|
||||
IS_AITER_FOUND = is_aiter_found()
|
||||
|
||||
|
||||
def if_aiter_supported(func: Callable) -> Callable:
|
||||
"""Decorator that only executes the function if
|
||||
ROCm AITER package is supported on gfx9 archs.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# checks the platform, device arch and aiter library existance.
|
||||
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
# Return None or do nothing if not supported
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
quant_method: int = 0,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
quant_type = QuantType(quant_method)
|
||||
|
||||
return fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
expert_mask,
|
||||
activation,
|
||||
quant_type,
|
||||
doweight_stage1,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
quant_method: int = 0,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _rocm_aiter_asm_moe_tkw1_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
|
||||
return asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=fc1_scale,
|
||||
fc2_scale=fc2_scale,
|
||||
fc1_smooth_scale=fc1_smooth_scale,
|
||||
fc2_smooth_scale=fc2_smooth_scale,
|
||||
a16=a16,
|
||||
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||
expert_mask=expert_mask,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_asm_moe_tkw1_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def _rocm_aiter_topk_softmax_impl(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
from aiter import topk_softmax
|
||||
|
||||
topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_topk_softmax_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_biased_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
from aiter import biased_grouped_topk
|
||||
|
||||
biased_grouped_topk(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_biased_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
is_softmax = scoring_func == "softmax"
|
||||
from aiter import grouped_topk
|
||||
|
||||
grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
is_softmax,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from aiter import rms_norm
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def _rocm_aiter_rms_norm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter import rmsnorm2d_fwd_with_add
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
rmsnorm2d_fwd_with_add(
|
||||
output, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual_out, # residual output
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
|
||||
|
||||
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
|
||||
class rocm_aiter_ops:
|
||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_enabled(cls) -> bool:
|
||||
"""Verifies device specs and availability of aiter main env variable."""
|
||||
return cls._AITER_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_linear_fp8_enaled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_rmsnorm_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fused_moe_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
|
||||
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mla_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MLA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mha_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_pa_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fp8bmm_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_rotary_embed_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
||||
|
||||
@staticmethod
|
||||
@if_aiter_supported
|
||||
def register_ops_once() -> None:
|
||||
global _OPS_REGISTERED
|
||||
if not _OPS_REGISTERED:
|
||||
tags = (
|
||||
tuple()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
)
|
||||
|
||||
# register all the custom ops here
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_moe",
|
||||
op_func=_rocm_aiter_fused_moe_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_fused_moe_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_topk_softmax",
|
||||
op_func=_rocm_aiter_topk_softmax_impl,
|
||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||
fake_impl=_rocm_aiter_topk_softmax_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_biased_grouped_topk",
|
||||
op_func=_rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_grouped_topk",
|
||||
op_func=_rocm_aiter_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=_rocm_aiter_grouped_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=_rocm_aiter_gemm_w8a8_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_w8a8_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=_rocm_aiter_rms_norm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_rms_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
def rms_norm2d_with_add(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
|
||||
x, residual, weight, variance_epsilon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
@staticmethod
|
||||
def gemm_w8a8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def gemm_w8a8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
A, B, As, Bs, output_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
quant_method: int = 0,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
expert_mask,
|
||||
activation_method,
|
||||
quant_method,
|
||||
doweight_stage1,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def asm_moe_tkw1(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = 0,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale,
|
||||
fc2_scale,
|
||||
fc1_smooth_scale,
|
||||
fc2_smooth_scale,
|
||||
a16,
|
||||
per_tensor_quant_scale,
|
||||
expert_mask,
|
||||
activation_method,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
torch.ops.vllm.rocm_aiter_topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
@staticmethod
|
||||
def biased_grouped_topk(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def grouped_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp4_gemm_dynamic_qaunt(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||
x_scales: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
if x_scales is None:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
y = torch.empty(
|
||||
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def triton_rotary_embed(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
):
|
||||
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
|
||||
|
||||
num_tokens = positions.numel()
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
rotate_style = 0 if is_neox_style else 1
|
||||
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
query_ = query[..., :rotary_dim]
|
||||
key_ = key[..., :rotary_dim]
|
||||
positions = positions.view(*query.shape[:1])
|
||||
rope_cached_thd_positions_2c_fwd_inplace(
|
||||
positions,
|
||||
sin,
|
||||
cos,
|
||||
query_,
|
||||
key_,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
is_nope_first=False,
|
||||
)
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp8_bmm(
|
||||
X: torch.Tensor,
|
||||
WQ: torch.Tensor,
|
||||
w_scale: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
bias: torch.Tensor | None = None,
|
||||
dtype: torch.dtype | None = torch.bfloat16,
|
||||
splitK: int | None = None,
|
||||
YQ: torch.Tensor | None = None,
|
||||
transpose_bm: bool | None = False,
|
||||
config: dict | None = None,
|
||||
) -> torch.Tensor:
|
||||
# ruff: noqa: E501 # isort: skip
|
||||
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
|
||||
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
|
||||
)
|
||||
|
||||
return aiter_triton_fp8_bmm(
|
||||
X,
|
||||
WQ,
|
||||
w_scale,
|
||||
group_size=group_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
splitK=splitK,
|
||||
YQ=YQ,
|
||||
transpose_bm=transpose_bm,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def triton_gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def per_1x128_fp8_quant(
|
||||
input_2d: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Only applies quantization method for fp8 data type only."""
|
||||
from aiter import QuantType, dtypes, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
|
||||
|
||||
@staticmethod
|
||||
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
|
||||
return (n, k) in [
|
||||
(1024, 8192),
|
||||
(2112, 7168),
|
||||
(3072, 1536),
|
||||
(32768, 8192),
|
||||
(4096, 7168),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2048),
|
||||
(7168, 256),
|
||||
(8192, 1024),
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weight(
|
||||
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
) -> torch.Tensor:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return shuffle_weight(tensor, layout=layout)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weights(
|
||||
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Applies shuffle_weight function from AITER to each
|
||||
input tensor and returns them.
|
||||
|
||||
Rearranges (shuffles) the input tensor/s
|
||||
into a specified block layout for optimized computation.
|
||||
|
||||
Args:
|
||||
*tensors: Variable number of torch.Tensor objects.
|
||||
layout: A pair of integers specifying the block sizes used to divide
|
||||
the tensors during shuffling. Default is (16, 16).
|
||||
|
||||
Returns:
|
||||
A Tuple of shuffled tensors.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
@ -1,105 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(
|
||||
max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
paged_kv_indices = torch.zeros(
|
||||
max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
|
||||
)
|
||||
paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
|
||||
paged_kv_last_page_lens = torch.full(
|
||||
(max_batch_size,), block_size, dtype=torch.int32
|
||||
)
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
logit_cap: float = 0.0,
|
||||
):
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
|
||||
def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
)
|
||||
|
||||
|
||||
def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: torch.Tensor | None = None,
|
||||
kv_indices: torch.Tensor | None = None,
|
||||
kv_last_page_lens: torch.Tensor | None = None,
|
||||
sm_scale: float = 1.0,
|
||||
logit_cap: float = 0.0,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
if is_torch_equal_or_newer("2.7.0"):
|
||||
tags = ()
|
||||
else:
|
||||
tags = ((torch.Tag.needs_fixed_stride_order,),)
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=mla_decode_fwd_fake,
|
||||
tags=tags,
|
||||
)
|
||||
@ -608,17 +608,19 @@ class VllmConfig:
|
||||
)
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
<= self.cache_config.block_size
|
||||
and self.cache_config.block_size
|
||||
% self.parallel_config.dcp_kv_cache_interleave_size
|
||||
== 0
|
||||
), (
|
||||
f"Block_size({self.cache_config.block_size}) should be "
|
||||
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
|
||||
)
|
||||
# If DCP, ensure the block size is right.
|
||||
if self.parallel_config.decode_context_parallel_size > 1:
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
<= self.cache_config.block_size
|
||||
and self.cache_config.block_size
|
||||
% self.parallel_config.dcp_kv_cache_interleave_size
|
||||
== 0
|
||||
), (
|
||||
f"Block_size({self.cache_config.block_size}) should be greater "
|
||||
"than or equal to and divisible by dcp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
|
||||
)
|
||||
|
||||
assert (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size == 1
|
||||
|
||||
@ -1375,6 +1375,8 @@ class OpenAIServing:
|
||||
for tool_call in tool_call_info.tool_calls
|
||||
)
|
||||
content = tool_call_info.content
|
||||
if content and content.strip() == "":
|
||||
content = None
|
||||
else:
|
||||
# No tool calls.
|
||||
return None, content
|
||||
|
||||
@ -109,7 +109,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||
VLLM_ROCM_USE_TRITON_ROPE: bool = False
|
||||
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
|
||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
|
||||
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
|
||||
@ -926,8 +926,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
),
|
||||
# Whether to use aiter rope.
|
||||
# By default is disabled.
|
||||
"VLLM_ROCM_USE_TRITON_ROPE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1")
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter triton fp8 bmm kernel
|
||||
# By default is enabled.
|
||||
@ -1589,7 +1589,7 @@ def compute_hash() -> str:
|
||||
"VLLM_ROCM_USE_AITER_MLA",
|
||||
"VLLM_ROCM_USE_AITER_MHA",
|
||||
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
||||
"VLLM_ROCM_USE_TRITON_ROPE",
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE",
|
||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
|
||||
"VLLM_ROCM_USE_AITER_TRITON_GEMM",
|
||||
|
||||
@ -23,6 +23,7 @@ def mm_k(
|
||||
CAST_TYPE: tl.constexpr,
|
||||
b_dtype: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
base_k,
|
||||
):
|
||||
"""
|
||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||
@ -47,32 +48,62 @@ def mm_k(
|
||||
matrix dtype.
|
||||
b_dtype: datatype of the B matrix
|
||||
USE_GDC: Whether to use PDL. True indicates use.
|
||||
base_k: Base offset along K dimension for current SPLIT_K group
|
||||
"""
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
|
||||
# Step size along K for each iteration
|
||||
STEP_K = BLOCK_K * SPLIT_K
|
||||
|
||||
# Total number of iterations (compile-time constant)
|
||||
num_iters = tl.cdiv(K, STEP_K)
|
||||
|
||||
for k in range(num_iters):
|
||||
# Current iteration's global K offset
|
||||
iter_k = k * STEP_K + base_k
|
||||
|
||||
# Check if this iteration is completely valid (no masking needed)
|
||||
block_end = iter_k + BLOCK_K
|
||||
|
||||
if EVEN_K:
|
||||
# pre-fetech lora weight
|
||||
# K is divisible by BLOCK_K, no masking ever needed
|
||||
# pre-fetch lora weight
|
||||
tiled_b = tl.load(b_ptr)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
else:
|
||||
tiled_b = tl.load(
|
||||
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
|
||||
)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(
|
||||
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
|
||||
)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
||||
# Check if we need element-wise masking
|
||||
if iter_k >= K:
|
||||
# Entire block out of range, skip
|
||||
pass
|
||||
elif block_end <= K:
|
||||
# Entire block in range, no masking needed (fast path)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
else:
|
||||
# Partial block, need masking (only last iteration)
|
||||
k_offsets = tl.arange(0, BLOCK_K)
|
||||
mask = iter_k + k_offsets < K
|
||||
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += STEP_K * ak_stride
|
||||
b_ptr += STEP_K * bk_stride
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
@ -178,6 +209,7 @@ def do_expand_kernel(
|
||||
CAST_TYPE,
|
||||
cur_lora_ptr.dtype.element_ty,
|
||||
USE_GDC,
|
||||
base_k=0,
|
||||
)
|
||||
|
||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||
@ -284,6 +316,7 @@ def do_shrink_kernel(
|
||||
False,
|
||||
cur_lora_ptr.dtype.element_ty,
|
||||
False, # USE_GDC is always False in shrink kernel
|
||||
base_k=pid_sk * BLOCK_K,
|
||||
)
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
if USE_GDC:
|
||||
|
||||
@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
group_size: int = 128,
|
||||
use_ue8m0: bool | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
device=y.device,
|
||||
)
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
|
||||
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
|
||||
return a_shape, w_shape
|
||||
|
||||
|
||||
# The type of method in top-K routing
|
||||
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
|
||||
class RoutingMethodType(IntEnum):
|
||||
# Default: Softmax -> TopK
|
||||
Default = (0,)
|
||||
# Renormalize: TopK -> Softmax
|
||||
Renormalize = (1,)
|
||||
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
|
||||
# -> Top8 experts from the Top4 groups
|
||||
DeepSeekV3 = (2,)
|
||||
# Llama4: Top1 -> Sigmoid
|
||||
Llama4 = (3,)
|
||||
# RenormalizeNaive: Softmax -> TopK -> Renormalize
|
||||
RenormalizeNaive = (4,)
|
||||
# TopK: TopK (no softmax)
|
||||
TopK = (5,)
|
||||
# Unspecified
|
||||
Unspecified = 6.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
|
||||
@ -215,7 +215,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
)
|
||||
assert M_sum % block_m == 0
|
||||
|
||||
workspace1 = (M_sum, max(N, K))
|
||||
workspace1 = (M_sum, N)
|
||||
workspace2 = (M_sum, max(N // 2, K))
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
calculate_tile_tokens_dim,
|
||||
@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0,
|
||||
routing_method_type: int = RoutingMethodType.DeepSeekV3,
|
||||
routed_scaling: float | None = 1.0,
|
||||
) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 8
|
||||
assert topk_group <= 4
|
||||
assert global_num_experts > num_expert_group
|
||||
assert global_num_experts % num_expert_group == 0
|
||||
assert top_k <= 10
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
# Routing kernel expects #experts <= #threads 256
|
||||
assert global_num_experts <= 256
|
||||
# Routing kernel expects #experts <= #threads 512
|
||||
assert global_num_experts <= 512
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(
|
||||
x.shape[0], top_k, global_num_experts
|
||||
),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
tile_tokens_dim=None,
|
||||
routing_method_type=routing_method_type,
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routing_method_type: int,
|
||||
routed_scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
||||
|
||||
return rocm_aiter_topk_softmax
|
||||
def dispatch_topk_func(
|
||||
use_rocm_aiter: bool = False,
|
||||
) -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||
if use_rocm_aiter:
|
||||
return rocm_aiter_ops.topk_softmax
|
||||
return vllm_topk_softmax
|
||||
|
||||
|
||||
@ -1121,7 +1120,7 @@ def fused_topk(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_func = dispatch_topk_func()
|
||||
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
|
||||
topk_weights, topk_ids = topk_func(
|
||||
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.distributed import (
|
||||
@ -30,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
biased_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
|
||||
@ -41,8 +43,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
init_aiter_topK_meta_data,
|
||||
is_rocm_aiter_fusion_shared_expert_enabled,
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -92,13 +92,11 @@ else:
|
||||
return topk_ids
|
||||
|
||||
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk,
|
||||
)
|
||||
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk as grouped_topk_aiter,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
@ -463,7 +461,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
|
||||
@ -620,13 +619,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
# Padding the weight for better performance on ROCm
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
# Lazy import to avoid importing triton.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
@ -1002,6 +997,7 @@ def determine_expert_map(
|
||||
global_num_experts: int,
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear",
|
||||
num_fused_shared_experts: int = 0,
|
||||
return_expert_mask: bool = False,
|
||||
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Calculates how many experts should be assigned to each rank for EP and
|
||||
@ -1064,7 +1060,7 @@ def determine_expert_map(
|
||||
)
|
||||
|
||||
expert_mask = None
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if return_expert_mask:
|
||||
expert_mask = torch.ones(
|
||||
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
|
||||
)
|
||||
@ -1218,6 +1214,7 @@ class FusedMoE(CustomOp):
|
||||
zero_expert_type: str | None = None,
|
||||
expert_mapping: list[tuple[str, str, int, str]] | None = None,
|
||||
n_shared_experts: int | None = None,
|
||||
routing_method_type: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -1292,14 +1289,18 @@ class FusedMoE(CustomOp):
|
||||
self.logical_replica_count: torch.Tensor | None = None
|
||||
|
||||
# ROCm aiter shared experts fusion
|
||||
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
self.aiter_fmoe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
|
||||
self.num_fused_shared_experts = (
|
||||
n_shared_experts
|
||||
if n_shared_experts is not None
|
||||
and is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
|
||||
else 0
|
||||
)
|
||||
if (
|
||||
not is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
not self.aiter_fmoe_shared_expert_enabled
|
||||
and self.num_fused_shared_experts != 0
|
||||
):
|
||||
raise ValueError(
|
||||
@ -1346,6 +1347,7 @@ class FusedMoE(CustomOp):
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp):
|
||||
"Only softmax scoring function is supported for non-grouped topk."
|
||||
)
|
||||
|
||||
# ToDo: Better logic to determine the routing method type
|
||||
if routing_method_type is not None:
|
||||
self.routing_method_type = routing_method_type
|
||||
else:
|
||||
if scoring_func == "sigmoid":
|
||||
if self.use_grouped_topk:
|
||||
self.routing_method_type = RoutingMethodType.DeepSeekV3
|
||||
elif self.top_k == 1:
|
||||
self.routing_method_type = RoutingMethodType.Llama4
|
||||
elif self.scoring_func == "softmax":
|
||||
self.routing_method_type = (
|
||||
RoutingMethodType.Renormalize
|
||||
if not self.renormalize
|
||||
else RoutingMethodType.RenormalizeNaive
|
||||
)
|
||||
else:
|
||||
self.routing_method_type = RoutingMethodType.TopK
|
||||
|
||||
self.moe_config: FusedMoEConfig = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
@ -1570,13 +1590,16 @@ class FusedMoE(CustomOp):
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
|
||||
)
|
||||
if self.aiter_fmoe_shared_expert_enabled:
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=get_current_vllm_config(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
)
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
@ -1753,20 +1776,19 @@ class FusedMoE(CustomOp):
|
||||
def _init_aiter_shared_experts_topK_buffer(
|
||||
self, vllm_config: VllmConfig, dp_size: int
|
||||
):
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled():
|
||||
if self.num_fused_shared_experts > 0:
|
||||
init_aiter_topK_meta_data(
|
||||
n_routed_experts=self.global_num_experts,
|
||||
n_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=self.top_k,
|
||||
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
|
||||
tp_size=self.ep_size if self.use_ep else self.tp_size,
|
||||
shared_experts_score=1.0,
|
||||
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
|
||||
* dp_size,
|
||||
is_EP=self.use_ep,
|
||||
)
|
||||
self.local_num_experts += self.num_fused_shared_experts
|
||||
if self.num_fused_shared_experts > 0:
|
||||
init_aiter_topK_meta_data(
|
||||
n_routed_experts=self.global_num_experts,
|
||||
n_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=self.top_k,
|
||||
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
|
||||
tp_size=self.ep_size if self.use_ep else self.tp_size,
|
||||
shared_experts_score=1.0,
|
||||
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
|
||||
* dp_size,
|
||||
is_EP=self.use_ep,
|
||||
)
|
||||
self.local_num_experts += self.num_fused_shared_experts
|
||||
|
||||
@overload
|
||||
def weight_loader(
|
||||
@ -2208,15 +2230,16 @@ class FusedMoE(CustomOp):
|
||||
elif use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
if not is_rocm_aiter_fusion_shared_expert_enabled():
|
||||
if rocm_aiter_ops.is_fused_moe_enabled():
|
||||
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
|
||||
assert num_fused_shared_experts == 0
|
||||
grouped_topk_impl = partial(
|
||||
grouped_topk_aiter,
|
||||
rocm_aiter_grouped_topk,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
)
|
||||
else:
|
||||
grouped_topk_impl = grouped_topk
|
||||
|
||||
topk_weights, topk_ids = grouped_topk_impl(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@ -2448,7 +2471,7 @@ class FusedMoE(CustomOp):
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.rocm_aiter_fmoe_enabled
|
||||
else self.expert_mask,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
@ -2612,7 +2635,7 @@ class FusedMoE(CustomOp):
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.rocm_aiter_fmoe_enabled
|
||||
else self.expert_mask,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import IntEnum
|
||||
from functools import cache, lru_cache
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
class QuantMethod(IntEnum):
|
||||
@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
|
||||
GELU = 1
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_moe_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_MOE
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def use_mxfp4_aiter_moe() -> bool:
|
||||
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
|
||||
return (
|
||||
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
|
||||
)
|
||||
|
||||
|
||||
aiter_topK_meta_data = None
|
||||
|
||||
|
||||
@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
|
||||
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
|
||||
|
||||
|
||||
def rocm_aiter_asm_moe_tkw1_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
|
||||
return asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=fc1_scale,
|
||||
fc2_scale=fc2_scale,
|
||||
fc1_smooth_scale=fc1_smooth_scale,
|
||||
fc2_smooth_scale=fc2_smooth_scale,
|
||||
a16=a16,
|
||||
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||
expert_mask=expert_mask,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_asm_moe_tkw1_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: torch.Tensor | None = None,
|
||||
fc2_scale: torch.Tensor | None = None,
|
||||
fc1_smooth_scale: torch.Tensor | None = None,
|
||||
fc2_smooth_scale: torch.Tensor | None = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: torch.Tensor | None = None,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax_impl(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
from aiter import topk_softmax
|
||||
|
||||
topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax_fake(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_biased_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
from aiter import biased_grouped_topk
|
||||
|
||||
biased_grouped_topk(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_biased_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
correction_bias: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_grouped_topk_impl(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
from aiter import grouped_topk
|
||||
|
||||
grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
need_renorm,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_grouped_topk_fake(
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
need_renorm: bool,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0, # mul to topk_weights
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
quant_method: int = QuantMethod.NO.value,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
quant_type = QuantType(quant_method)
|
||||
|
||||
return fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
expert_mask,
|
||||
activation,
|
||||
quant_type,
|
||||
doweight_stage1,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_fused_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_mask: torch.Tensor | None = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
quant_method: int = QuantMethod.NO.value,
|
||||
doweight_stage1: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_moe",
|
||||
op_func=rocm_aiter_fused_moe_impl,
|
||||
fake_impl=rocm_aiter_fused_moe_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_topk_softmax",
|
||||
op_func=rocm_aiter_topk_softmax_impl,
|
||||
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||
fake_impl=rocm_aiter_topk_softmax_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_biased_grouped_topk",
|
||||
op_func=rocm_aiter_biased_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_biased_grouped_topk_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_grouped_topk",
|
||||
op_func=rocm_aiter_grouped_topk_impl,
|
||||
mutates_args=["topk_weights", "topk_ids"],
|
||||
fake_impl=rocm_aiter_grouped_topk_fake,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
token = hidden_states.shape[0]
|
||||
device = hidden_states.device
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
|
||||
if (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
and num_fused_shared_experts > 0
|
||||
):
|
||||
assert aiter_topK_meta_data is not None, (
|
||||
"AITER topK meta data is not initialized. "
|
||||
"Please ensure that init_aiter_topK_meta_data "
|
||||
@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
|
||||
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
|
||||
rocm_aiter_ops.biased_grouped_topk(
|
||||
gating_output,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
topk_weights,
|
||||
@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
|
||||
)
|
||||
else:
|
||||
assert scoring_func == "softmax" or scoring_func == "sigmoid"
|
||||
torch.ops.vllm.rocm_aiter_grouped_topk(
|
||||
rocm_aiter_ops.grouped_topk(
|
||||
gating_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
|
||||
if (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
and num_fused_shared_experts > 0
|
||||
):
|
||||
return total_topk_weights, total_topk_ids
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
|
||||
"Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
)
|
||||
|
||||
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
||||
return rocm_aiter_ops.asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
|
||||
|
||||
else:
|
||||
quant_method = QuantMethod.NO.value
|
||||
|
||||
# quark moe for mxfp4 w_dtype
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
quant_method = QuantMethod.BLOCK_1X32.value
|
||||
# w8a8 block-scaled
|
||||
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
|
||||
assert not apply_router_weight_on_input, (
|
||||
@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
|
||||
"Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
)
|
||||
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
return rocm_aiter_ops.fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
|
||||
a2_scale=quant_config.a2_scale,
|
||||
doweight_stage1=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
token_expert_indices: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
renormalize: bool,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
torch.ops.vllm.rocm_aiter_topk_softmax(
|
||||
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
|
||||
)
|
||||
return topk_weights, topk_indices
|
||||
|
||||
|
||||
def shuffle_weights(
|
||||
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Applies shuffle_weight function from AITER to each
|
||||
input tensor and returns them.
|
||||
|
||||
Rearranges (shuffles) the input tensor/s
|
||||
into a specified block layout for optimized computation.
|
||||
|
||||
Args:
|
||||
*tensors: Variable number of torch.Tensor objects.
|
||||
layout: A pair of integers specifying the block sizes used to divide
|
||||
the tensors during shuffling. Default is (16, 16).
|
||||
|
||||
Returns:
|
||||
A Tuple of shuffled tensors.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
@ -28,13 +28,18 @@ class SharedFusedMoE(FusedMoE):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
|
||||
# Disable shared expert overlap if we are not using
|
||||
# flashinfer + DP since there is nothing to be gained in this case.
|
||||
# Disabling the overlap optimization also prevents the shared experts
|
||||
# from being hidden from torch.compile.
|
||||
# Disable shared expert overlap if we are using eplb, because of
|
||||
# correctness issues, or if using flashinfer with DP, since there
|
||||
# is nothing to be gained in this case. Disabling the overlap
|
||||
# optimization also prevents the shared experts from being hidden
|
||||
# from torch.compile.
|
||||
self.use_overlapped = (
|
||||
use_overlapped
|
||||
and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||
and not (
|
||||
# TODO(wentao): find the root cause and remove this condition
|
||||
self.enable_eplb
|
||||
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
|
||||
@ -6,18 +6,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
def rms_norm(
|
||||
@ -58,80 +53,34 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
def poly_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
rocm_aiter.rmsnorm2d_fwd_with_add(
|
||||
output, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual_out, # residual output
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_fake(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=rocm_aiter_rms_norm_impl,
|
||||
fake_impl=rocm_aiter_rms_norm_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
|
||||
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
|
||||
def dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
|
||||
):
|
||||
use_aiter = use_aiter and dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
|
||||
if use_aiter and with_fused_add:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
return rocm_aiter_ops.rms_norm2d_with_add
|
||||
if use_aiter:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm
|
||||
return rocm_aiter_ops.rms_norm
|
||||
|
||||
# fall back to CUDA implementation
|
||||
if with_fused_add:
|
||||
@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=False, dtype=weight_dtype
|
||||
with_fused_add=False,
|
||||
dtype=weight_dtype,
|
||||
use_aiter=aiter_rmsnorm_enabled,
|
||||
)
|
||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# cutlass path
|
||||
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
|
||||
@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
|
||||
@ -7,12 +7,12 @@ import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
)
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
@ -27,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
@ -56,7 +58,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
check_aiter_fp8_linear_support,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
@ -369,7 +370,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
@ -869,12 +870,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
@ -916,7 +913,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
@ -962,7 +959,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
@ -1042,7 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
start += shard_size
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight
|
||||
)
|
||||
|
||||
@ -1226,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert activation == "silu", (
|
||||
f"Expected 'silu' activation but got {activation}"
|
||||
)
|
||||
assert scoring_func == "sigmoid", (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}"
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
|
||||
assert (
|
||||
renormalize and use_grouped_topk and custom_routing_function is None
|
||||
)
|
||||
e_score_correction_bias = (
|
||||
e_score_correction_bias.to(x.dtype)
|
||||
if e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
routing_method_type = layer.routing_method_type
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_logits=router_logits.to(torch.float32)
|
||||
if routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
else router_logits,
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
w13_weight=layer.w13_weight,
|
||||
@ -1256,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||
local_num_experts=layer.local_num_experts,
|
||||
block_shape=self.weight_block_size,
|
||||
routing_method_type=routing_method_type,
|
||||
routed_scaling=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -4,54 +4,14 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
from aiter import gemm_a8w8_CK
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=rocm_aiter_gemm_w8a8_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_fake,
|
||||
)
|
||||
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
+ "installed on ROCm.",
|
||||
)
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
|
||||
if not (rocm_aiter_ops.is_linear_enabled()):
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel is disabled. "
|
||||
@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
|
||||
x_q, w_q.t(), x_s, w_s, bias, out_dtype
|
||||
)
|
||||
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
use_mxfp4_aiter_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
)
|
||||
@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
shuffle_weights,
|
||||
)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
@ -458,6 +451,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
|
||||
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
|
||||
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
|
||||
|
||||
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self.input_dtype, self.weight_dtype
|
||||
@ -469,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"not implemented. Please open an issue."
|
||||
)
|
||||
|
||||
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
|
||||
self.emulate = not current_platform.supports_mx() or not (
|
||||
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
)
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
|
||||
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
@ -581,6 +577,17 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
|
||||
if self.fp4_dtype is not None:
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
layer.w13_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w13_weight.requires_grad,
|
||||
)
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
layer.w2_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w2_weight.requires_grad,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
@ -644,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
if not self.emulate:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
aiter_acts = {
|
||||
ActivationType.No.name.lower(): ActivationType.No,
|
||||
ActivationType.Silu.name.lower(): ActivationType.Silu,
|
||||
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
|
||||
}
|
||||
assert activation in aiter_acts, (
|
||||
f"Aiter CK fp4 MoE doesn't support activation {activation}"
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
out = fused_moe(
|
||||
|
||||
out = rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=aiter_acts[activation],
|
||||
doweight_stage1=False,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: move registration of custom op to aiter_ops.py
|
||||
# `from vllm._aiter_ops import rocm_aiter_ops`
|
||||
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
|
||||
# for envs checks which does not require @cache anymore.
|
||||
# triton kernel is torch compile compatible.
|
||||
# does not require direct registeration.
|
||||
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return (
|
||||
|
||||
@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
|
||||
|
||||
|
||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
from flashinfer import next_positive_power_of_2
|
||||
|
||||
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
|
||||
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
|
||||
# with the necessary kernels is released.
|
||||
tile_tokens_dim = 8
|
||||
|
||||
# from flashinfer import next_positive_power_of_2
|
||||
|
||||
# # Guess tokens per expert assuming perfect expert distribution first.
|
||||
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# # And pad the number to the next power of 2.
|
||||
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
||||
# # kernel.
|
||||
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
# A factor considering tokens are not perfectly balanced among experts.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@ -68,78 +69,6 @@ def cutlass_scaled_mm(
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
def is_aiter_triton_kernel_tuned(n, k):
|
||||
return (n, k) in [
|
||||
(1024, 8192),
|
||||
(2112, 7168),
|
||||
(3072, 1536),
|
||||
(32768, 8192),
|
||||
(4096, 7168),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2048),
|
||||
(7168, 256),
|
||||
(8192, 1024),
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
group_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
else:
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
import aiter as rocm_aiter
|
||||
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
q_input, input_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
|
||||
)
|
||||
|
||||
return gemm_a8w8_blockscale(
|
||||
q_input, weight, input_scale, weight_scale, dtype=output_dtype
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = input_2d.shape[0]
|
||||
n = weight.shape[0]
|
||||
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
# after we resolve GroupShape compilation issue
|
||||
# https://github.com/vllm-project/vllm/issues/25270
|
||||
@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
input_2d,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
self.act_quant_group_shape.col,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
if (
|
||||
not current_platform.is_fp8_fnuz()
|
||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||
):
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
self.act_quant_group_shape.col,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
||||
return rocm_aiter_ops.gemm_w8a8_blockscale(
|
||||
q_input,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
def _run_triton(
|
||||
self,
|
||||
@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm for MI3XX"""
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
)
|
||||
|
||||
|
||||
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad the weight tensor. This is an optimization on ROCm platform, which
|
||||
can benefit from tensors located far enough from one another in memory"""
|
||||
|
||||
@ -472,7 +472,7 @@ class Fp8LinearOp:
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
|
||||
@ -4,13 +4,10 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
from .rocm_aiter_rope_ops import (
|
||||
is_rocm_triton_rotary_embedding_enabled,
|
||||
rocm_aiter_rotary_emb,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
self.is_rocm_triton_rotary_embedding_enabled = (
|
||||
is_rocm_triton_rotary_embedding_enabled()
|
||||
self.is_rocm_triton_rotary_embed_enabled = (
|
||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.is_rocm_triton_rotary_embedding_enabled:
|
||||
if self.is_rocm_triton_rotary_embed_enabled:
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
rocm_aiter_rotary_emb(
|
||||
rocm_aiter_ops.triton_rotary_embed(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
|
||||
@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
|
||||
key = key_rot
|
||||
return query, key
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
||||
@ -1,94 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_triton_rotary_embedding_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_TRITON_ROPE
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
import aiter.ops.triton.rope as ops
|
||||
|
||||
ops.rope_cached_thd_positions_2c_fwd_inplace(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
positions,
|
||||
rotate_style,
|
||||
reuse_freqs_front_part=True,
|
||||
nope_first=is_nope_first,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
|
||||
positions: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
rotate_style: int = 0,
|
||||
is_nope_first: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
if is_rocm_triton_rotary_embedding_enabled():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
|
||||
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
|
||||
mutates_args=["key", "query"],
|
||||
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_rotary_emb(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
is_neox_style: bool,
|
||||
):
|
||||
num_tokens = positions.numel()
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
rotate_style = 0 if is_neox_style else 1
|
||||
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
query_ = query[..., :rotary_dim]
|
||||
key_ = key[..., :rotary_dim]
|
||||
positions = positions.view(*query.shape[:1])
|
||||
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
|
||||
positions,
|
||||
sin,
|
||||
cos,
|
||||
query_,
|
||||
key_,
|
||||
rotate_style,
|
||||
False,
|
||||
)
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
@ -33,6 +33,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
@ -50,10 +51,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled,
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.physical_expert_start + self.n_local_physical_experts
|
||||
)
|
||||
|
||||
if (
|
||||
config.n_shared_experts is None
|
||||
or is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
):
|
||||
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
|
||||
self.shared_experts = None
|
||||
else:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
# aiter applies routed_scaling_factor internally
|
||||
routed_scaling_factor=1.0
|
||||
if not is_rocm_aiter_moe_enabled()
|
||||
if not self.is_rocm_aiter_moe_enabled
|
||||
else self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
n_shared_experts=config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
else None,
|
||||
)
|
||||
|
||||
@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# Fix FP16 overflow
|
||||
# See DeepseekV2DecoderLayer for more details.
|
||||
if hidden_states.dtype != torch.float16:
|
||||
if not is_rocm_aiter_moe_enabled():
|
||||
if not self.is_rocm_aiter_moe_enabled:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
elif self.shared_experts is not None:
|
||||
assert shared_output is not None
|
||||
@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
rocm_aiter_moe_shared_expert_enabled = (
|
||||
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
|
||||
)
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.config.n_shared_experts
|
||||
if is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
if rocm_aiter_moe_shared_expert_enabled
|
||||
else 0
|
||||
),
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
is_fuse_shared_experts_layer = (
|
||||
is_rocm_aiter_fusion_shared_expert_enabled()
|
||||
and ("mlp.shared_experts" in name)
|
||||
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
|
||||
"mlp.shared_experts" in name
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
|
||||
@ -1435,8 +1435,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
@ -1569,7 +1567,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -1622,8 +1622,6 @@ class Glm4vForConditionalGeneration(
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -1754,7 +1752,6 @@ class Glm4vForConditionalGeneration(
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -625,8 +625,6 @@ class GLM4VForCausalLM(
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
@ -758,7 +756,6 @@ class GLM4VForCausalLM(
|
||||
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -995,8 +995,6 @@ class SupportsMRoPE(Protocol):
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -1012,8 +1010,6 @@ class SupportsMRoPE(Protocol):
|
||||
image_grid_thw: Image grid dimensions (t, h, w)
|
||||
video_grid_thw: Video grid dimensions (t, h, w)
|
||||
second_per_grid_ts: Seconds per grid timestep for videos
|
||||
context_len: Context length
|
||||
seq_len: Sequence length
|
||||
audio_feature_lengths: Audio feature lengths for multimodal models
|
||||
use_audio_in_video: Whether to use audio in video for interleaving
|
||||
|
||||
|
||||
@ -1630,8 +1630,6 @@ class KeyeForConditionalGeneration(
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
@ -1759,6 +1757,5 @@ class KeyeForConditionalGeneration(
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@ -600,8 +600,6 @@ class KeyeVL1_5ForConditionalGeneration(
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
@ -729,6 +727,5 @@ class KeyeVL1_5ForConditionalGeneration(
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@ -1179,8 +1179,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -1293,7 +1291,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -927,8 +927,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -1125,7 +1123,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
mrope_position_delta = (
|
||||
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
|
||||
)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -1118,8 +1118,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
second_per_grid_ts: list[float],
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -1232,7 +1230,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -1240,8 +1240,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -1360,7 +1358,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.distributed import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
routing_method_type=RoutingMethodType.Renormalize,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
GemmaRMSNorm as Qwen3NextRMSNorm,
|
||||
)
|
||||
@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
routing_method_type=RoutingMethodType.Renormalize,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -1417,8 +1417,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
|
||||
@ -1419,8 +1419,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
@ -1519,7 +1517,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
|
||||
@ -371,8 +371,6 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
image_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor | None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
@ -390,7 +388,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
video_grid_thw=video_grid_thw,
|
||||
)
|
||||
|
||||
mrope_positions = mrope_positions[:, 0, context_len:seq_len]
|
||||
mrope_positions = mrope_positions[:, 0]
|
||||
mrope_position_delta = mrope_position_delta[0].item()
|
||||
|
||||
return mrope_positions, mrope_position_delta
|
||||
|
||||
@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
|
||||
and not (rocm_aiter_ops.is_pa_attn_enabled())
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
@ -202,12 +204,15 @@ class RocmPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
@ -228,19 +233,23 @@ class RocmPlatform(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
|
||||
if use_mla:
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled,
|
||||
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend."
|
||||
)
|
||||
|
||||
if use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
_Backend.ROCM_AITER_MLA
|
||||
if is_aiter_mla_enabled() or block_size == 1
|
||||
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA
|
||||
)
|
||||
|
||||
@ -265,12 +274,12 @@ class RocmPlatform(Platform):
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
|
||||
rocm_aiter_ops.is_mha_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_FA:
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
rocm_aiter_ops.is_triton_unified_attn_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return (
|
||||
|
||||
@ -198,6 +198,7 @@ from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
@ -270,28 +271,15 @@ except ImportError:
|
||||
flashinfer_available = False
|
||||
|
||||
|
||||
def is_rocm_aiter_fp8bmm_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
|
||||
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
|
||||
)
|
||||
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
def dynamic_per_batched_tensor_quant(
|
||||
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
|
||||
):
|
||||
DTYPE_MAX = torch.finfo(dtype).max
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
|
||||
scale = DTYPE_MAX / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
|
||||
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1109,6 +1097,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
def get_layer_weight(layer):
|
||||
@ -1158,7 +1147,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
@ -1187,7 +1176,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
|
||||
@ -1196,7 +1185,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
@ -1208,10 +1197,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = aiter_triton_fp8_bmm(
|
||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
@ -1571,7 +1559,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
@ -1600,7 +1588,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
|
||||
@ -1609,7 +1597,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device,
|
||||
)
|
||||
aiter_triton_fp8_bmm(
|
||||
rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
@ -1958,7 +1946,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
if self.q_pad_num_heads is not None:
|
||||
B, N, L = decode_q_pe.shape
|
||||
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
|
||||
@ -1966,9 +1953,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
decode_pe_padded.copy_(decode_q_pe)
|
||||
decode_q_pe = decode_pe_padded
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
decode_ql_nope = aiter_triton_fp8_bmm(
|
||||
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||
decode_q_nope,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
|
||||
@ -6,9 +6,8 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
@ -22,10 +21,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -284,7 +279,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
aiter_mla_decode_fwd(
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
|
||||
@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
# Schedule newly needed KV blocks for the request.
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is not None:
|
||||
# The request can be scheduled.
|
||||
break
|
||||
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
preempted_req.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||
with record_function_or_nullcontext("schedule: allocate_slots"):
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
break
|
||||
if new_blocks is not None:
|
||||
# The request can be scheduled.
|
||||
break
|
||||
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[
|
||||
preempted_req.request_id
|
||||
]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
preempted_req.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||
)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
break
|
||||
|
||||
if new_blocks is None:
|
||||
# Cannot schedule this request.
|
||||
@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
|
||||
# Record the request ids that were scheduled in this step.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
self._update_after_schedule(scheduler_output)
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _update_after_schedule(
|
||||
|
||||
@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -315,17 +316,21 @@ class EngineCore:
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
with record_function_or_nullcontext("core step: schedule"):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
with record_function_or_nullcontext("core step: execute_model"):
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
|
||||
with record_function_or_nullcontext("core step: update_from_output"):
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
@ -363,32 +368,49 @@ class EngineCore:
|
||||
model_executed = False
|
||||
deferred_scheduler_output = None
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: execute_model"
|
||||
):
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
if scheduler_output.pending_structured_output_tokens:
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
assert exec_result is None
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: pending_structured_output_tokens"
|
||||
):
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
# Block-wait for execute to return
|
||||
# (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
assert exec_result is None
|
||||
else:
|
||||
# We aren't waiting for any tokens, get any grammar output immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: get_grammar_bitmask"
|
||||
):
|
||||
# We aren't waiting for any tokens, get any grammar
|
||||
# output immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
scheduler_output
|
||||
)
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
|
||||
if exec_result is None:
|
||||
# Call sample tokens.
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: sample_tokens"
|
||||
):
|
||||
# Call sample tokens.
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
else:
|
||||
# No sampling required (e.g. all requests finished).
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
@ -408,27 +430,34 @@ class EngineCore:
|
||||
# only be called when the scheduler contains requests or the queue
|
||||
# is non-empty.
|
||||
return None, False
|
||||
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: update_from_output"
|
||||
):
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||
# in a field and do it immediately once step_with_batch_queue is
|
||||
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||
if deferred_scheduler_output:
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: deferred_scheduler_output"
|
||||
):
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -280,28 +281,32 @@ class LLMEngine:
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
outputs = self.engine_core.get_output()
|
||||
with record_function_or_nullcontext("llm_genine step: get_output"):
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
with record_function_or_nullcontext("llm_genine step: process_outputs"):
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
with record_function_or_nullcontext("llm_genine step: abort_requests"):
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.processor.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
with record_function_or_nullcontext("llm_genine step: record_stats"):
|
||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.processor.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
|
||||
@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
vocab_size: int
|
||||
printed_error: bool = False
|
||||
terminated: bool = False
|
||||
rollback_lag: int = 0
|
||||
|
||||
def check_error(self):
|
||||
if not self.printed_error:
|
||||
@ -127,6 +128,8 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
"""
|
||||
|
||||
if self.ll_tokenizer.eos_token in tokens:
|
||||
if self.ll_matcher.is_stopped() and not self.terminated:
|
||||
self.rollback_lag = 1
|
||||
self.terminated = True
|
||||
|
||||
if self.ll_matcher.is_stopped():
|
||||
@ -163,8 +166,11 @@ class GuidanceGrammar(StructuredOutputGrammar):
|
||||
return tokens[:num_tokens]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.ll_matcher.rollback(num_tokens)
|
||||
self.check_error()
|
||||
if num_tokens > 0:
|
||||
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
|
||||
self.terminated = False
|
||||
self.rollback_lag = 0
|
||||
self.check_error()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
# this will automatically return [EOS] mask if the matcher is stopped
|
||||
|
||||
@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
with record_function_or_nullcontext("Preprocess"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
|
||||
with self.synchronize_input_prep():
|
||||
# Update persistent batch states.
|
||||
self._update_states(scheduler_output)
|
||||
@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
),
|
||||
record_function_or_nullcontext("Forward"),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||
):
|
||||
model_output = self._model_forward(
|
||||
@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("Postprocess"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
# True when EAGLE 3 is used.
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
scheduler_output, grammar_output, self.input_batch, logits
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("Sample"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
with record_function_or_nullcontext("Draft"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: draft"):
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
sampled_token_ids,
|
||||
@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("Bookkeep"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
||||
(
|
||||
num_nans_in_logits,
|
||||
logprobs_lists,
|
||||
@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("EPLB"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||
self.eplb_step()
|
||||
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=req_ids_output_copy,
|
||||
req_id_to_index=req_id_to_index_output_copy,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=kv_connector_output,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=req_ids_output_copy,
|
||||
req_id_to_index=req_id_to_index_output_copy,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=kv_connector_output,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
return output
|
||||
|
||||
async_output = AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
|
||||
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
||||
# any requests with sampling params that that require output ids.
|
||||
self.input_batch.set_async_sampled_token_ids(
|
||||
async_output.sampled_token_ids_cpu,
|
||||
async_output.async_copy_ready_event,
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"gpu_model_runner: AsyncGPUModelRunnerOutput"
|
||||
):
|
||||
async_output = AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"gpu_model_runner: set_async_sampled_token_ids"
|
||||
):
|
||||
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
||||
# any requests with sampling params that that require output ids.
|
||||
self.input_batch.set_async_sampled_token_ids(
|
||||
async_output.sampled_token_ids_cpu,
|
||||
async_output.async_copy_ready_event,
|
||||
)
|
||||
|
||||
return async_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user