mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Add unit tests for MoE ModularKernel combinations + Profiling utility (#20449)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
6fb162447b
commit
53fa457391
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
160
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
160
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
from .common import Config
|
||||
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(
|
||||
f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
|
||||
|
||||
def to_quant_torch_dtype(s: str) -> torch.dtype:
|
||||
if s == "torch.float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported quant type {s}")
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of ranks that participate in all2all",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pf-type",
|
||||
type=to_pf_class_type,
|
||||
required=True,
|
||||
help=("Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experts-type",
|
||||
type=to_experts_class_type,
|
||||
required=True,
|
||||
help=(f"Choose a FusedExpert type : "
|
||||
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[64],
|
||||
help="num tokens per rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
type=int,
|
||||
default=7168,
|
||||
help="hidden-size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="N dimension of the first fused-moe matmul",
|
||||
)
|
||||
parser.add_argument("--num-experts",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Global num experts")
|
||||
parser.add_argument("--topk",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[4, 1],
|
||||
help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl."
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=to_quant_torch_dtype,
|
||||
help="Quant datatype")
|
||||
parser.add_argument("--per-token-quantized-activations",
|
||||
action='store_true',
|
||||
help=("The input activations must be per-token "
|
||||
"quantized"))
|
||||
parser.add_argument("--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.")
|
||||
parser.add_argument("--block-shape",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Quantization block shape")
|
||||
|
||||
# Torch trace profile generation args
|
||||
parser.add_argument("--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _validate_args(args: argparse.Namespace):
|
||||
|
||||
if args.quant_dtype is not None:
|
||||
assert args.quant_dtype == torch.float8_e4m3fn
|
||||
if args.block_shape is not None:
|
||||
assert len(args.block_shape) == 2, (
|
||||
f"block shape must have 2 elements. got {args.block_shape}")
|
||||
|
||||
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
|
||||
assert args.world_size == 1, (
|
||||
"Single GPU objects need world size set to 1")
|
||||
|
||||
if args.torch_trace_dir_path is not None:
|
||||
from pathlib import Path
|
||||
assert Path(args.torch_trace_dir_path).is_dir(), (
|
||||
f"Please create {args.torch_trace_dir_path}")
|
||||
|
||||
|
||||
def make_config(args: argparse.Namespace) -> Config:
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
quant_config = None
|
||||
if args.quant_dtype is not None:
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
block_shape=args.block_shape)
|
||||
|
||||
return Config(
|
||||
Ms=args.m,
|
||||
K=args.k,
|
||||
N=args.n,
|
||||
E=args.num_experts,
|
||||
topks=args.topk,
|
||||
dtype=torch.bfloat16, # hard-code
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path)
|
641
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
641
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
@ -0,0 +1,641 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||
TritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
|
||||
make_quant_fp8_weights, per_token_cast_to_fp8)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
|
||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
return f"{name} : {t.shape} {t.dtype} {t.device}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: Union[list[int], int]
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: Optional[int]
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: Optional[str] = None
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
s += "== Config: \n"
|
||||
s += f" world_size={self.world_size} \n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__} \n"
|
||||
s += f" FE={self.fused_experts_type.__name__} \n"
|
||||
s += f" topk={self.topks} \n"
|
||||
s += f" dtype={self.dtype} \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
|
||||
s += " Quant: \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype} \n"
|
||||
s += f" q_block_shape={self.quant_block_shape} \n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
|
||||
else:
|
||||
s += " quant=None \n"
|
||||
return s
|
||||
|
||||
@property
|
||||
def M(self) -> int:
|
||||
assert isinstance(self.Ms, int)
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def is_per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return (not self.is_per_act_token_quant
|
||||
and self.quant_block_shape is None)
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def topk(self) -> int:
|
||||
assert isinstance(self.topks, int)
|
||||
return self.topks
|
||||
|
||||
@property
|
||||
def topk_ids_dtype(self) -> Optional[torch.dtype]:
|
||||
topk_ids_dtype = None
|
||||
if self.prepare_finalize_type == PplxPrepareAndFinalize:
|
||||
topk_ids_dtype = torch.uint32
|
||||
elif self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]:
|
||||
topk_ids_dtype = torch.int64
|
||||
return topk_ids_dtype
|
||||
|
||||
@property
|
||||
def num_local_experts(self) -> int:
|
||||
return self.E // self.world_size
|
||||
|
||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||
"""
|
||||
make env data for vllm launch.
|
||||
"""
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
env_dict = {
|
||||
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
|
||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||
}
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
|
||||
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
|
||||
]
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
|
||||
def is_fe_16bit_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts, TritonExperts
|
||||
]
|
||||
|
||||
def is_fe_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
|
||||
def is_fe_block_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
|
||||
def needs_deep_gemm(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
]
|
||||
|
||||
def needs_pplx(self):
|
||||
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
|
||||
|
||||
def needs_deep_ep(self):
|
||||
return self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def all2all_backend(self):
|
||||
if self.needs_pplx():
|
||||
return "pplx"
|
||||
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
|
||||
return "deepep_high_throughput"
|
||||
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
|
||||
return "deepep_low_latency"
|
||||
return "naive"
|
||||
|
||||
def needs_all2all(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def is_valid(self):
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
return False
|
||||
else:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False
|
||||
|
||||
# Check quantization sanity
|
||||
if (int(self.is_per_act_token_quant) +
|
||||
int(self.is_per_tensor_act_quant) +
|
||||
int(self.quant_block_shape is not None)) > 1:
|
||||
# invalid quant config
|
||||
return False
|
||||
|
||||
# check bf16 / fp16 support
|
||||
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
|
||||
if is_16bit and not self.is_fe_16bit_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 support
|
||||
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
|
||||
if is_fp8 and not self.is_fe_fp8_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and not is_fp8:
|
||||
return False
|
||||
if is_block_quatized and not self.is_fe_block_fp8_supported():
|
||||
return False
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False
|
||||
|
||||
# Check dependencies
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
return False
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: Optional[torch.Tensor]
|
||||
w2_scale: Optional[torch.Tensor]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.w1, "w1")} \n'
|
||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||
return s
|
||||
|
||||
def to_current_device(self):
|
||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
if is_quantized:
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
self.w1_scale = self.w1_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
self.w2_scale = self.w2_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
w1_scale, w2_scale = (None, None)
|
||||
if is_quantized:
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
w1_scale = self.w1_scale[s:e, :, :]
|
||||
w2_scale = self.w2_scale[s:e, :, :]
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale)
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config) -> "WeightTensors":
|
||||
|
||||
if config.quant_dtype is None:
|
||||
# just make normal dtype weights
|
||||
w1, w2 = make_non_quant_weights(e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
dtype=config.dtype)
|
||||
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
|
||||
|
||||
assert config.quant_dtype == torch.float8_e4m3fn
|
||||
if not config.is_fp8_block_quantized():
|
||||
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
per_out_channel_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
block_size=config.quant_block_shape,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
|
||||
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
|
||||
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = (torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
|
||||
# We dequant and use that as hidden_states so the tests are stable.
|
||||
# quantizing and dequantizing yield slightly different results
|
||||
# depending on the hardware. Here we, quantize and dequantize
|
||||
# first - so further quantize and dequantize will yeild the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a,
|
||||
use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view(
|
||||
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
|
||||
config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts,
|
||||
config.E)
|
||||
score = torch.randn((m, global_num_experts),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||
False)
|
||||
topk_ids = topk_ids.to(config.topk_ids_dtype)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
topk_ids[mi] = torch.randperm(config.E)[:topk]
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1:
|
||||
expert_map = torch.full((global_num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_config=config.quant_config,
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
rank_tensors: RankTensors) -> torch.Tensor:
|
||||
|
||||
return torch_experts(a=rank_tensors.hidden_states,
|
||||
w1=weights.w1,
|
||||
w2=weights.w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=weights.w1_scale,
|
||||
w2_scale=weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
quant_dtype=config.quant_dtype,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
block_shape=config.quant_block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1)
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
config: Config, moe: FusedMoEConfig,
|
||||
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
if config.fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
elif config.fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif config.fused_experts_type == CutlassExpertsFp8:
|
||||
use_batched_format = config.is_batched_prepare_finalize()
|
||||
num_experts = (moe.num_local_experts
|
||||
if use_batched_format else moe.num_experts)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
"per_out_ch_quant": config.is_per_out_ch_quant,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"use_batched_format": use_batched_format
|
||||
}
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
|
||||
return experts
|
||||
|
||||
|
||||
def make_modular_kernel(config: Config,
|
||||
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=config.E,
|
||||
experts_per_token=config.topk,
|
||||
hidden_dim=config.K,
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
quant_config=config.quant_config,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = None
|
||||
if config.needs_all2all():
|
||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
|
||||
assert prepare_finalize is not None
|
||||
else:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
fused_experts = make_fused_experts(config, moe,
|
||||
prepare_finalize.num_dispatchers())
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
|
||||
def run_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states.clone(
|
||||
), # impls might update the tensor in place
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
return out
|
173
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
173
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
PASS = 1
|
||||
FAIL = 2
|
||||
SKIP = 3
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def make_feature_matrix(csv_file_path: str):
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(config: Config,
|
||||
success: Result,
|
||||
results_df: Optional[pd.DataFrame] = None):
|
||||
config_dict = asdict(config)
|
||||
config_dict['prepare_finalize_type'] = config_dict[
|
||||
'prepare_finalize_type'].__name__
|
||||
config_dict['fused_experts_type'] = config_dict[
|
||||
'fused_experts_type'].__name__
|
||||
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict['quant_config']
|
||||
del config_dict['quant_config']
|
||||
if quant_config_dict is None:
|
||||
quant_config = FusedMoEQuantConfig(None)
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
result_dict = config_dict | {'success': success.name}
|
||||
|
||||
result_df = pd.DataFrame([result_dict])
|
||||
if results_df is None:
|
||||
results_df = result_df
|
||||
else:
|
||||
results_df = pd.concat([results_df, result_df], ignore_index=True)
|
||||
|
||||
return results_df
|
||||
|
||||
Ms = [64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [[4, 1]]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
FE_TYPES = MK_FUSED_EXPERT_TYPES
|
||||
Q_TYPES = MK_QUANT_CONFIGS
|
||||
|
||||
combinations = list(
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
|
||||
|
||||
results_df: Optional[pd.DataFrame] = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations): #noqa: E501
|
||||
config = Config(Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None)
|
||||
|
||||
success = None
|
||||
if config.is_valid():
|
||||
print(f"Running config : {config.describe()} ...")
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker,
|
||||
vllm_config, env_dict, config,
|
||||
weights)
|
||||
success = Result.PASS
|
||||
except Exception as _:
|
||||
success = Result.FAIL
|
||||
else:
|
||||
success = Result.SKIP
|
||||
|
||||
results_df = add_to_results(config, success, results_df)
|
||||
|
||||
if results_df is not None:
|
||||
results_df.to_csv(f"{csv_file_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"))
|
||||
|
||||
parser.add_argument("-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file")
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_path = args.feature_matrix_csv_file_path
|
||||
assert csv_path.endswith(
|
||||
'csv'), f"Need a file path ending with .csv, got {csv_path}"
|
||||
assert Path(csv_path).parent.is_dir(
|
||||
), f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
|
||||
make_feature_matrix(args.feature_matrix_csv_file_path)
|
87
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
87
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
|
||||
if has_pplx():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
|
||||
if has_deep_ep():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
|
||||
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES = [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
TritonExperts,
|
||||
]
|
||||
|
||||
MK_QUANT_CONFIGS = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
]
|
138
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
138
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
local_rank: int):
|
||||
|
||||
import tempfile
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
set_current_vllm_config(vllm_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=local_rank,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.
|
||||
tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.
|
||||
pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)),
|
||||
backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
|
||||
P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
cpu_group = None
|
||||
if vllm_config is not None:
|
||||
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
vllm_config,
|
||||
cpu_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch_with_config(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
|
||||
vllm_config: VllmConfig,
|
||||
env_dict: dict[Any, Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
127
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
127
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from itertools import product
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
def do_profile(fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5):
|
||||
for _ in range(num_warmups):
|
||||
fn(**fn_kwargs)
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
) as tprof:
|
||||
fn(**fn_kwargs)
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
# TODO (varun): Add a descriptive trace file name
|
||||
tprof.export_chrome_trace(
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
|
||||
|
||||
|
||||
def profile_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> None:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
# make modular kernel
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from .cli_args import make_config, make_config_arg_parser
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
assert args.torch_trace_dir_path is not None, (
|
||||
"Please pass in a directory to store torch traces")
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
142
tests/kernels/moe/modular_kernel_tools/utils.py
Normal file
142
tests/kernels/moe/modular_kernel_tools/utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size_k: int,
|
||||
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(
|
||||
int(math.ceil(m / block_size_k)) * block_size_k,
|
||||
int(math.ceil(n / block_size_n)) * block_size_n,
|
||||
),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_size_k,
|
||||
x_padded.size(1) // block_size_k, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_non_quant_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2
|
||||
"""
|
||||
device = torch.cuda.current_device()
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
|
||||
return w1, w2
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
|
||||
(k + (block_k - 1)) // block_k)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
def make_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_out_channel_quant: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
|
||||
|
||||
# w1 -> w1_q, w2 -> w2_q
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_channel_quant else 1
|
||||
k_b_scales = k if per_out_channel_quant else 1
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
@ -4,7 +4,6 @@
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
@ -15,10 +14,9 @@ from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.utils import get_open_port, has_deep_ep
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
|
214
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
214
tests/kernels/moe/test_modular_kernel_combinations.py
Normal file
@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||
reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||
parallel_launch_with_config)
|
||||
|
||||
# TODO (varun): These requirements are very strict and could be relaxed.
|
||||
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
|
||||
|
||||
meets_package_requirements = pytest.mark.skipif(
|
||||
not has_all_packages,
|
||||
reason="Requires deep_ep & deep_gemm & pplx packages",
|
||||
)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
assert config.is_valid()
|
||||
print(f"Testing config \n{config.describe()} ...")
|
||||
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
|
||||
|
||||
Ms = [32, 64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
||||
|
||||
|
||||
def is_nyi_config(config: Config) -> bool:
|
||||
# We know these configs to be legitimate. but still fail.
|
||||
|
||||
if (config.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
TritonExperts, TritonOrDeepGemmExperts
|
||||
]):
|
||||
# The triton kernels expect both per-act-token-quant and
|
||||
# per-out-ch-quant or neither.
|
||||
unsupported_quant_config = ((config.is_per_act_token_quant +
|
||||
config.is_per_out_ch_quant) == 1)
|
||||
return unsupported_quant_config
|
||||
|
||||
# cutlass kernels dont support expert_maps yet.
|
||||
return config.fused_experts_type == CutlassExpertsFp8
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@pytest.mark.parametrize("n", Ns)
|
||||
@pytest.mark.parametrize("e", Es)
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination",
|
||||
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@meets_package_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
||||
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
print(f"{config.describe()}")
|
||||
run(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("k", Ks)
|
||||
@pytest.mark.parametrize("n", Ns)
|
||||
@pytest.mark.parametrize("e", Es)
|
||||
@pytest.mark.parametrize("dtype", DTYPEs)
|
||||
@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS)
|
||||
@pytest.mark.parametrize(
|
||||
"combination",
|
||||
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
@meets_package_requirements
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
||||
config = Config(
|
||||
Ms=Ms,
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=TOPKs,
|
||||
dtype=dtype,
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=combination[0],
|
||||
fused_experts_type=combination[1],
|
||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
if not config.is_valid():
|
||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||
|
||||
if is_nyi_config(config):
|
||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||
|
||||
run(config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Ability to test individual PrepareAndFinalize and FusedExperts combination
|
||||
from .modular_kernel_tools.cli_args import (make_config,
|
||||
make_config_arg_parser)
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
@ -1072,6 +1072,7 @@ def torch_experts(
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
apply_router_weights_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert (global_num_experts == -1
|
||||
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||
@ -1081,11 +1082,17 @@ def torch_experts(
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
if apply_router_weights_on_input:
|
||||
assert topk == 1
|
||||
a = a * topk_weight.to(a.dtype)
|
||||
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
|
||||
if a1_scale:
|
||||
assert not per_act_token_quant and block_shape is None
|
||||
a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype,
|
||||
per_act_token_quant, block_shape)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
@ -1104,6 +1111,7 @@ def torch_experts(
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
elif block_shape is not None:
|
||||
# block quantized
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||
@ -1121,15 +1129,27 @@ def torch_experts(
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
|
||||
|
||||
tmp1 = a[mask].to(f32) * scales
|
||||
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||
tmp1 = tmp1 @ w1_dq
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp1 = (tmp1 @ w1_dq).to(out.dtype)
|
||||
|
||||
tmp2 = SiluAndMul()(tmp1).to(out.dtype)
|
||||
|
||||
tmp2, b_scale = moe_kernel_quantize_input(
|
||||
tmp2, a2_scale, quant_dtype, per_act_token_quant,
|
||||
block_shape)
|
||||
assert b_scale is not None
|
||||
|
||||
tmp2 = tmp2.to(f32) * b_scale
|
||||
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]).to(f32) *
|
||||
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
|
||||
if apply_router_weights_on_input:
|
||||
return out
|
||||
else:
|
||||
return (out.view(M, -1, w2.shape[1]).to(f32) *
|
||||
topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype)
|
||||
|
||||
|
||||
def torch_moe(a: torch.Tensor,
|
||||
|
@ -240,8 +240,7 @@ class DeviceCommunicatorBase:
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module.moe_config,
|
||||
module.quant_config)
|
||||
module.quant_method.init_prepare_finalize(module.moe_config)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
|
@ -37,7 +37,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
))
|
||||
self.allow_deep_gemm = allow_deep_gemm
|
||||
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
|
@ -81,13 +81,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_prepare_finalize(self, moe: FusedMoEConfig,
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
@staticmethod
|
||||
def maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
self.moe = moe
|
||||
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
@ -160,8 +159,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
and moe.quant_config.block_shape
|
||||
== DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
@ -169,11 +166,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
|
||||
def init_prepare_finalize(self, moe: FusedMoEConfig):
|
||||
self.moe = moe
|
||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(
|
||||
self.moe)
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
if prepare_finalize is not None:
|
||||
logger.debug("%s", prepare_finalize.__class__.__name__)
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, moe)
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
|
@ -7,7 +7,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
@ -44,8 +45,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant
|
||||
and use_fp8_w8a8)
|
||||
|
||||
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
|
||||
self.block_shape == deep_gemm_block_shape())
|
||||
|
||||
self.deep_gemm_expert = DeepGemmExperts(
|
||||
) if self.allow_deep_gemm else None
|
||||
|
||||
|
Reference in New Issue
Block a user