mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
224 lines
7.9 KiB
Python
224 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import multiprocessing
|
|
import random
|
|
|
|
import pytest
|
|
import ray
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
|
from vllm.platforms import current_platform
|
|
|
|
from ..utils import (
|
|
ensure_model_parallel_initialized,
|
|
init_test_distributed_environment,
|
|
multi_process_parallel,
|
|
)
|
|
|
|
torch.manual_seed(42)
|
|
random.seed(44)
|
|
# Size over 8MB is sufficient for custom quick allreduce.
|
|
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
|
|
for i, v in enumerate(test_sizes):
|
|
test_sizes[i] -= v % 8
|
|
|
|
|
|
@ray.remote(num_gpus=1, max_calls=1)
|
|
def graph_quickreduce(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tp_size,
|
|
pp_size,
|
|
rank,
|
|
distributed_init_port,
|
|
):
|
|
with monkeypatch.context() as m:
|
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
|
group = get_tp_group().device_group
|
|
|
|
# A small all_reduce for warmup.
|
|
# this is needed because device communicators might be created lazily
|
|
# (e.g. NCCL). This will ensure that the communicator is initialized
|
|
# before any communication happens, so that this group can be used for
|
|
# graph capture immediately.
|
|
data = torch.zeros(1)
|
|
data = data.to(device=device)
|
|
torch.distributed.all_reduce(data, group=group)
|
|
torch.cuda.synchronize()
|
|
del data
|
|
|
|
# we use the first group to communicate once
|
|
# and the second group to communicate twice
|
|
# and so on
|
|
# this is used to demonstrate that each group can
|
|
# communicate independently
|
|
num_communication = rank // tp_size + 1
|
|
|
|
for sz in test_sizes:
|
|
for dtype in [torch.float16, torch.bfloat16]:
|
|
with graph_capture(device=device) as graph_capture_context:
|
|
inp1 = torch.randint(
|
|
1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
|
)
|
|
inp2 = torch.randint(
|
|
-23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device()
|
|
)
|
|
torch.cuda.synchronize()
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph, stream=graph_capture_context.stream):
|
|
for _ in range(num_communication):
|
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
|
dist.all_reduce(inp1, group=group)
|
|
out2 = tensor_model_parallel_all_reduce(inp2)
|
|
dist.all_reduce(inp2, group=group)
|
|
graph.replay()
|
|
torch.testing.assert_close(out1, inp1, atol=2.5, rtol=0.1)
|
|
torch.testing.assert_close(out2, inp2, atol=2.5, rtol=0.1)
|
|
|
|
|
|
@ray.remote(num_gpus=1, max_calls=1)
|
|
def eager_quickreduce(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tp_size,
|
|
pp_size,
|
|
rank,
|
|
distributed_init_port,
|
|
):
|
|
with monkeypatch.context() as m:
|
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
|
|
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
|
|
|
# Size over 8MB is sufficient for custom quick allreduce.
|
|
sz = 16 * 1024 * 1024
|
|
fa = get_tp_group().device_communicator.qr_comm
|
|
inp = torch.tensor(
|
|
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
|
|
)
|
|
out = fa.quick_all_reduce(inp)
|
|
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
|
|
|
inp = torch.tensor(
|
|
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
|
|
)
|
|
out = fa.quick_all_reduce(inp)
|
|
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
|
)
|
|
@pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"])
|
|
@pytest.mark.parametrize("tp_size", [2])
|
|
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
|
@pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce])
|
|
def test_custom_quick_allreduce(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tp_size,
|
|
pipeline_parallel_size,
|
|
test_target,
|
|
quant_mode,
|
|
):
|
|
world_size = tp_size * pipeline_parallel_size
|
|
if world_size > torch.cuda.device_count():
|
|
pytest.skip("Not enough GPUs to run the test.")
|
|
|
|
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
|
|
|
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
|
|
|
|
|
def qr_variable_input(rank, world_size):
|
|
"""
|
|
When the tensor parallelism is set to 4 or 8, frequent changes
|
|
in the input shape can cause QuickReduce to hang (this issue
|
|
has been observed with the gpt_oss model).
|
|
"""
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
qr_max_size = None # MB
|
|
_ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
|
|
ranks = []
|
|
for i in range(world_size):
|
|
ranks.append(i)
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
init_method="tcp://127.0.0.1:29500",
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
|
|
|
|
handle = ops.qr_get_handle(_ptr)
|
|
world_size = dist.get_world_size(group=cpu_group)
|
|
handles = [None] * world_size
|
|
dist.all_gather_object(handles, handle, group=cpu_group)
|
|
ops.qr_open_handles(_ptr, handles)
|
|
|
|
num = 1
|
|
s1 = 1024
|
|
while num < 50000: # 50000 is sufficient to identify issues.
|
|
dtype = torch.float16
|
|
if num % 2 == 0:
|
|
s2 = 1024
|
|
inp1 = torch.zeros(
|
|
(s1, s2), dtype=dtype, device=torch.cuda.current_device()
|
|
)
|
|
else:
|
|
s2 = 2048
|
|
inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device())
|
|
result = torch.empty_like(inp1)
|
|
# FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
|
|
ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
|
|
try:
|
|
if inp1[0, 0] == 0:
|
|
assert torch.all(result == 0)
|
|
else:
|
|
assert torch.all(result == world_size)
|
|
except AssertionError:
|
|
print("Assertion failed! Allreduce results are incorrect.")
|
|
raise
|
|
num += 1
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
|
)
|
|
@pytest.mark.parametrize("tp_size", [4, 8])
|
|
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
|
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
|
|
world_size = tp_size * pipeline_parallel_size
|
|
if world_size > torch.cuda.device_count():
|
|
pytest.skip("Not enough GPUs to run the test.")
|
|
|
|
multiprocessing.set_start_method("spawn", force=True)
|
|
# 60s is enough
|
|
timeout = 60
|
|
processes = []
|
|
for rank in range(tp_size):
|
|
p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size))
|
|
p.start()
|
|
processes.append((rank, p))
|
|
for rank, p in processes:
|
|
p.join(timeout=timeout)
|
|
if p.is_alive():
|
|
for r, proc in processes:
|
|
if proc.is_alive():
|
|
proc.terminate()
|
|
proc.join()
|
|
raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1)
|