Compare commits

...

13 Commits

Author SHA1 Message Date
30700b1cd7 [CI] Fix Plugin Tests Tests (#28413)
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
2025-11-10 22:36:11 +00:00
4b94ed8f92 [Frontend][2/n] remove empty content from _parse_tool_calls_from_content (#28331)
Signed-off-by: Andrew Xia <axia@fb.com>
Co-authored-by: Andrew Xia <axia@fb.com>
2025-11-10 14:07:49 -08:00
6dec9f6109 [BugFix] Fix DeepGEMM over-allocating workspace (#28254)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
2025-11-10 17:01:17 -05:00
bf6a3d0ff5 [Misc] Add more scoping for improved trace (#28329)
Signed-off-by: Wei Wei <wwei6@meta.com>
2025-11-10 21:03:21 +00:00
40d33264c6 [Bugfix][EPLB] Disabled shared expert overlap when EPLB is enabled (#28377)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Sage Moore <sagemoore@utexas.edu>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
2025-11-10 20:39:19 +00:00
9c84ca8293 [FA/Chore] Bump FA version for FP8 two-level accumulation (#27889)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
2025-11-10 12:06:04 -08:00
6d54336ae5 [Bugfix] Fix llguidance backend, rollback when EOS was encountered (#25905)
Signed-off-by: Rémi Delacourt <remi@mistral.ai>
Signed-off-by: remi <remi@mistral.ai>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
2025-11-10 14:53:32 -05:00
34553b9d27 [Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
2025-11-10 12:34:57 -05:00
b039bfda8f [Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2025-11-10 09:21:52 -08:00
d0e186c16f [V0 Deprecation] Remove unused context_len and seq_len from M-RoPE (#28395)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-11-11 00:30:06 +08:00
f080a83511 [RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-11-10 08:20:53 -08:00
40e2eeeb92 [Kernel] Optimization of the mm_k operator. (#28280)
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
2025-11-10 16:03:46 +00:00
b06b9470ca [Rocm][fused_moe][fp4] view weight to torch.float4_e2m1fn_x2 when running aiter fused moe for fp4 model (#27474)
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
2025-11-10 10:38:56 -05:00
58 changed files with 1697 additions and 1186 deletions

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG a893712401d70362fbb299cd9c4b3476e8e9ed54
GIT_TAG 8e1b01d56210dc72030a2d0d41c2d8d266ba6309
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -578,11 +578,13 @@ void persistent_masked_m_silu_mul_quant(
// This kernel currently only supports H % 128 == 0 and assumes a
// fixed GROUP_SIZE of 128.
static constexpr int GROUP_SIZE = 128;
TORCH_CHECK(input.dtype() == torch::kBFloat16);
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
TORCH_CHECK(input.size(-1) % 256 == 0);
TORCH_CHECK(input.size(-1) % (GROUP_SIZE * 2) == 0);
using Idx_t = int64_t;
@ -601,8 +603,6 @@ void persistent_masked_m_silu_mul_quant(
Idx_t stride_counts_e = tokens_per_expert.stride(0);
static constexpr int GROUP_SIZE = 128;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
@ -628,21 +628,26 @@ void persistent_masked_m_silu_mul_quant(
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
int const NUM_GROUPS = H / GROUP_SIZE;
if (!use_ue8m0) {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
}
} else {
if (H >= 4096) {
if (H >= 4096 && (NUM_GROUPS % 8 == 0)) {
/* 8 warps config */
static constexpr int NUM_STAGES = 4;
static constexpr int THREAD_COUNT = 256;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
} else {
/* 1 warp config */
static constexpr int THREAD_COUNT = 32;
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
}

View File

@ -97,7 +97,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_moe_impl] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched<sup>4</sup> | batched | int8,</br>fp8 | G,A,T | silu, gelu | <sup>6</sup> | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |

View File

@ -6,6 +6,8 @@ Run `pytest tests/kernels/test_moe.py`.
"""
import functools
import importlib
import sys
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
@ -20,6 +22,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.moe.utils import fused_moe
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
@ -412,14 +415,12 @@ def test_mixtral_moe(
huggingface."""
# clear the cache before every test
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
# Force reload aiter_ops to pick up the new environment variables.
if "rocm_aiter_ops" in sys.modules:
importlib.reload(rocm_aiter_ops)
is_rocm_aiter_moe_enabled.cache_clear()
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")

View File

@ -25,6 +25,7 @@ CASES = [
(8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype),
(8, 64, 7168, fp8_dtype),
(8, 128, 128 * 33, fp8_dtype),
(8, 128, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype),
@ -54,8 +55,10 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type):
)
# Run the SiLU V2 kernel
# TODO (varun): use_e8m0 is set to false as the reference impl does
# not handle that case.
y_q, y_s = persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, group_size=group_size
y, tokens_per_expert, group_size=group_size, use_ue8m0=False
)
torch.cuda.synchronize()

View File

@ -4,6 +4,7 @@
import pytest
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (
@ -15,9 +16,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_topk_func,
vllm_topk_softmax,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import (
RMSNorm,
dispatch_rocm_rmsnorm_func,
@ -126,50 +124,39 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func()
is_rocm_aiter_moe_enabled.cache_clear()
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax,
)
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
def test_topk_dispatch(use_rocm_aiter: bool):
topk_func = dispatch_topk_func(use_rocm_aiter)
assert topk_func == rocm_aiter_topk_softmax
if current_platform.is_rocm() and use_rocm_aiter:
assert topk_func == rocm_aiter_ops.topk_softmax
else:
assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
)
def test_rms_norm_dispatch(
add_residual: bool,
dtype: torch.dtype,
use_rocm_aiter: str,
use_rocm_aiter_norm: str,
monkeypatch,
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
should_use_rocm_aiter = (
current_platform.is_rocm()
and int(use_rocm_aiter)
and int(use_rocm_aiter_norm)
and use_rocm_aiter
and dtype in RMS_NORM_SUPPORTED_DTYPES
)
if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
assert rms_norm_func == rocm_aiter_ops.rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:

View File

@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
from vllm.config import StructuredOutputsConfig, VllmConfig
from vllm.config.model import ModelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import StructuredOutputOptions
TOKENIZER = "gpt2"
def test_backend_guidance_rollback_terminated():
# Test that the backend guidance successfully rollbacks from a
# terminated state. This can happen with speculative decoding,
# where the draft model proposes EOS and it is verified by the
# guidance backend. In that case we are in a stopped state, but
# it should be reverted in case EOS is not accepted by the target
# model.
vllm_config = VllmConfig(
decoding_config=StructuredOutputsConfig(
backend="guidance",
)
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
backend = GuidanceBackend(
vllm_config,
tokenizer=tokenizer,
vocab_size=50257,
)
grammar = backend.compile_grammar(
StructuredOutputOptions.JSON, '{"type": "object"}'
)
prompt = tokenizer.encode('{"a": "b"}')
assert len(prompt) > 1
dummy_wrong = tokenizer.encode('{"a"}')
for token in prompt:
assert grammar.accept_tokens("", [token])
assert not grammar.is_terminated()
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
assert grammar.is_terminated()
# Giving any other token should also be accepted
assert grammar.accept_tokens("", dummy_wrong)
# Rollback is done from where state was terminated, so from '}' not EOS
grammar.rollback(len(prompt) - 1)
assert not grammar.is_terminated()
assert grammar.validate_tokens([tokenizer.eos_token_id]) == []
assert grammar.validate_tokens(dummy_wrong) != dummy_wrong
assert grammar.accept_tokens("", prompt[1:])
assert not grammar.is_terminated()
assert grammar.accept_tokens("", [tokenizer.eos_token_id])
assert grammar.is_terminated()
# Rollback of <= 0 should not change the terminated state
grammar.rollback(0)
assert grammar.is_terminated()
grammar.rollback(-1)
assert grammar.is_terminated()
def test_grammar_bitmask_with_specdec():
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
prompt = tokenizer.encode('{"a": "b"}')
vllm_config = VllmConfig(
model_config=ModelConfig(tokenizer=TOKENIZER),
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
speculative_config=SpeculativeConfig(model="[ngram]", num_speculative_tokens=3),
)
structured_output_manager = StructuredOutputManager(vllm_config)
for i in range(1, 2):
sampling_params = SamplingParams(
structured_outputs=StructuredOutputsParams(
json='{"type": "object"}',
),
)
sampling_params.structured_outputs._backend = "guidance"
my_req_id = f"my_req_id_{i}"
request = Request(
my_req_id,
prompt_token_ids=prompt[:i],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=tokenizer.eos_token_id,
)
structured_output_manager.grammar_init(request)
def grammar_bitmask(req: Request, tokens: list[int]) -> None:
structured_output_manager.grammar_bitmask(
requests={req.request_id: req},
structured_output_request_ids={req.request_id: 0},
scheduled_spec_decode_tokens={req.request_id: tokens},
)
# At this point, we rolled-back, so should not be terminated
assert not req.structured_output_request.grammar.is_terminated()
# The grammar might not yet be compiled, so we wait for it
while not request.structured_output_request._check_grammar_completion():
continue
assert request.structured_output_request.grammar.accept_tokens(
request.request_id, prompt[:i]
)
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
grammar_bitmask(
request, prompt[i:] + [tokenizer.eos_token_id] + prompt
) # EOS not the final token
grammar_bitmask(request, prompt[i:]) # EOS not present
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])

941
vllm/_aiter_ops.py Normal file
View File

@ -0,0 +1,941 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def is_aiter_found() -> bool:
from importlib.util import find_spec
return find_spec("aiter") is not None
# `find_spec` is not torch.compile compatible.
# In cases where aiter availability might have
# been checked in forward passes that are torch compiled.
# we keep this global outside to not cause torch compile breaks.
IS_AITER_FOUND = is_aiter_found()
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported on gfx9 archs.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existance.
from vllm.platforms.rocm import on_gfx9
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
return func(*args, **kwargs)
else:
# Return None or do nothing if not supported
return None
return wrapper
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def _rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def _rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def _rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def _rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def _rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def _rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def _rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
is_softmax = scoring_func == "softmax"
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
is_softmax,
routed_scaling_factor,
)
def _rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def _rocm_aiter_mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
def _rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def _rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_gemm_w8a8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
def _rocm_aiter_gemm_w8a8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
def _rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from aiter import rms_norm
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rms_norm(x, weight, variance_epsilon)
def _rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def _rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter import rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out
def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
class rocm_aiter_ops:
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
"""Verifies device specs and availability of aiter main env variable."""
return cls._AITER_ENABLED
@classmethod
@if_aiter_supported
def is_linear_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._LINEAR_ENABLED
@classmethod
@if_aiter_supported
def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
@classmethod
@if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod
@if_aiter_supported
def is_fused_moe_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._FMOE_ENABLED
@classmethod
@if_aiter_supported
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED
@classmethod
@if_aiter_supported
def is_mla_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MLA_ENABLED
@classmethod
@if_aiter_supported
def is_mha_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._MHA_ENABLED
@classmethod
@if_aiter_supported
def is_pa_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._PG_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_triton_unified_attn_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_UNIFIED_ATTN_ENABLED
@classmethod
@if_aiter_supported
def is_fp8bmm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP8BMM_ENABLED
@classmethod
@if_aiter_supported
def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._FP4_GEMM_DYNAMIC_QUANT_ASM
@classmethod
@if_aiter_supported
def is_triton_rotary_embed_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
tags = (
tuple()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=_rocm_aiter_asm_moe_tkw1_impl,
mutates_args=[],
fake_impl=_rocm_aiter_asm_moe_tkw1_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=_rocm_aiter_fused_moe_impl,
mutates_args=[],
fake_impl=_rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=_rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=_rocm_aiter_topk_softmax_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=_rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=_rocm_aiter_biased_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=_rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=_rocm_aiter_grouped_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
tags=tags,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=_rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
def rms_norm2d_with_add(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add(
x, residual, weight, variance_epsilon
)
@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
@staticmethod
def gemm_w8a8(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def gemm_w8a8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
A, B, As, Bs, output_dtype
)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
quant_method: int = 0,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation_method,
quant_method,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
@staticmethod
def asm_moe_tkw1(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = 0,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale,
fc2_scale,
fc1_smooth_scale,
fc2_smooth_scale,
a16,
per_tensor_quant_scale,
expert_mask,
activation_method,
)
@staticmethod
def topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
@staticmethod
def biased_grouped_topk(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0,
) -> None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
@staticmethod
def grouped_topk(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> None:
torch.ops.vllm.rocm_aiter_grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
@staticmethod
def mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = torch.bfloat16,
x_scales: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
)
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
@staticmethod
def triton_rotary_embed(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_2c_fwd_inplace(
positions,
sin,
cos,
query_,
key_,
rotate_style,
reuse_freqs_front_part=True,
is_nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)
@staticmethod
def triton_fp8_bmm(
X: torch.Tensor,
WQ: torch.Tensor,
w_scale: torch.Tensor,
group_size: int = 128,
bias: torch.Tensor | None = None,
dtype: torch.dtype | None = torch.bfloat16,
splitK: int | None = None,
YQ: torch.Tensor | None = None,
transpose_bm: bool | None = False,
config: dict | None = None,
) -> torch.Tensor:
# ruff: noqa: E501 # isort: skip
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import (
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm,
)
return aiter_triton_fp8_bmm(
X,
WQ,
w_scale,
group_size=group_size,
bias=bias,
dtype=dtype,
splitK=splitK,
YQ=YQ,
transpose_bm=transpose_bm,
config=config,
)
@staticmethod
def triton_gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
@staticmethod
def per_1x128_fp8_quant(
input_2d: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""Only applies quantization method for fp8 data type only."""
from aiter import QuantType, dtypes, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
@staticmethod
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight
return shuffle_weight(tensor, layout=layout)
@staticmethod
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
rocm_aiter_ops.register_ops_once()

View File

@ -1,105 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata(
max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device
) -> tuple[torch.Tensor, ...]:
paged_kv_indices = torch.zeros(
max_batch_size * max_block_per_batch, dtype=torch.int32, device=device
)
paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device)
paged_kv_last_page_lens = torch.full(
(max_batch_size,), block_size, dtype=torch.int32
)
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
def aiter_mla_decode_fwd(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
sm_scale: float,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
logit_cap: float = 0.0,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
max_seqlen_qo,
kv_indptr,
kv_indices,
kv_last_page_lens,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
from aiter.mla import mla_decode_fwd
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
o,
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
def mla_decode_fwd_fake(
q: torch.Tensor,
kv_buffer: torch.Tensor,
o: torch.Tensor,
qo_indptr: torch.Tensor,
max_seqlen_qo: int,
kv_indptr: torch.Tensor | None = None,
kv_indices: torch.Tensor | None = None,
kv_last_page_lens: torch.Tensor | None = None,
sm_scale: float = 1.0,
logit_cap: float = 0.0,
) -> None:
pass
if current_platform.is_rocm():
if is_torch_equal_or_newer("2.7.0"):
tags = ()
else:
tags = ((torch.Tag.needs_fixed_stride_order,),)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=mla_decode_fwd_fake,
tags=tags,
)

View File

@ -608,17 +608,19 @@ class VllmConfig:
)
current_platform.check_and_update_config(self)
assert (
self.parallel_config.dcp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be "
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
)
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
assert (
self.parallel_config.dcp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
)
assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1

View File

@ -1375,6 +1375,8 @@ class OpenAIServing:
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
else:
# No tool calls.
return None, content

View File

@ -109,7 +109,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
@ -926,8 +926,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Whether to use aiter rope.
# By default is disabled.
"VLLM_ROCM_USE_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1")
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
@ -1589,7 +1589,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_AITER_TRITON_GEMM",

View File

@ -23,6 +23,7 @@ def mm_k(
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
USE_GDC: tl.constexpr,
base_k,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@ -47,32 +48,62 @@ def mm_k(
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K
# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)
for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k
# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K
if EVEN_K:
# pre-fetech lora weight
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
tiled_b = tl.load(
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
# Check if we need element-wise masking
if iter_k >= K:
# Entire block out of range, skip
pass
elif block_end <= K:
# Entire block in range, no masking needed (fast path)
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Partial block, need masking (only last iteration)
k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride
return accumulator
@ -178,6 +209,7 @@ def do_expand_kernel(
CAST_TYPE,
cur_lora_ptr.dtype.element_ty,
USE_GDC,
base_k=0,
)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
@ -284,6 +316,7 @@ def do_shrink_kernel(
False,
cur_lora_ptr.dtype.element_ty,
False, # USE_GDC is always False in shrink kernel
base_k=pid_sk * BLOCK_K,
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
if USE_GDC:

View File

@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
device=y.device,
)
use_ue8m0 = is_deep_gemm_e8m0_used()
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import IntEnum
from typing import Optional, Union
import torch
@ -91,6 +92,26 @@ def _quant_flags_to_group_shape(
return a_shape, w_shape
# The type of method in top-K routing
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups
# -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK: TopK (no softmax)
TopK = (5,)
# Unspecified
Unspecified = 6.0
@dataclass
class FusedMoEQuantDesc:
"""

View File

@ -215,7 +215,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace1 = (M_sum, N)
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output)

View File

@ -3,6 +3,7 @@
import torch
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim,
@ -23,26 +24,24 @@ def flashinfer_fused_moe_blockscale_fp8(
w2_weight_scale_inv: torch.Tensor,
global_num_experts: int,
top_k: int,
num_expert_group: int,
topk_group: int,
num_expert_group: int | None,
topk_group: int | None,
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routed_scaling: float = 1.0,
routing_method_type: int = RoutingMethodType.DeepSeekV3,
routed_scaling: float | None = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
topk_group = topk_group if topk_group is not None else 0
assert top_k <= global_num_experts
assert top_k <= 8
assert topk_group <= 4
assert global_num_experts > num_expert_group
assert global_num_experts % num_expert_group == 0
assert top_k <= 10
assert global_num_experts % 4 == 0
assert top_k < (topk_group * global_num_experts / num_expert_group)
assert block_shape == [128, 128]
# Routing kernel expects #experts <= #threads 256
assert global_num_experts <= 256
# Routing kernel expects #experts <= #threads 512
assert global_num_experts <= 512
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
# NOTE: scales of hidden states have to be transposed!
@ -64,10 +63,8 @@ def flashinfer_fused_moe_blockscale_fp8(
local_expert_offset=expert_offset,
local_num_experts=local_num_experts,
routed_scaling_factor=routed_scaling,
tile_tokens_dim=calculate_tile_tokens_dim(
x.shape[0], top_k, global_num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
tile_tokens_dim=None,
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)
@ -88,6 +85,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
expert_offset: int,
local_num_experts: int,
block_shape: list[int],
routing_method_type: int,
routed_scaling: float = 1.0,
) -> torch.Tensor:
return torch.empty_like(x)

View File

@ -14,6 +14,7 @@ import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
@ -55,8 +56,6 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
@ -1089,11 +1088,11 @@ def vllm_topk_softmax(
return topk_weights, topk_indices
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
def dispatch_topk_func(
use_rocm_aiter: bool = False,
) -> Callable[..., tuple[torch.Tensor, ...]]:
if use_rocm_aiter:
return rocm_aiter_ops.topk_softmax
return vllm_topk_softmax
@ -1121,7 +1120,7 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_func = dispatch_topk_func()
topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled())
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)

View File

@ -13,6 +13,7 @@ import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
@ -30,6 +31,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
biased_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton
@ -41,8 +43,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
init_aiter_topK_meta_data,
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator
from vllm.model_executor.layers.quantization.base_config import (
@ -92,13 +92,11 @@ else:
return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk,
)
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk_aiter,
)
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
@ -463,7 +461,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@ -620,13 +619,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
# Padding the weight for better performance on ROCm
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
shuffle_weights,
)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@ -1002,6 +997,7 @@ def determine_expert_map(
global_num_experts: int,
expert_placement_strategy: ExpertPlacementStrategy = "linear",
num_fused_shared_experts: int = 0,
return_expert_mask: bool = False,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
"""
Calculates how many experts should be assigned to each rank for EP and
@ -1064,7 +1060,7 @@ def determine_expert_map(
)
expert_mask = None
if is_rocm_aiter_moe_enabled():
if return_expert_mask:
expert_mask = torch.ones(
(global_num_experts + num_fused_shared_experts + 1,), dtype=torch.int32
)
@ -1218,6 +1214,7 @@ class FusedMoE(CustomOp):
zero_expert_type: str | None = None,
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
):
super().__init__()
@ -1292,14 +1289,18 @@ class FusedMoE(CustomOp):
self.logical_replica_count: torch.Tensor | None = None
# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.aiter_fmoe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
self.num_fused_shared_experts = (
n_shared_experts
if n_shared_experts is not None
and is_rocm_aiter_fusion_shared_expert_enabled()
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
if (
not is_rocm_aiter_fusion_shared_expert_enabled()
not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0
):
raise ValueError(
@ -1346,6 +1347,7 @@ class FusedMoE(CustomOp):
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
@ -1397,6 +1399,24 @@ class FusedMoE(CustomOp):
"Only softmax scoring function is supported for non-grouped topk."
)
# ToDo: Better logic to determine the routing method type
if routing_method_type is not None:
self.routing_method_type = routing_method_type
else:
if scoring_func == "sigmoid":
if self.use_grouped_topk:
self.routing_method_type = RoutingMethodType.DeepSeekV3
elif self.top_k == 1:
self.routing_method_type = RoutingMethodType.Llama4
elif self.scoring_func == "softmax":
self.routing_method_type = (
RoutingMethodType.Renormalize
if not self.renormalize
else RoutingMethodType.RenormalizeNaive
)
else:
self.routing_method_type = RoutingMethodType.TopK
self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
@ -1570,13 +1590,16 @@ class FusedMoE(CustomOp):
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(), dp_size=get_dp_group().world_size
)
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(),
dp_size=get_dp_group().world_size,
)
def _load_per_tensor_weight_scale(
self,
@ -1753,20 +1776,19 @@ class FusedMoE(CustomOp):
def _init_aiter_shared_experts_topK_buffer(
self, vllm_config: VllmConfig, dp_size: int
):
if is_rocm_aiter_fusion_shared_expert_enabled():
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
if self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
n_shared_experts=self.num_fused_shared_experts,
top_k=self.top_k,
tp_rank=self.ep_rank if self.use_ep else self.tp_rank,
tp_size=self.ep_size if self.use_ep else self.tp_size,
shared_experts_score=1.0,
max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens
* dp_size,
is_EP=self.use_ep,
)
self.local_num_experts += self.num_fused_shared_experts
@overload
def weight_loader(
@ -2208,15 +2230,16 @@ class FusedMoE(CustomOp):
elif use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if is_rocm_aiter_moe_enabled():
if not is_rocm_aiter_fusion_shared_expert_enabled():
if rocm_aiter_ops.is_fused_moe_enabled():
if not rocm_aiter_ops.is_fusion_moe_shared_experts_enabled():
assert num_fused_shared_experts == 0
grouped_topk_impl = partial(
grouped_topk_aiter,
rocm_aiter_grouped_topk,
num_fused_shared_experts=num_fused_shared_experts,
)
else:
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states,
gating_output=router_logits,
@ -2448,7 +2471,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
@ -2612,7 +2635,7 @@ class FusedMoE(CustomOp):
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map
if not is_rocm_aiter_moe_enabled()
if not self.rocm_aiter_fmoe_enabled
else self.expert_mask,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,

View File

@ -1,17 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import IntEnum
from functools import cache, lru_cache
from functools import lru_cache
import torch
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum):
@ -37,27 +35,6 @@ class ActivationMethod(IntEnum):
GELU = 1
@cache
def is_rocm_aiter_moe_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_MOE
and envs.VLLM_ROCM_USE_AITER
)
@cache
def use_mxfp4_aiter_moe() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS and is_rocm_aiter_moe_enabled()
)
aiter_topK_meta_data = None
@ -114,250 +91,6 @@ def init_aiter_topK_meta_data(
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)
def rocm_aiter_asm_moe_tkw1_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = ActivationType(activation_method)
return asm_moe_tkw1(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
per_tensor_quant_scale=per_tensor_quant_scale,
expert_mask=expert_mask,
activation=activation,
)
def rocm_aiter_asm_moe_tkw1_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: torch.Tensor | None = None,
fc2_scale: torch.Tensor | None = None,
fc1_smooth_scale: torch.Tensor | None = None,
fc2_smooth_scale: torch.Tensor | None = None,
a16: bool = False,
per_tensor_quant_scale: torch.Tensor | None = None,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_topk_softmax_impl(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
from aiter import topk_softmax
topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
def rocm_aiter_topk_softmax_fake(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> None:
pass
def rocm_aiter_biased_grouped_topk_impl(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import biased_grouped_topk
biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
routed_scaling_factor,
)
def rocm_aiter_biased_grouped_topk_fake(
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_grouped_topk_impl(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
from aiter import grouped_topk
grouped_topk(
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
need_renorm,
scoring_func,
routed_scaling_factor,
)
def rocm_aiter_grouped_topk_fake(
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_expert_group: int,
topk_group: int,
need_renorm: bool,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, # mul to topk_weights
) -> None:
pass
def rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
)
def rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: torch.Tensor | None = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
op_func=rocm_aiter_asm_moe_tkw1_impl,
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fused_moe_impl,
fake_impl=rocm_aiter_fused_moe_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_topk_softmax",
op_func=rocm_aiter_topk_softmax_impl,
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
fake_impl=rocm_aiter_topk_softmax_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_biased_grouped_topk",
op_func=rocm_aiter_biased_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_biased_grouped_topk_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_grouped_topk",
op_func=rocm_aiter_grouped_topk_impl,
mutates_args=["topk_weights", "topk_ids"],
fake_impl=rocm_aiter_grouped_topk_fake,
)
def rocm_aiter_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@ -372,7 +105,10 @@ def rocm_aiter_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]:
token = hidden_states.shape[0]
device = hidden_states.device
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
assert aiter_topK_meta_data is not None, (
"AITER topK meta data is not initialized. "
"Please ensure that init_aiter_topK_meta_data "
@ -397,7 +133,7 @@ def rocm_aiter_grouped_topk(
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
if e_score_correction_bias is not None:
torch.ops.vllm.rocm_aiter_biased_grouped_topk(
rocm_aiter_ops.biased_grouped_topk(
gating_output,
e_score_correction_bias.to(gating_output.dtype),
topk_weights,
@ -409,7 +145,7 @@ def rocm_aiter_grouped_topk(
)
else:
assert scoring_func == "softmax" or scoring_func == "sigmoid"
torch.ops.vllm.rocm_aiter_grouped_topk(
rocm_aiter_ops.grouped_topk(
gating_output,
topk_weights,
topk_ids,
@ -420,7 +156,10 @@ def rocm_aiter_grouped_topk(
routed_scaling_factor=routed_scaling_factor,
)
if is_rocm_aiter_fusion_shared_expert_enabled() and num_fused_shared_experts > 0:
if (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
and num_fused_shared_experts > 0
):
return total_topk_weights, total_topk_ids
return topk_weights, topk_ids
@ -464,7 +203,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
return rocm_aiter_ops.asm_moe_tkw1(
hidden_states,
w1,
w2,
@ -482,7 +221,9 @@ def rocm_aiter_fused_experts(
else:
quant_method = QuantMethod.NO.value
# quark moe for mxfp4 w_dtype
if quant_config.use_mxfp4_w4a16:
quant_method = QuantMethod.BLOCK_1X32.value
# w8a8 block-scaled
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
assert not apply_router_weight_on_input, (
@ -507,7 +248,7 @@ def rocm_aiter_fused_experts(
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
return torch.ops.vllm.rocm_aiter_fused_moe(
return rocm_aiter_ops.fused_moe(
hidden_states,
w1,
w2,
@ -522,39 +263,3 @@ def rocm_aiter_fused_experts(
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input,
)
def rocm_aiter_topk_softmax(
topk_weights: torch.Tensor,
topk_indices: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_indices
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
"""
Applies shuffle_weight function from AITER to each
input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args:
*tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the block sizes used to divide
the tensors during shuffling. Default is (16, 16).
Returns:
A Tuple of shuffled tensors.
"""
from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)

View File

@ -28,13 +28,18 @@ class SharedFusedMoE(FusedMoE):
super().__init__(**kwargs)
self._shared_experts = shared_experts
# Disable shared expert overlap if we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
# Disable shared expert overlap if we are using eplb, because of
# correctness issues, or if using flashinfer with DP, since there
# is nothing to be gained in this case. Disabling the overlap
# optimization also prevents the shared experts from being hidden
# from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
and not (
# TODO(wentao): find the root cause and remove this condition
self.enable_eplb
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None
)

View File

@ -6,18 +6,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER
def rms_norm(
@ -58,80 +53,34 @@ def fused_add_rms_norm(
return x, residual
def rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
def poly_norm(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
import aiter as rocm_aiter
from vllm import _custom_ops as ops
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return output, residual_out
return out
def rocm_aiter_rms_norm_fake(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
return torch.empty_like(x)
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
fake_impl=rocm_aiter_rms_norm_fake,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
)
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
def dispatch_rocm_rmsnorm_func(
with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
use_aiter = use_aiter and dtype in [
torch.float16,
torch.bfloat16,
]
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
return rocm_aiter_ops.rms_norm2d_with_add
if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm
return rocm_aiter_ops.rms_norm
# fall back to CUDA implementation
if with_fused_add:
@ -169,11 +118,14 @@ class RMSNorm(CustomOp):
self.weight = nn.Parameter(self.weight)
if current_platform.is_rocm():
aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype
with_fused_add=False,
dtype=weight_dtype,
use_aiter=aiter_rmsnorm_enabled,
)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype
with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
)
@staticmethod

View File

@ -12,6 +12,7 @@ from compressed_tensors.quantization import ActivationOrdering, QuantizationStra
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@ -582,11 +583,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# cutlass path
self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100(
@ -829,12 +827,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)

View File

@ -7,12 +7,12 @@ import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@ -61,7 +61,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@ -27,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
@ -56,7 +58,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
check_aiter_fp8_linear_support,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
@ -369,7 +370,7 @@ class Fp8LinearMethod(LinearMethodBase):
if vllm_is_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
@ -869,12 +870,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
shuffle_weights,
)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
# TODO (rob): refactor block quant into separate class.
if self.block_quant:
@ -916,7 +913,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@ -962,7 +959,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@ -1042,7 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight, layer.w2_weight
)
@ -1226,22 +1223,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
assert activation == "silu", (
f"Expected 'silu' activation but got {activation}"
)
assert scoring_func == "sigmoid", (
f"Expected 'sigmoid' scoring func but got {scoring_func}"
)
if self.block_quant:
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (
renormalize and use_grouped_topk and custom_routing_function is None
)
e_score_correction_bias = (
e_score_correction_bias.to(x.dtype)
if e_score_correction_bias is not None
else None
)
routing_method_type = layer.routing_method_type
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_logits=router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits,
routing_bias=e_score_correction_bias,
x=x,
w13_weight=layer.w13_weight,
@ -1256,6 +1251,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
block_shape=self.weight_block_size,
routing_method_type=routing_method_type,
routed_scaling=routed_scaling_factor,
)
else:

View File

@ -4,54 +4,14 @@
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
fake_impl=rocm_aiter_gemm_w8a8_fake,
)
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
@ -75,7 +35,7 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
+ "installed on ROCm.",
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER):
if not (rocm_aiter_ops.is_linear_enabled()):
return (
False,
"AiterScaledMMLinearKernel is disabled. "
@ -157,6 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return torch.ops.vllm.rocm_aiter_gemm_w8a8(
x_q, w_q.t(), x_s, w_s, bias, out_dtype
)
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@ -8,6 +8,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@ -21,10 +22,6 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
use_mxfp4_aiter_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
)
@ -122,7 +119,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
if current_platform.is_rocm():
self.use_marlin = False
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
def create_weights(
self,
@ -309,12 +306,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
# Property to determine if AITER is used
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
shuffle_weights,
)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
@ -458,6 +451,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp")
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype
@ -469,13 +463,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"not implemented. Please open an issue."
)
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.emulate = not current_platform.supports_mx() or not (
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
)
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
@ -581,6 +577,17 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
if self.fp4_dtype is not None:
layer.w13_weight = torch.nn.Parameter(
layer.w13_weight.view(self.fp4_dtype),
requires_grad=layer.w13_weight.requires_grad,
)
layer.w2_weight = torch.nn.Parameter(
layer.w2_weight.view(self.fp4_dtype),
requires_grad=layer.w2_weight.requires_grad,
)
torch.cuda.empty_cache()
def get_fused_moe_quant_config(
@ -644,28 +651,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
if not self.emulate:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
aiter_acts = {
ActivationType.No.name.lower(): ActivationType.No,
ActivationType.Silu.name.lower(): ActivationType.Silu,
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
}
assert activation in aiter_acts, (
f"Aiter CK fp4 MoE doesn't support activation {activation}"
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
out = fused_moe(
out = rocm_aiter_fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=aiter_acts[activation],
doweight_stage1=False,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
quant_config=self.moe_quant_config,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts

View File

@ -31,6 +31,13 @@ from .quark_scheme import QuarkScheme
logger = init_logger(__name__)
# TODO: move registration of custom op to aiter_ops.py
# `from vllm._aiter_ops import rocm_aiter_ops`
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
# for envs checks which does not require @cache anymore.
# triton kernel is torch compile compatible.
# does not require direct registeration.
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return (

View File

@ -27,20 +27,25 @@ class FlashinferMoeBackend(Enum):
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
from flashinfer import next_positive_power_of_2
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
# with the necessary kernels is released.
tile_tokens_dim = 8
# from flashinfer import next_positive_power_of_2
# # Guess tokens per expert assuming perfect expert distribution first.
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
# # And pad the number to the next power of 2.
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
# # kernel.
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
# A factor considering tokens are not perfectly balanced among experts.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
# Cap to 8-max_tile_tokens_dim tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim

View File

@ -12,6 +12,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@ -68,78 +69,6 @@ def cutlass_scaled_mm(
)
def rocm_aiter_gemm_w8a8_blockscale_impl(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake(
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = input_2d.shape[0]
n = weight.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
@ -385,14 +314,40 @@ class W8A8BlockFp8LinearOp:
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
input_2d,
weight,
input_scale,
weight_scale,
self.act_quant_group_shape.col,
input_2d.dtype,
)
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
# MI350 case uses triton kernel
if (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
):
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
self.act_quant_group_shape.col,
column_major_scales=False,
use_ue8m0=False,
)
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
# MI300 uses tuned AITER ASM/C++ kernel
else:
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
return rocm_aiter_ops.gemm_w8a8_blockscale(
q_input,
weight,
input_scale,
weight_scale,
input_2d.dtype,
)
def _run_triton(
self,
@ -971,15 +926,6 @@ def requant_weight_ue8m0_inplace(
s_old.copy_(s_requant)
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm for MI3XX"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
)
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""

View File

@ -472,7 +472,7 @@ class Fp8LinearOp:
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)

View File

@ -4,13 +4,10 @@
import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .rocm_aiter_rope_ops import (
is_rocm_triton_rotary_embedding_enabled,
rocm_aiter_rotary_emb,
)
@CustomOp.register("rotary_embedding")
@ -48,8 +45,8 @@ class RotaryEmbeddingBase(CustomOp):
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embedding_enabled = (
is_rocm_triton_rotary_embedding_enabled()
self.is_rocm_triton_rotary_embed_enabled = (
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
@ -169,9 +166,9 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embedding_enabled:
if self.is_rocm_triton_rotary_embed_enabled:
self._match_cos_sin_cache_dtype(query)
rocm_aiter_rotary_emb(
rocm_aiter_ops.triton_rotary_embed(
positions,
query,
key,

View File

@ -146,6 +146,15 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
key = key_rot
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
positions: torch.Tensor,

View File

@ -1,94 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_TRITON_ROPE
)
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter.ops.triton.rope as ops
ops.rope_cached_thd_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
if is_rocm_triton_rotary_embedding_enabled():
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_rotary_emb(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
positions,
sin,
cos,
query_,
key_,
rotate_style,
False,
)
query = query.view(query_shape)
key = key.view(key_shape)

View File

@ -33,6 +33,7 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
@ -50,10 +51,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@ -294,10 +291,8 @@ class DeepseekV2MoE(nn.Module):
self.physical_expert_start + self.n_local_physical_experts
)
if (
config.n_shared_experts is None
or is_rocm_aiter_fusion_shared_expert_enabled()
):
self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
if config.n_shared_experts is None or self.is_rocm_aiter_moe_enabled:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
@ -330,14 +325,14 @@ class DeepseekV2MoE(nn.Module):
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
if not is_rocm_aiter_moe_enabled()
if not self.is_rocm_aiter_moe_enabled
else self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
n_shared_experts=config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
else None,
)
@ -371,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
if hidden_states.dtype != torch.float16:
if not is_rocm_aiter_moe_enabled():
if not self.is_rocm_aiter_moe_enabled:
final_hidden_states *= self.routed_scaling_factor
elif self.shared_experts is not None:
assert shared_output is not None
@ -1428,6 +1423,9 @@ class DeepseekV2ForCausalLM(
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
)
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
@ -1456,7 +1454,7 @@ class DeepseekV2ForCausalLM(
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
if rocm_aiter_moe_shared_expert_enabled
else 0
),
num_redundant_experts=self.num_redundant_experts,
@ -1472,9 +1470,8 @@ class DeepseekV2ForCausalLM(
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
is_fuse_shared_experts_layer = rocm_aiter_moe_shared_expert_enabled and (
"mlp.shared_experts" in name
)
for param_name, weight_name, shard_id in stacked_params_mapping:

View File

@ -1435,8 +1435,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
@ -1569,7 +1567,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta

View File

@ -1622,8 +1622,6 @@ class Glm4vForConditionalGeneration(
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -1754,7 +1752,6 @@ class Glm4vForConditionalGeneration(
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta

View File

@ -625,8 +625,6 @@ class GLM4VForCausalLM(
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
@ -758,7 +756,6 @@ class GLM4VForCausalLM(
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta

View File

@ -995,8 +995,6 @@ class SupportsMRoPE(Protocol):
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -1012,8 +1010,6 @@ class SupportsMRoPE(Protocol):
image_grid_thw: Image grid dimensions (t, h, w)
video_grid_thw: Video grid dimensions (t, h, w)
second_per_grid_ts: Seconds per grid timestep for videos
context_len: Context length
seq_len: Sequence length
audio_feature_lengths: Audio feature lengths for multimodal models
use_audio_in_video: Whether to use audio in video for interleaving

View File

@ -1630,8 +1630,6 @@ class KeyeForConditionalGeneration(
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
@ -1759,6 +1757,5 @@ class KeyeForConditionalGeneration(
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -600,8 +600,6 @@ class KeyeVL1_5ForConditionalGeneration(
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
@ -729,6 +727,5 @@ class KeyeVL1_5ForConditionalGeneration(
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -1179,8 +1179,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -1293,7 +1291,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -927,8 +927,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -1125,7 +1123,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
mrope_position_delta = (
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -1118,8 +1118,6 @@ class Qwen2_5_VLForConditionalGeneration(
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float],
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -1232,7 +1230,6 @@ class Qwen2_5_VLForConditionalGeneration(
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -1240,8 +1240,6 @@ class Qwen2VLForConditionalGeneration(
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -1360,7 +1358,6 @@ class Qwen2VLForConditionalGeneration(
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta

View File

@ -43,6 +43,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)
self.gate = ReplicatedLinear(

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -1417,8 +1417,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:

View File

@ -1419,8 +1419,6 @@ class Qwen3VLForConditionalGeneration(
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
@ -1519,7 +1517,7 @@ class Qwen3VLForConditionalGeneration(
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
def get_language_model(self) -> torch.nn.Module:

View File

@ -371,8 +371,6 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
@ -390,7 +388,7 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
video_grid_thw=video_grid_thw,
)
mrope_positions = mrope_positions[:, 0, context_len:seq_len]
mrope_positions = mrope_positions[:, 0]
mrope_position_delta = mrope_position_delta[0].item()
return mrope_positions, mrope_position_delta

View File

@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None
)
@ -202,12 +204,15 @@ class RocmPlatform(Platform):
]
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
@ -228,19 +233,23 @@ class RocmPlatform(Platform):
has_sink,
use_sparse,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if use_mla:
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled,
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
if use_mla:
if selected_backend is None:
selected_backend = (
_Backend.ROCM_AITER_MLA
if is_aiter_mla_enabled() or block_size == 1
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
)
@ -265,12 +274,12 @@ class RocmPlatform(Platform):
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return (

View File

@ -198,6 +198,7 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
@ -270,28 +271,15 @@ except ImportError:
flashinfer_available = False
def is_rocm_aiter_fp8bmm_enabled() -> bool:
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER_FP8BMM
and envs.VLLM_ROCM_USE_AITER
)
if is_rocm_aiter_fp8bmm_enabled():
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501
)
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
def dynamic_per_batched_tensor_quant(
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
):
DTYPE_MAX = torch.finfo(dtype).max
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
scale = DTYPE_MAX / amax
x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
logger = init_logger(__name__)
@ -1109,6 +1097,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
@ -1158,7 +1147,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
@ -1187,7 +1176,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
@ -1196,7 +1185,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
@ -1208,10 +1197,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
# Convert from (B, N, V) to (B, N * V)
@ -1571,7 +1559,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
[self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
W_K = W_UK.transpose(0, 1) # 16 512 128
W_V = W_UV.permute(1, 2, 0) # 16 128 512
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
@ -1600,7 +1588,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=torch.bfloat16,
device=self.W_K.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
)
@ -1609,7 +1597,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=torch.bfloat16,
device=self.W_V.device,
)
aiter_triton_fp8_bmm(
rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
)
else:
@ -1958,7 +1946,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
@ -1966,9 +1953,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded
if is_rocm_aiter_fp8bmm_enabled():
if self.is_aiter_triton_fp8_bmm_enabled:
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(
decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
decode_q_nope,
self.W_K,
self.W_K_scale,

View File

@ -6,9 +6,8 @@ from typing import ClassVar
import torch
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
@ -22,10 +21,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
@ -284,7 +279,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(
rocm_aiter_ops.mla_decode_fwd(
q,
kv_buffer,
o,

View File

@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
logger = init_logger(__name__)
@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
continue
# Schedule newly needed KV blocks for the request.
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is not None:
# The request can be scheduled.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break
if new_blocks is not None:
# The request can be scheduled.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[
preempted_req.request_id
]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break
if new_blocks is None:
# Cannot schedule this request.
@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id
)
)
)
# Construct the scheduler output.
new_reqs_data = [
@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
)
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
self._update_after_schedule(scheduler_output)
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
def _update_after_schedule(

View File

@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -315,17 +316,21 @@ class EngineCore:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
with record_function_or_nullcontext("core step: schedule"):
scheduler_output = self.scheduler.schedule()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("core step: execute_model"):
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
with record_function_or_nullcontext("core step: update_from_output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
@ -363,32 +368,49 @@ class EngineCore:
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
scheduler_output = self.scheduler.schedule()
with record_function_or_nullcontext(
"core step_with_batch_queue: execute_model"
):
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if scheduler_output.pending_structured_output_tokens:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
with record_function_or_nullcontext(
"core step_with_batch_queue: pending_structured_output_tokens"
):
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return
# (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with record_function_or_nullcontext(
"core step_with_batch_queue: get_grammar_bitmask"
):
# We aren't waiting for any tokens, get any grammar
# output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
with record_function_or_nullcontext(
"core step_with_batch_queue: sample_tokens"
):
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
@ -408,27 +430,34 @@ class EngineCore:
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
with record_function_or_nullcontext(
"core step_with_batch_queue: update_from_output"
):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
with record_function_or_nullcontext(
"core step_with_batch_queue: deferred_scheduler_output"
):
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed

View File

@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
@ -280,28 +281,32 @@ class LLMEngine:
return []
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
with record_function_or_nullcontext("llm_genine step: get_output"):
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
with record_function_or_nullcontext("llm_genine step: process_outputs"):
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
with record_function_or_nullcontext("llm_genine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
with record_function_or_nullcontext("llm_genine step: record_stats"):
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
return processed_outputs.request_outputs

View File

@ -111,6 +111,7 @@ class GuidanceGrammar(StructuredOutputGrammar):
vocab_size: int
printed_error: bool = False
terminated: bool = False
rollback_lag: int = 0
def check_error(self):
if not self.printed_error:
@ -127,6 +128,8 @@ class GuidanceGrammar(StructuredOutputGrammar):
"""
if self.ll_tokenizer.eos_token in tokens:
if self.ll_matcher.is_stopped() and not self.terminated:
self.rollback_lag = 1
self.terminated = True
if self.ll_matcher.is_stopped():
@ -163,8 +166,11 @@ class GuidanceGrammar(StructuredOutputGrammar):
return tokens[:num_tokens]
def rollback(self, num_tokens: int) -> None:
self.ll_matcher.rollback(num_tokens)
self.check_error()
if num_tokens > 0:
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
self.terminated = False
self.rollback_lag = 0
self.check_error()
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped

View File

@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"):
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
),
record_function_or_nullcontext("Forward"),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
):
model_output = self._model_forward(
@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**model_kwargs,
)
with record_function_or_nullcontext("Postprocess"):
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
hidden_states, aux_hidden_states = model_output
@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"):
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
with record_function_or_nullcontext("gpu_model_runner: draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
sampled_token_ids,
@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
with record_function_or_nullcontext("Bookkeep"):
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
num_nans_in_logits,
logprobs_lists,
@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
with record_function_or_nullcontext("EPLB"):
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
)
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
)
if not self.use_async_scheduling:
return output
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self.input_batch.set_async_sampled_token_ids(
async_output.sampled_token_ids_cpu,
async_output.async_copy_ready_event,
)
with record_function_or_nullcontext(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
):
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
with record_function_or_nullcontext(
"gpu_model_runner: set_async_sampled_token_ids"
):
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self.input_batch.set_async_sampled_token_ids(
async_output.sampled_token_ids_cpu,
async_output.async_copy_ready_event,
)
return async_output