mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
773 lines
25 KiB
Python
773 lines
25 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
from contextlib import nullcontext
|
|
from datetime import datetime
|
|
from itertools import product
|
|
from typing import Any, TypedDict
|
|
|
|
import ray
|
|
import torch
|
|
from ray.experimental.tqdm_ray import tqdm
|
|
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FusedMoEQuantConfig,
|
|
_get_config_dtype_str,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
|
from vllm.platforms import current_platform
|
|
from vllm.transformers_utils.config import get_config
|
|
from vllm.triton_utils import triton
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
FP8_DTYPE = current_platform.fp8_dtype()
|
|
|
|
|
|
def ensure_divisibility(numerator, denominator, text):
|
|
"""Ensure that numerator is divisible by the denominator."""
|
|
assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
|
|
text, numerator, denominator
|
|
)
|
|
|
|
|
|
class BenchmarkConfig(TypedDict):
|
|
BLOCK_SIZE_M: int
|
|
BLOCK_SIZE_N: int
|
|
BLOCK_SIZE_K: int
|
|
GROUP_SIZE_M: int
|
|
num_warps: int
|
|
num_stages: int
|
|
|
|
|
|
def benchmark_config(
|
|
config: BenchmarkConfig,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
num_iters: int = 100,
|
|
block_quant_shape: list[int] = None,
|
|
use_deep_gemm: bool = False,
|
|
) -> float:
|
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
|
if use_int8_w8a16:
|
|
w1 = torch.randint(
|
|
-127,
|
|
127,
|
|
(
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
),
|
|
dtype=torch.int8,
|
|
)
|
|
w2 = torch.randint(
|
|
-127,
|
|
127,
|
|
(
|
|
num_experts,
|
|
hidden_size,
|
|
shard_intermediate_size // 2,
|
|
),
|
|
dtype=torch.int8,
|
|
)
|
|
else:
|
|
w1 = torch.randn(
|
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
|
)
|
|
w2 = torch.randn(
|
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
|
)
|
|
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
|
|
|
w1_scale = None
|
|
w2_scale = None
|
|
a1_scale = None
|
|
a2_scale = None
|
|
if use_int8_w8a16:
|
|
w1_scale = torch.randn(
|
|
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
|
)
|
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
|
if use_deep_gemm:
|
|
# we use the default block shape for deepgemm
|
|
block_quant_shape = [128, 128]
|
|
if use_fp8_w8a8:
|
|
if block_quant_shape:
|
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
|
E = num_experts
|
|
N = shard_intermediate_size // 2
|
|
K = hidden_size
|
|
factor_for_scale = 1e-2
|
|
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
|
w1_scale = (
|
|
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
|
* factor_for_scale
|
|
)
|
|
w2_scale = (
|
|
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
|
* factor_for_scale
|
|
)
|
|
else:
|
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
|
|
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
|
|
|
w1 = w1.to(FP8_DTYPE)
|
|
w2 = w2.to(FP8_DTYPE)
|
|
|
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
|
|
|
def prepare(i: int):
|
|
input_gating.copy_(gating_output[i])
|
|
|
|
def run():
|
|
from vllm.model_executor.layers.fused_moe import override_config
|
|
|
|
if use_fp8_w8a8:
|
|
quant_dtype = torch.float8_e4m3fn
|
|
elif use_int8_w8a16:
|
|
quant_dtype = torch.int8
|
|
else:
|
|
quant_dtype = None
|
|
|
|
quant_config = FusedMoEQuantConfig.make(
|
|
quant_dtype=quant_dtype,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
a1_scale=a1_scale,
|
|
a2_scale=a2_scale,
|
|
block_shape=block_quant_shape,
|
|
)
|
|
|
|
with override_config(config):
|
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
x, input_gating, topk, renormalize=not use_deep_gemm
|
|
)
|
|
return fused_experts(
|
|
x,
|
|
w1,
|
|
w2,
|
|
topk_weights,
|
|
topk_ids,
|
|
inplace=True,
|
|
quant_config=quant_config,
|
|
allow_deep_gemm=use_deep_gemm,
|
|
)
|
|
|
|
# JIT compilation & warmup
|
|
run()
|
|
torch.cuda.synchronize()
|
|
|
|
# Capture 10 invocations with CUDA graph
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph):
|
|
for _ in range(10):
|
|
run()
|
|
torch.cuda.synchronize()
|
|
|
|
# Warmup
|
|
for _ in range(5):
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
latencies: list[float] = []
|
|
for i in range(num_iters):
|
|
prepare(i)
|
|
torch.cuda.synchronize()
|
|
|
|
start_event.record()
|
|
graph.replay()
|
|
end_event.record()
|
|
end_event.synchronize()
|
|
latencies.append(start_event.elapsed_time(end_event))
|
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
|
graph.reset()
|
|
return avg
|
|
|
|
|
|
def get_rocm_tuning_space(use_fp16):
|
|
block_mn_range = [16, 32, 64, 128, 256]
|
|
block_k_range = [16, 32, 64, 128, 256]
|
|
if not use_fp16:
|
|
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
|
num_warps_range = [1, 2, 4, 8]
|
|
group_m_range = [1, 4, 8, 16, 32]
|
|
num_stage_range = [2]
|
|
waves_per_eu_range = [0]
|
|
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
|
kpack_range = [1, 2] if use_fp16 else []
|
|
|
|
param_ranges = {
|
|
"BLOCK_SIZE_M": block_mn_range,
|
|
"BLOCK_SIZE_N": block_mn_range,
|
|
"BLOCK_SIZE_K": block_k_range,
|
|
"GROUP_SIZE_M": group_m_range,
|
|
"num_warps": num_warps_range,
|
|
"num_stages": num_stage_range,
|
|
"waves_per_eu": waves_per_eu_range,
|
|
}
|
|
if use_fp16:
|
|
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
|
param_ranges["kpack"] = kpack_range
|
|
|
|
return param_ranges
|
|
|
|
|
|
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
|
|
configs: list[BenchmarkConfig] = []
|
|
|
|
if current_platform.is_rocm():
|
|
param_ranges = get_rocm_tuning_space(use_fp16)
|
|
else:
|
|
# Reduced search space for faster tuning.
|
|
# TODO(woosuk): Increase the search space and use a performance model to
|
|
# prune the search space.
|
|
block_m_range = [16, 32, 64, 128, 256]
|
|
block_n_range = [32, 64, 128, 256]
|
|
block_k_range = [64, 128, 256]
|
|
num_warps_range = [4, 8]
|
|
group_m_range = [1, 16, 32, 64]
|
|
num_stage_range = [2, 3, 4, 5]
|
|
|
|
param_ranges = {
|
|
"BLOCK_SIZE_M": block_m_range,
|
|
"BLOCK_SIZE_N": block_n_range,
|
|
"BLOCK_SIZE_K": block_k_range,
|
|
"GROUP_SIZE_M": group_m_range,
|
|
"num_warps": num_warps_range,
|
|
"num_stages": num_stage_range,
|
|
}
|
|
|
|
keys, values = zip(*param_ranges.items())
|
|
for config_values in product(*values):
|
|
config = dict(zip(keys, config_values))
|
|
configs.append(config)
|
|
|
|
# Remove configs that are not compatible with fp8 block quantization
|
|
# BLOCK_SIZE_K must be a multiple of block_k
|
|
# BLOCK_SIZE_N must be a multiple of block_n
|
|
if block_quant_shape is not None and not use_fp16:
|
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
|
for config in configs[:]:
|
|
if (
|
|
config["BLOCK_SIZE_K"] % block_k != 0
|
|
or config["BLOCK_SIZE_N"] % block_n != 0
|
|
):
|
|
configs.remove(config)
|
|
return configs
|
|
|
|
|
|
def prune_rocm_search_space(
|
|
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
|
|
):
|
|
N1, K1 = shard_intermediate_size, hidden_size
|
|
N2, K2 = hidden_size, shard_intermediate_size // 2
|
|
pruned_space_1 = prune_rocm_configs(
|
|
num_tokens * topk, N1, K1, search_space, is_fp16
|
|
)
|
|
pruned_space_2 = prune_rocm_configs(
|
|
num_tokens * topk, N2, K2, search_space, is_fp16
|
|
)
|
|
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
|
return search_space
|
|
|
|
|
|
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
|
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
|
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
|
pruned_configs = []
|
|
elemBytes_a = 2 if is_fp16 else 1
|
|
elemBytes_b = 2 if is_fp16 else 1
|
|
|
|
mfma = 16 if M < 32 or N < 32 else 32
|
|
|
|
# TODO (zhanglx): figure out the boundary between large and small gemms
|
|
large_gemm = False
|
|
if M >= 2048 and N >= 2048:
|
|
large_gemm = True
|
|
|
|
for config in configs:
|
|
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
|
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
|
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
|
num_warps = config.get("num_warps")
|
|
|
|
if is_fp16:
|
|
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
|
if matrix_instr_nonkdim > mfma:
|
|
continue
|
|
if mfma == 4 and BLOCK_SIZE_K < 64:
|
|
continue
|
|
# some layouts could not work properly in case
|
|
# number elements per thread is less 1
|
|
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
|
continue
|
|
SPLIT_K = config.get("SPLIT_K", 1)
|
|
GROUP_M = config.get("GROUP_SIZE_M")
|
|
if is_fp16:
|
|
if (
|
|
matrix_instr_nonkdim > BLOCK_SIZE_M
|
|
or matrix_instr_nonkdim > BLOCK_SIZE_N
|
|
):
|
|
continue
|
|
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
|
continue
|
|
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
|
continue
|
|
# Skip BLOCK_SIZE that is too large compare to M/N
|
|
# unless BLOCK_SIZE is already small enough
|
|
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
|
continue
|
|
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
|
continue
|
|
# skip large split_k when not necessary
|
|
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
|
continue
|
|
# skip split_k that leads to EVEN_K = false
|
|
leap = SPLIT_K * BLOCK_SIZE_K
|
|
modv = K % leap
|
|
if modv != 0:
|
|
continue
|
|
# skip large GROUP_M
|
|
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
|
continue
|
|
# out of shared memory resource
|
|
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
|
LDS = (
|
|
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
|
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
|
)
|
|
if LDS > 65536:
|
|
continue
|
|
# Skip small block sizes and num_warps for large gemm
|
|
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
|
if large_gemm:
|
|
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
|
continue
|
|
if BLOCK_SIZE_K < 64:
|
|
continue
|
|
if num_warps < 4:
|
|
continue
|
|
|
|
pruned_configs.append(config)
|
|
|
|
return pruned_configs
|
|
|
|
|
|
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
|
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
|
|
|
|
|
def merge_unique_dicts(list1, list2):
|
|
result = []
|
|
combined_list = list1.copy()
|
|
combined_list.extend(list2)
|
|
for dictionary in combined_list:
|
|
if dictionary not in result:
|
|
result.append(dictionary)
|
|
return result
|
|
|
|
|
|
@ray.remote(num_gpus=1)
|
|
class BenchmarkWorker:
|
|
def __init__(self, seed: int) -> None:
|
|
torch.set_default_device("cuda")
|
|
current_platform.seed_everything(seed)
|
|
self.seed = seed
|
|
# Get the device ID to allocate tensors and kernels
|
|
# on the respective GPU. This is required for Ray to work
|
|
# correctly with multi-GPU tuning on the ROCm platform.
|
|
self.device_id = int(ray.get_gpu_ids()[0])
|
|
|
|
def benchmark(
|
|
self,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
block_quant_shape: list[int] = None,
|
|
use_deep_gemm: bool = False,
|
|
) -> tuple[dict[str, int], float]:
|
|
current_platform.seed_everything(self.seed)
|
|
dtype_str = _get_config_dtype_str(
|
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
|
)
|
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
|
# is the intermediate size after silu_and_mul.
|
|
block_n = block_quant_shape[0] if block_quant_shape else None
|
|
block_k = block_quant_shape[1] if block_quant_shape else None
|
|
op_config = get_moe_configs(
|
|
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
|
)
|
|
if op_config is None:
|
|
config = get_default_config(
|
|
num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype_str,
|
|
block_quant_shape,
|
|
)
|
|
else:
|
|
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
|
kernel_time = benchmark_config(
|
|
config,
|
|
num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
num_iters=100,
|
|
block_quant_shape=block_quant_shape,
|
|
use_deep_gemm=use_deep_gemm,
|
|
)
|
|
return config, kernel_time
|
|
|
|
def tune(
|
|
self,
|
|
num_tokens: int,
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
search_space: list[dict[str, int]],
|
|
block_quant_shape: list[int],
|
|
use_deep_gemm: bool,
|
|
) -> dict[str, int]:
|
|
best_config = None
|
|
best_time = float("inf")
|
|
if current_platform.is_rocm():
|
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
|
search_space = prune_rocm_search_space(
|
|
num_tokens,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
search_space,
|
|
is_fp16,
|
|
topk,
|
|
)
|
|
|
|
need_device_guard = False
|
|
if current_platform.is_rocm():
|
|
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
|
|
if visible_device != f"{self.device_id}":
|
|
need_device_guard = True
|
|
|
|
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
|
for config in tqdm(search_space):
|
|
try:
|
|
kernel_time = benchmark_config(
|
|
config,
|
|
num_tokens,
|
|
num_experts,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
num_iters=20,
|
|
block_quant_shape=block_quant_shape,
|
|
use_deep_gemm=use_deep_gemm,
|
|
)
|
|
except triton.runtime.autotuner.OutOfResources:
|
|
# Some configurations may be invalid and fail to compile.
|
|
continue
|
|
|
|
if kernel_time < best_time:
|
|
best_time = kernel_time
|
|
best_config = config
|
|
now = datetime.now()
|
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
|
assert best_config is not None
|
|
return best_config
|
|
|
|
|
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|
return {
|
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
|
"num_warps": config["num_warps"],
|
|
"num_stages": config["num_stages"],
|
|
**(
|
|
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
|
),
|
|
**(
|
|
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
|
|
if "matrix_instr_nonkdim" in config
|
|
else {}
|
|
),
|
|
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
|
}
|
|
|
|
|
|
def save_configs(
|
|
configs: dict[int, BenchmarkConfig],
|
|
num_experts: int,
|
|
shard_intermediate_size: int,
|
|
hidden_size: int,
|
|
topk: int,
|
|
dtype: torch.dtype,
|
|
use_fp8_w8a8: bool,
|
|
use_int8_w8a16: bool,
|
|
block_quant_shape: list[int],
|
|
save_dir: str,
|
|
) -> None:
|
|
dtype_str = _get_config_dtype_str(
|
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
|
)
|
|
|
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
|
# is the intermediate size after silu_and_mul.
|
|
filename = get_config_file_name(
|
|
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
|
)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
filename = os.path.join(save_dir, filename)
|
|
print(f"Writing best config to {filename}...")
|
|
with open(filename, "w") as f:
|
|
json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
|
|
f.write("\n")
|
|
|
|
|
|
def get_weight_block_size_safety(config, default_value=None):
|
|
quantization_config = getattr(config, "quantization_config", {})
|
|
if isinstance(quantization_config, dict):
|
|
return quantization_config.get("weight_block_size", default_value)
|
|
return default_value
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
print(args)
|
|
|
|
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
|
if args.model_prefix:
|
|
config = getattr(config, args.model_prefix)
|
|
|
|
if config.architectures[0] == "DbrxForCausalLM":
|
|
E = config.ffn_config.moe_num_experts
|
|
topk = config.ffn_config.moe_top_k
|
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
|
elif config.architectures[0] == "JambaForCausalLM":
|
|
E = config.num_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.intermediate_size
|
|
elif config.architectures[0] in (
|
|
"DeepseekV3ForCausalLM",
|
|
"DeepseekV2ForCausalLM",
|
|
"Glm4MoeForCausalLM",
|
|
):
|
|
E = config.n_routed_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.moe_intermediate_size
|
|
elif config.architectures[0] in (
|
|
"Qwen2MoeForCausalLM",
|
|
"Qwen3MoeForCausalLM",
|
|
"Qwen3NextForCausalLM",
|
|
):
|
|
E = config.num_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.moe_intermediate_size
|
|
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
|
E = config.num_experts
|
|
topk = config.moe_topk[0]
|
|
intermediate_size = config.moe_intermediate_size[0]
|
|
else:
|
|
# Support for llama4
|
|
config = config.get_text_config()
|
|
# Default: Mixtral.
|
|
E = config.num_local_experts
|
|
topk = config.num_experts_per_tok
|
|
intermediate_size = config.intermediate_size
|
|
enable_ep = bool(args.enable_expert_parallel)
|
|
if enable_ep:
|
|
ensure_divisibility(E, args.tp_size, "Number of experts")
|
|
E = E // args.tp_size
|
|
shard_intermediate_size = 2 * intermediate_size
|
|
else:
|
|
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
|
hidden_size = config.hidden_size
|
|
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
|
block_quant_shape = get_weight_block_size_safety(config)
|
|
|
|
if args.batch_size is None:
|
|
batch_sizes = [
|
|
1,
|
|
2,
|
|
4,
|
|
8,
|
|
16,
|
|
24,
|
|
32,
|
|
48,
|
|
64,
|
|
96,
|
|
128,
|
|
256,
|
|
512,
|
|
1024,
|
|
1536,
|
|
2048,
|
|
3072,
|
|
4096,
|
|
]
|
|
else:
|
|
batch_sizes = args.batch_size
|
|
|
|
use_deep_gemm = bool(args.use_deep_gemm)
|
|
|
|
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
|
|
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
|
logger.warning(
|
|
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
|
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
|
|
)
|
|
val = os.environ["HIP_VISIBLE_DEVICES"]
|
|
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
|
del os.environ["HIP_VISIBLE_DEVICES"]
|
|
|
|
ray.init()
|
|
num_gpus = int(ray.available_resources()["GPU"])
|
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
|
|
|
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
|
outputs = []
|
|
worker_idx = 0
|
|
for input_args in inputs:
|
|
worker = workers[worker_idx]
|
|
worker_method = getattr(worker, method)
|
|
output = worker_method.remote(*input_args)
|
|
outputs.append(output)
|
|
worker_idx = (worker_idx + 1) % num_gpus
|
|
return ray.get(outputs)
|
|
|
|
if args.tune:
|
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
|
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
|
print(f"Start tuning over {len(search_space)} configurations...")
|
|
if use_deep_gemm:
|
|
raise ValueError(
|
|
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
|
|
"kernels. Please remove the flag."
|
|
)
|
|
start = time.time()
|
|
configs = _distribute(
|
|
"tune",
|
|
[
|
|
(
|
|
batch_size,
|
|
E,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
search_space,
|
|
block_quant_shape,
|
|
use_deep_gemm,
|
|
)
|
|
for batch_size in batch_sizes
|
|
],
|
|
)
|
|
best_configs = {
|
|
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
|
}
|
|
save_configs(
|
|
best_configs,
|
|
E,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
block_quant_shape,
|
|
args.save_dir,
|
|
)
|
|
end = time.time()
|
|
print(f"Tuning took {end - start:.2f} seconds")
|
|
else:
|
|
outputs = _distribute(
|
|
"benchmark",
|
|
[
|
|
(
|
|
batch_size,
|
|
E,
|
|
shard_intermediate_size,
|
|
hidden_size,
|
|
topk,
|
|
dtype,
|
|
use_fp8_w8a8,
|
|
use_int8_w8a16,
|
|
block_quant_shape,
|
|
use_deep_gemm,
|
|
)
|
|
for batch_size in batch_sizes
|
|
],
|
|
)
|
|
|
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
|
print(f"Batch size: {batch_size}, config: {config}")
|
|
print(f"Kernel time: {kernel_time:.2f} us")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = FlexibleArgumentParser()
|
|
parser.add_argument(
|
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
)
|
|
parser.add_argument(
|
|
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
|
)
|
|
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
|
|
parser.add_argument(
|
|
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
|
)
|
|
parser.add_argument("--use-deep-gemm", action="store_true")
|
|
parser.add_argument(
|
|
"--save-dir", type=str, default="./", help="Directory to save tuned results"
|
|
)
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
|
parser.add_argument("--tune", action="store_true")
|
|
parser.add_argument("--trust-remote-code", action="store_true")
|
|
parser.add_argument("--model-prefix", type=str, required=False)
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|