[Bugfix][Rocm] fix qr error when different inp shape (#25892)

Signed-off-by: Haoyang Li <lihaoyang0109@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
haoyangli-amd
2025-10-14 01:04:21 +08:00
committed by GitHub
parent a1b2d658ee
commit 134f70b3ed
3 changed files with 96 additions and 11 deletions

View File

@ -22,13 +22,14 @@ template <typename AllReduceKernel, typename T>
__global__ __quickreduce_launch_bounds_two_shot__ static void __global__ __quickreduce_launch_bounds_two_shot__ static void
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
int rank, uint8_t** dbuffer_list, int rank, uint8_t** dbuffer_list,
uint32_t data_offset, uint32_t flag_color) { uint32_t data_offset, uint32_t flag_color,
int64_t data_size_per_phase) {
int block = blockIdx.x; int block = blockIdx.x;
int grid = gridDim.x; int grid = gridDim.x;
while (block < num_blocks) { while (block < num_blocks) {
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
flag_color); flag_color, data_size_per_phase);
block += grid; block += grid;
flag_color++; flag_color++;
} }
@ -41,21 +42,21 @@ allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \ hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \ num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \ flag_color, this->kMaxProblemSize); \
} else if (world_size == 4) { \ } else if (world_size == 4) { \
using LineCodec = __codec<T, 4>; \ using LineCodec = __codec<T, 4>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \ using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \ hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \ num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \ flag_color, this->kMaxProblemSize); \
} else if (world_size == 8) { \ } else if (world_size == 8) { \
using LineCodec = __codec<T, 8>; \ using LineCodec = __codec<T, 8>; \
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \ using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \ hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
num_blocks, rank, dbuffer_list, data_offset, \ num_blocks, rank, dbuffer_list, data_offset, \
flag_color); \ flag_color, this->kMaxProblemSize); \
} }
enum QuickReduceQuantLevel { enum QuickReduceQuantLevel {

View File

@ -553,13 +553,12 @@ struct AllReduceTwoshot {
int const rank, // rank index int const rank, // rank index
uint8_t** __restrict__ buffer_list, // communication buffers uint8_t** __restrict__ buffer_list, // communication buffers
uint32_t const data_offset, // offset to start of the data buffer uint32_t const data_offset, // offset to start of the data buffer
uint32_t flag_color) { uint32_t flag_color, int64_t data_size_per_phase) {
// Topology // Topology
int thread = threadIdx.x + threadIdx.y * kWavefront; int thread = threadIdx.x + threadIdx.y * kWavefront;
uint8_t* rank_buffer = buffer_list[rank]; uint8_t* rank_buffer = buffer_list[rank];
Codec codec(thread, rank); Codec codec(thread, rank);
int block_id = blockIdx.x; int block_id = blockIdx.x;
int grid_size = gridDim.x;
// -------------------------------------------------------- // --------------------------------------------------------
// Read input into registers // Read input into registers
int32x4_t tA[kAtoms]; int32x4_t tA[kAtoms];
@ -588,12 +587,10 @@ struct AllReduceTwoshot {
// rank responsible for this segment. // rank responsible for this segment.
uint32_t comm_data0_offset = uint32_t comm_data0_offset =
data_offset + block_id * Codec::kTransmittedTileSize; data_offset + block_id * Codec::kTransmittedTileSize;
uint32_t comm_data1_offset = uint32_t comm_data1_offset = data_size_per_phase + comm_data0_offset;
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
uint32_t comm_flags1_offset = uint32_t comm_flags1_offset = (data_offset / 2) + comm_flags0_offset;
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
for (int r = 0; r < kWorldSize; r++) { for (int r = 0; r < kWorldSize; r++) {
int32x4_t* send_buffer = int32x4_t* send_buffer =

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import random import random
import pytest import pytest
@ -8,6 +9,7 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.distributed.parallel_state import get_tp_group, graph_capture
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -134,3 +136,88 @@ def test_custom_quick_allreduce(
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
def qr_variable_input(rank, world_size):
"""
When the tensor parallelism is set to 4 or 8, frequent changes
in the input shape can cause QuickReduce to hang (this issue
has been observed with the gpt_oss model).
"""
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
qr_max_size = None # MB
_ptr = ops.init_custom_qr(rank, world_size, qr_max_size)
ranks = []
for i in range(world_size):
ranks.append(i)
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=cpu_group)
ops.qr_open_handles(_ptr, handles)
num = 1
s1 = 1024
while num < 50000: # 50000 is sufficient to identify issues.
dtype = torch.float16
if num % 2 == 0:
s2 = 1024
inp1 = torch.zeros(
(s1, s2), dtype=dtype, device=torch.cuda.current_device()
)
else:
s2 = 2048
inp1 = torch.ones((s1, s2), dtype=dtype, device=torch.cuda.current_device())
result = torch.empty_like(inp1)
# FP = 0 INT8 = 1 INT6 = 2 INT4 = 3 NONE = 4
ops.qr_all_reduce(_ptr, inp1, result, 3, cast_bf2half=True)
try:
if inp1[0, 0] == 0:
assert torch.all(result == 0)
else:
assert torch.all(result == world_size)
except AssertionError:
print("Assertion failed! Allreduce results are incorrect.")
raise
num += 1
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
@pytest.mark.parametrize("tp_size", [4, 8])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
def test_custom_quick_allreduce_variable_input(tp_size, pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multiprocessing.set_start_method("spawn", force=True)
# 60s is enough
timeout = 60
processes = []
for rank in range(tp_size):
p = multiprocessing.Process(target=qr_variable_input, args=(rank, tp_size))
p.start()
processes.append((rank, p))
for rank, p in processes:
p.join(timeout=timeout)
if p.is_alive():
for r, proc in processes:
if proc.is_alive():
proc.terminate()
proc.join()
raise RuntimeError(f"QuickReduce hang detected after {timeout} seconds!")
if __name__ == "__main__":
test_custom_quick_allreduce_variable_input(tp_size=4, pipeline_parallel_size=1)