mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix][Rocm] fix qr error when different inp shape (#25892)
Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
|
|||||||
__global__ __quickreduce_launch_bounds_two_shot__ static void
|
__global__ __quickreduce_launch_bounds_two_shot__ static void
|
||||||
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
||||||
int rank, uint8_t** dbuffer_list,
|
int rank, uint8_t** dbuffer_list,
|
||||||
uint32_t data_offset, uint32_t flag_color) {
|
uint32_t data_offset, uint32_t flag_color,
|
||||||
|
int64_t data_size_per_phase) {
|
||||||
int block = blockIdx.x;
|
int block = blockIdx.x;
|
||||||
int grid = gridDim.x;
|
int grid = gridDim.x;
|
||||||
|
|
||||||
while (block < num_blocks) {
|
while (block < num_blocks) {
|
||||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
|
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
|
||||||
flag_color);
|
flag_color, data_size_per_phase);
|
||||||
block += grid;
|
block += grid;
|
||||||
flag_color++;
|
flag_color++;
|
||||||
}
|
}
|
||||||
@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
|||||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||||
num_blocks, rank, dbuffer_list, data_offset, \
|
num_blocks, rank, dbuffer_list, data_offset, \
|
||||||
flag_color); \
|
flag_color, this->kMaxProblemSize); \
|
||||||
} else if (world_size == 4) { \
|
} else if (world_size == 4) { \
|
||||||
using LineCodec = __codec<T, 4>; \
|
using LineCodec = __codec<T, 4>; \
|
||||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||||
num_blocks, rank, dbuffer_list, data_offset, \
|
num_blocks, rank, dbuffer_list, data_offset, \
|
||||||
flag_color); \
|
flag_color, this->kMaxProblemSize); \
|
||||||
} else if (world_size == 8) { \
|
} else if (world_size == 8) { \
|
||||||
using LineCodec = __codec<T, 8>; \
|
using LineCodec = __codec<T, 8>; \
|
||||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||||
num_blocks, rank, dbuffer_list, data_offset, \
|
num_blocks, rank, dbuffer_list, data_offset, \
|
||||||
flag_color); \
|
flag_color, this->kMaxProblemSize); \
|
||||||
}
|
}
|
||||||
|
|
||||||
enum QuickReduceQuantLevel {
|
enum QuickReduceQuantLevel {
|
||||||
|
@ -553,13 +553,12 @@ struct AllReduceTwoshot {
|
|||||||
int const rank, // rank index
|
int const rank, // rank index
|
||||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||||
uint32_t const data_offset, // offset to start of the data buffer
|
uint32_t const data_offset, // offset to start of the data buffer
|
||||||
uint32_t flag_color) {
|
uint32_t flag_color, int64_t data_size_per_phase) {
|
||||||
// Topology
|
// Topology
|
||||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||||
uint8_t* rank_buffer = buffer_list[rank];
|
uint8_t* rank_buffer = buffer_list[rank];
|
||||||
Codec codec(thread, rank);
|
Codec codec(thread, rank);
|
||||||
int block_id = blockIdx.x;
|
int block_id = blockIdx.x;
|
||||||
int grid_size = gridDim.x;
|
|
||||||
// --------------------------------------------------------
|
// --------------------------------------------------------
|
||||||
// Read input into registers
|
// Read input into registers
|
||||||
int32x4_t tA[kAtoms];
|
int32x4_t tA[kAtoms];
|
||||||
@ -588,12 +587,10 @@ struct AllReduceTwoshot {
|
|||||||
// rank responsible for this segment.
|
// rank responsible for this segment.
|
||||||
uint32_t comm_data0_offset =
|
uint32_t comm_data0_offset =
|
||||||
data_offset + block_id * Codec::kTransmittedTileSize;
|
data_offset + block_id * Codec::kTransmittedTileSize;
|
||||||
uint32_t comm_data1_offset =
|
uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;
|
||||||
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
|
||||||
|
|
||||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||||
uint32_t comm_flags1_offset =
|
uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;
|
||||||
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
|
||||||
|
|
||||||
for (int r = 0; r < kWorldSize; r++) {
|
for (int r = 0; r < kWorldSize; r++) {
|
||||||
int32x4_t* send_buffer =
|
int32x4_t* send_buffer =
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,6 +9,7 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
|
||||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -134,3 +136,88 @@ def test_custom_quick_allreduce(
|
|||||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
||||||
|
|
||||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||||
|
|
||||||
|
|
||||||
|
def qr_variable_input(rank, world_size):
|
||||||
|
"""
|
||||||
|
When the tensor parallelism is set to 4 or 8, frequent changes
|
||||||
|
in the input shape can cause QuickReduce to hang (this issue
|
||||||
|
has been observed with the gpt_oss model).
|
||||||
|
"""
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
qr_max_size = None # MB
|
||||||
|
_ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
|
||||||
|
ranks = []
|
||||||
|
for i in range(world_size):
|
||||||
|
ranks.append(i)
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="tcp://127.0.0.1:29500",
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
|
||||||
|
|
||||||
|
handle = ops.qr_get_handle(_ptr)
|
||||||
|
world_size = dist.get_world_size(group=cpu_group)
|
||||||
|
handles = [None] * world_size
|
||||||
|
dist.all_gather_object(handles, handle, group=cpu_group)
|
||||||
|
ops.qr_open_handles(_ptr, handles)
|
||||||
|
|
||||||
|
num = 1
|
||||||
|
s1 = 1024
|
||||||
|
while num < 50000: # 50000 is sufficient to identify issues.
|
||||||
|
dtype = torch.float16
|
||||||
|
if num % 2 == 0:
|
||||||
|
s2 = 1024
|
||||||
|
inp1 = torch.zeros(
|
||||||
|
(s1, s2), dtype=dtype, device=torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
s2 = 2048
|
||||||
|
inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device())
|
||||||
|
result = torch.empty_like(inp1)
|
||||||
|
# FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
|
||||||
|
ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
|
||||||
|
try:
|
||||||
|
if inp1[0, 0] == 0:
|
||||||
|
assert torch.all(result == 0)
|
||||||
|
else:
|
||||||
|
assert torch.all(result == world_size)
|
||||||
|
except AssertionError:
|
||||||
|
print("Assertion failed! Allreduce results are incorrect.")
|
||||||
|
raise
|
||||||
|
num += 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("tp_size", [4, 8])
|
||||||
|
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
||||||
|
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
|
||||||
|
world_size = tp_size * pipeline_parallel_size
|
||||||
|
if world_size > torch.cuda.device_count():
|
||||||
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
|
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
# 60s is enough
|
||||||
|
timeout = 60
|
||||||
|
processes = []
|
||||||
|
for rank in range(tp_size):
|
||||||
|
p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size))
|
||||||
|
p.start()
|
||||||
|
processes.append((rank, p))
|
||||||
|
for rank, p in processes:
|
||||||
|
p.join(timeout=timeout)
|
||||||
|
if p.is_alive():
|
||||||
|
for r, proc in processes:
|
||||||
|
if proc.is_alive():
|
||||||
|
proc.terminate()
|
||||||
|
proc.join()
|
||||||
|
raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1)
|
||||||
|
Reference in New Issue
Block a user