mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762)
Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@ -80,6 +80,11 @@ def bench_run(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -111,6 +116,10 @@ def bench_run(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
@ -125,6 +134,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -136,6 +149,10 @@ def bench_run(
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
@ -150,6 +167,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -194,6 +215,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
@ -231,6 +256,10 @@ def bench_run(
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -289,6 +318,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
@ -297,7 +330,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
@ -160,6 +160,30 @@ __global__ void shuffleInputRowsKernel(const T* input,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void shuffleInputRowsKernelSlow(const T* input,
|
||||
const int32_t* dst2src_map,
|
||||
T* output, int64_t num_src_rows,
|
||||
int64_t num_dst_rows,
|
||||
int64_t num_cols) {
|
||||
int64_t dest_row_idx = blockIdx.x;
|
||||
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
||||
|
||||
if (blockIdx.x < num_dst_rows) {
|
||||
// Duplicate and permute rows
|
||||
auto const* source_row_ptr = input + source_row_idx * num_cols;
|
||||
auto* dest_row_ptr = output + dest_row_idx * num_cols;
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
|
||||
for (int elem_index = start_offset; elem_index < num_cols;
|
||||
elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
const torch::Tensor& dst2src_map,
|
||||
torch::Tensor& output_tensor) {
|
||||
@ -173,17 +197,24 @@ void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
int64_t const num_src_rows = input_tensor.size(0);
|
||||
int64_t const num_cols = input_tensor.size(1);
|
||||
|
||||
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
|
||||
"num_cols must be divisible by 128 / "
|
||||
"sizeof(input_tensor.scalar_type()) / 8");
|
||||
|
||||
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
|
||||
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
|
||||
dst2src_map.data_ptr<int32_t>(),
|
||||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
|
||||
num_dest_rows, num_cols);
|
||||
});
|
||||
if (num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)) {
|
||||
// use slow kernel if num_cols can't be aligned to 128 bits
|
||||
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
|
||||
shuffleInputRowsKernelSlow<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
|
||||
dst2src_map.data_ptr<int32_t>(),
|
||||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
|
||||
num_dest_rows, num_cols);
|
||||
});
|
||||
} else {
|
||||
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
|
||||
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
|
||||
dst2src_map.data_ptr<int32_t>(),
|
||||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
|
||||
num_dest_rows, num_cols);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
@ -206,6 +206,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
}
|
||||
@ -439,6 +443,11 @@ def test_run_cutlass_moe_fp8(
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||
torch.float8_e4m3fn,
|
||||
@ -447,8 +456,9 @@ def test_run_cutlass_moe_fp8(
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
|
||||
per_act_token, per_out_channel, False)
|
||||
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
|
||||
workspace13, workspace2, None, mt.a.dtype, per_act_token,
|
||||
per_out_channel, False)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(output_shape,
|
||||
|
@ -75,6 +75,7 @@ def pplx_cutlass_moe(
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
intermediate_dim = w2.shape[2]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = hidden_dim # TODO support more cases
|
||||
device = pgi.device
|
||||
@ -123,10 +124,31 @@ def pplx_cutlass_moe(
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
|
||||
ab_strides1 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_local_experts, ),
|
||||
intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_local_experts, ),
|
||||
2 * intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassExpertsFp8(num_local_experts,
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_batched_format=True)
|
||||
|
||||
|
@ -13,8 +13,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
@ -34,6 +33,10 @@ def run_cutlass_moe_fp8(
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
@ -152,27 +155,11 @@ def run_cutlass_moe_fp8(
|
||||
problem_sizes1, problem_sizes2, a_map,
|
||||
c_map, global_num_experts, N, K)
|
||||
|
||||
a1q = _fp8_perm(a1q, a_map)
|
||||
a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale
|
||||
a1q = ops.shuffle_rows(a1q, a_map)
|
||||
a1q_scale = (ops.shuffle_rows(a1q_scale, a_map)
|
||||
if per_act_token else a1q_scale)
|
||||
expert_offsets = expert_offsets[:-1]
|
||||
|
||||
ab_strides1 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((w1.size(0), ),
|
||||
2 * N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((w1.size(0), ),
|
||||
N,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((w1.size(0), ),
|
||||
K,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if use_batched_format:
|
||||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
@ -209,7 +196,8 @@ def run_cutlass_moe_fp8(
|
||||
else:
|
||||
# We can't do this inplace because output may point to the same tensor
|
||||
# as c3.
|
||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||
output.copy_(ops.shuffle_rows(c3, c_map).view(M * topk, K),
|
||||
non_blocking=True)
|
||||
|
||||
|
||||
# TODO (bnell): split class batched vs. non-batched?
|
||||
@ -222,6 +210,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
num_dispatchers: Optional[int] = None,
|
||||
use_batched_format: bool = False,
|
||||
@ -238,6 +230,10 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
self.c_strides1 = c_strides1
|
||||
self.c_strides2 = c_strides2
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@property
|
||||
@ -316,7 +312,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
self.use_batched_format)
|
||||
@ -330,6 +327,10 @@ def cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
per_act_token: Optional[bool] = None,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
@ -357,6 +358,17 @@ def cutlass_moe_fp8(
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides1 (torch.Tensor): The output strides for the first gemm.
|
||||
Shape: [num_experts]
|
||||
- c_strides2 (torch.Tensor): The output strides for the second gemm.
|
||||
Shape: [num_experts]
|
||||
- per_act_token (Optional[bool]): Whether the scale is per-token or
|
||||
per-tensor.
|
||||
- activation (str): The activation function to use.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
@ -389,6 +401,10 @@ def cutlass_moe_fp8(
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
|
@ -859,6 +859,21 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
device = layer.w13_weight.device
|
||||
# ab_strides1 and c_strides2 are the same
|
||||
self.ab_strides1_c_strides2 = torch.full((layer.local_num_experts, ),
|
||||
layer.hidden_size,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.ab_strides2 = torch.full((layer.local_num_experts, ),
|
||||
layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
self.c_strides1 = torch.full((layer.local_num_experts, ),
|
||||
2 * layer.intermediate_size_per_partition,
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
@ -881,6 +896,10 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_batched_format=use_batched_format,
|
||||
)
|
||||
@ -927,7 +946,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
per_act_token = (
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
@ -948,6 +968,10 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
Reference in New Issue
Block a user