mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature]Add async tensor parallelism using compilation pass (#17882)
Signed-off-by: cascade812 <cascade812@outlook.com>
This commit is contained in:
@ -316,6 +316,7 @@ steps:
|
||||
- pytest -v -s compile/test_fusion.py
|
||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
|
@ -5,6 +5,8 @@ from typing import Callable, Union
|
||||
|
||||
from torch import fx
|
||||
|
||||
from vllm.compilation.fx_utils import (find_specified_fn,
|
||||
find_specified_fn_maybe)
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
@ -44,3 +46,19 @@ class TestBackend:
|
||||
self.graph_post_pass = deepcopy(graph)
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
def check_before_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe, \
|
||||
ops_fully_replaced=True):
|
||||
for op in ops:
|
||||
find_fn(self.graph_pre_pass.nodes, op)
|
||||
if ops_fully_replaced:
|
||||
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
|
||||
|
||||
def check_after_ops(self, ops,
|
||||
find_fn=find_specified_fn, \
|
||||
find_fn_maybe=find_specified_fn_maybe):
|
||||
for op in ops:
|
||||
find_fn(self.graph_post_pass.nodes, op)
|
||||
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
|
||||
|
248
tests/compile/test_async_tp.py
Normal file
248
tests/compile/test_async_tp.py
Normal file
@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AsyncTPPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import (compare_two_settings, create_new_process_for_each_test,
|
||||
multi_gpu_test)
|
||||
from .backend import TestBackend
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class TestMMRSModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.gate_proj = torch.nn.Parameter(torch.empty(
|
||||
(self.hidden_size * 2, hidden_size)),
|
||||
requires_grad=False)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Forward pass implementing the mm + reduce scatter in the FX graph
|
||||
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
|
||||
# matrix multiplication
|
||||
permute = self.gate_proj.permute(1, 0)
|
||||
mm = torch.mm(view, permute)
|
||||
reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
|
||||
return reduce_scatter
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGMMModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = torch.nn.Parameter(torch.empty(
|
||||
(hidden_size, hidden_size)),
|
||||
requires_grad=False)
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.weight, std=0.02)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Forward pass implementing the mm + all gather in the FX graph
|
||||
"""
|
||||
# Reshape input
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_gather = tensor_model_parallel_all_gather(view, dim=0)
|
||||
permute = self.weight.permute(1, 0)
|
||||
mm = torch.mm(all_gather, permute)
|
||||
return mm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_gather.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
# need to use torch.mp.spawn otherwise will have problems with
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, test_model,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype),
|
||||
nprocs=nprocs)
|
||||
|
||||
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
# initialize distributed
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
|
||||
enable_async_tp=True, ), )
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
task="auto",
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
async_tp_pass = AsyncTPPass(vllm_config)
|
||||
backend = TestBackend(async_tp_pass)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
dtype=dtype,
|
||||
requires_grad=False)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
# In pre-nodes, all gather or reduce scatter should exist,
|
||||
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
|
||||
backend.check_before_ops(model.ops_in_model_before(),
|
||||
ops_fully_replaced=False)
|
||||
|
||||
# In post-nodes, fused_matmul_reduce_scatter or \
|
||||
# fused_all_gather_matmul should exist
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("async_tp_enabled", [True])
|
||||
@pytest.mark.parametrize("distributed_backend", ["mp"])
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
def test_async_tp_pass_correctness(
|
||||
model_id: str,
|
||||
tp_size: int,
|
||||
async_tp_enabled: bool,
|
||||
distributed_backend: str,
|
||||
eager_mode: bool,
|
||||
num_gpus_available: int,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
pp_size = 1
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
|
||||
common_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
|
||||
compilation_config = {
|
||||
'level': 3,
|
||||
'compile_sizes': [2, 4, 8],
|
||||
'splitting_ops': [],
|
||||
'pass_config': {
|
||||
'enable_async_tp': async_tp_enabled
|
||||
},
|
||||
}
|
||||
|
||||
async_tp_env = tp_env = {
|
||||
"VLLM_USE_V1": "1",
|
||||
}
|
||||
|
||||
aysnc_tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_backend,
|
||||
"--compilation_config",
|
||||
json.dumps(compilation_config),
|
||||
]
|
||||
|
||||
tp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
]
|
||||
|
||||
compare_two_settings(model_id,
|
||||
aysnc_tp_args,
|
||||
tp_args,
|
||||
async_tp_env,
|
||||
tp_env,
|
||||
method="generate")
|
@ -29,6 +29,10 @@ class TestModel(torch.nn.Module):
|
||||
self.cutlass_fp8_enabled = cutlass_fp8_enabled
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
else:
|
||||
@ -59,6 +63,15 @@ class TestModel(torch.nn.Module):
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [QUANT_OPS[self.key]]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
||||
@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend.graph_pre_pass.nodes
|
||||
post_nodes = backend.graph_post_pass.nodes
|
||||
|
||||
# static is per-tensor, dynamic is per-token
|
||||
key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
per_tensor=static,
|
||||
symmetric=True)
|
||||
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
|
||||
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
|
||||
fp8_quant = QUANT_OPS[key]
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
||||
find_auto_fn(pre_nodes, fp8_quant)
|
||||
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
find_auto_fn(post_nodes, rms_quant)
|
||||
find_auto_fn(post_nodes, add_rms_quant)
|
||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
||||
backend.check_after_ops(model.ops_in_model_after(), find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
|
@ -5,9 +5,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
|
||||
find_specified_fn,
|
||||
find_specified_fn_maybe, is_func)
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
|
||||
PassConfig, VllmConfig)
|
||||
@ -21,17 +19,6 @@ from vllm.utils import update_environment_variables
|
||||
from ..utils import multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
OPS_IN_MODEL_BEFORE = [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
]
|
||||
|
||||
OPS_IN_MODEL_AFTER = [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default,
|
||||
]
|
||||
|
||||
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -78,6 +65,18 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
return norm_output, residual_output
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
torch.ops.vllm.reduce_scatter.default,
|
||||
torch.ops.vllm.all_gather.default
|
||||
]
|
||||
|
||||
def ops_in_model(self):
|
||||
return [torch.ops._C.fused_add_rms_norm.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@ -156,26 +155,16 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend_no_func.graph_pre_pass.nodes
|
||||
post_nodes = backend_no_func.graph_post_pass.nodes
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
for op in OPS_IN_MODEL_BEFORE:
|
||||
find_specified_fn(pre_nodes, op)
|
||||
for op in OPS_IN_MODEL_AFTER:
|
||||
assert find_specified_fn_maybe(pre_nodes, op) is None
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
for op in OPS_IN_MODEL_AFTER:
|
||||
find_specified_fn(post_nodes, op)
|
||||
for op in OPS_IN_MODEL_BEFORE:
|
||||
assert find_specified_fn_maybe(post_nodes, op) is None
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in OPS_IN_MODEL:
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
|
||||
op) is None # noqa: E501
|
||||
@ -183,7 +172,7 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in OPS_IN_MODEL:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in OPS_IN_MODEL)
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
|
126
vllm/compilation/collective_fusion.py
Normal file
126
vllm/compilation/collective_fusion.py
Normal file
@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
|
||||
def __init__(self, dtype: torch.dtype, device: str):
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor):
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor):
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
|
||||
def get_inputs(self):
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(pattern, replacement, self.get_inputs(),
|
||||
pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmInductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass")
|
||||
GEMMReduceScatterPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_async_tp_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_async_tp_pass")
|
||||
self.end_and_log()
|
@ -6,6 +6,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .collective_fusion import AsyncTPPass
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
@ -54,6 +55,8 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
|
@ -243,24 +243,25 @@ class SequenceParallelismPass(VllmInductorPass):
|
||||
pass_name="sequence_parallelism_pass")
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
EmbeddingAllReduceRMSNormPattern(
|
||||
epsilon, self.dtype, self.device).register(self.patterns)
|
||||
epsilon, self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
|
||||
LastAllReduceRMSNormPattern(epsilon, self.dtype,
|
||||
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
|
||||
self.device).register(self.patterns)
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
|
||||
# only do replace for specific shapes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return shape is not None and shape % tp_size == 0
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_sequence_parallelism_pass")
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_sequence_parallelism_pass")
|
||||
self.end_and_log()
|
||||
|
@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass):
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.dtype = config.model_config.dtype if config.model_config else None
|
||||
self.model_dtype = config.model_config.dtype if config.model_config \
|
||||
else None
|
||||
self.device = config.device_config.device if config.device_config \
|
||||
else None
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
@ -3652,6 +3652,8 @@ class PassConfig:
|
||||
"""Whether to enable the custom no-op elimination pass."""
|
||||
enable_sequence_parallelism: bool = False
|
||||
"""Whether to enable sequence parallelism."""
|
||||
enable_async_tp: bool = False
|
||||
"""Whether to enable async TP."""
|
||||
|
||||
def uuid(self):
|
||||
"""
|
||||
@ -3661,7 +3663,8 @@ class PassConfig:
|
||||
compilation.
|
||||
"""
|
||||
include = {
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism"
|
||||
"enable_fusion", "enable_noop", "enable_sequence_parallelism",
|
||||
"enable_async_tp"
|
||||
}
|
||||
dict_ = {k: v for k, v in asdict(self).items() if k in include}
|
||||
return InductorPass.hash_dict(dict_)
|
||||
@ -4274,6 +4277,12 @@ class VllmConfig:
|
||||
|
||||
if self.compilation_config is None:
|
||||
self.compilation_config = CompilationConfig()
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
self.compilation_config.pass_config.enable_sequence_parallelism = \
|
||||
True
|
||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
if envs.VLLM_USE_V1 and self.model_config is not None and \
|
||||
|
@ -120,7 +120,7 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group.reduce_scatter(tensor, dim)
|
||||
return group._reduce_scatter_out_place(tensor, dim)
|
||||
|
||||
|
||||
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
@ -136,7 +136,7 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group.all_gather(tensor, dim)
|
||||
return group._all_gather_out_place(tensor, dim)
|
||||
|
||||
|
||||
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
|
||||
@ -161,6 +161,7 @@ if supports_custom_op():
|
||||
op_func=reduce_scatter,
|
||||
mutates_args=[],
|
||||
fake_impl=reduce_scatter_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
@ -168,6 +169,7 @@ if supports_custom_op():
|
||||
op_func=all_gather,
|
||||
mutates_args=[],
|
||||
fake_impl=all_gather_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
@ -367,6 +369,16 @@ class GroupCoordinator:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if self.use_custom_op_call:
|
||||
return torch.ops.vllm.all_gather(input_,
|
||||
dim,
|
||||
world_size,
|
||||
group_name=self.unique_name)
|
||||
else:
|
||||
return self._all_gather_out_place(input_, dim)
|
||||
|
||||
def _all_gather_out_place(self, input_: torch.Tensor,
|
||||
dim: int) -> torch.Tensor:
|
||||
return self.device_communicator.all_gather(input_, dim)
|
||||
|
||||
def reduce_scatter(self,
|
||||
@ -379,6 +391,16 @@ class GroupCoordinator:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
if self.use_custom_op_call:
|
||||
return torch.ops.vllm.reduce_scatter(input_,
|
||||
dim,
|
||||
world_size,
|
||||
group_name=self.unique_name)
|
||||
else:
|
||||
return self._reduce_scatter_out_place(input_, dim)
|
||||
|
||||
def _reduce_scatter_out_place(self, input_: torch.Tensor,
|
||||
dim: int) -> torch.Tensor:
|
||||
return self.device_communicator.reduce_scatter(input_, dim)
|
||||
|
||||
def gather(self,
|
||||
|
Reference in New Issue
Block a user