mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
509 lines
17 KiB
Python
509 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""
|
|
Benchmark script for device communicators:
|
|
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
|
and SymmMemCommunicator (multimem, two-shot).
|
|
|
|
for NCCL symmetric memory you need to set the environment variables
|
|
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
|
|
not use fast NVLS implementation for all reduce.
|
|
|
|
Usage:
|
|
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
|
|
|
Example:
|
|
torchrun --nproc_per_node=2 benchmark_device_communicators.py
|
|
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
from collections.abc import Callable
|
|
from contextlib import nullcontext
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
|
from vllm.distributed.device_communicators.pynccl import (
|
|
PyNcclCommunicator,
|
|
register_nccl_symmetric_ops,
|
|
)
|
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
|
set_graph_pool_id,
|
|
)
|
|
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Default sequence lengths to benchmark
|
|
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
|
|
|
# Fixed hidden size and dtype for all benchmarks
|
|
HIDDEN_SIZE = 8192
|
|
BENCHMARK_DTYPE = torch.bfloat16
|
|
|
|
# CUDA graph settings
|
|
CUDA_GRAPH_CAPTURE_CYCLES = 10
|
|
|
|
|
|
class CommunicatorBenchmark:
|
|
"""Benchmark class for testing device communicators."""
|
|
|
|
def __init__(
|
|
self,
|
|
rank: int,
|
|
world_size: int,
|
|
device: torch.device,
|
|
cpu_group: ProcessGroup,
|
|
sequence_lengths: list[int],
|
|
):
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.device = device
|
|
self.cpu_group = cpu_group
|
|
|
|
# Calculate max_size_override based on largest sequence length
|
|
max_seq_len = max(sequence_lengths)
|
|
max_tensor_elements = max_seq_len * HIDDEN_SIZE
|
|
self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1
|
|
|
|
# Initialize communicators
|
|
self.custom_allreduce = None
|
|
self.pynccl_comm = None
|
|
self.symm_mem_comm = None
|
|
self.symm_mem_comm_multimem = None
|
|
self.symm_mem_comm_two_shot = None
|
|
|
|
self._init_communicators()
|
|
|
|
def _init_communicators(self):
|
|
"""Initialize all available communicators."""
|
|
try:
|
|
self.custom_allreduce = CustomAllreduce(
|
|
group=self.cpu_group,
|
|
device=self.device,
|
|
max_size=self.max_size_override,
|
|
)
|
|
if not self.custom_allreduce.disabled:
|
|
logger.info("Rank %s: CustomAllreduce initialized", self.rank)
|
|
else:
|
|
logger.info("Rank %s: CustomAllreduce disabled", self.rank)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e
|
|
)
|
|
self.custom_allreduce = None
|
|
|
|
try:
|
|
self.pynccl_comm = PyNcclCommunicator(
|
|
group=self.cpu_group, device=self.device
|
|
)
|
|
if not self.pynccl_comm.disabled:
|
|
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
|
register_nccl_symmetric_ops(self.pynccl_comm)
|
|
else:
|
|
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
|
self.pynccl_comm = None
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e
|
|
)
|
|
self.pynccl_comm = None
|
|
|
|
# Initialize variants for SymmMemCommunicator
|
|
try:
|
|
self.symm_mem_comm_multimem = SymmMemCommunicator(
|
|
group=self.cpu_group,
|
|
device=self.device,
|
|
force_multimem=True,
|
|
max_size_override=self.max_size_override,
|
|
)
|
|
if not self.symm_mem_comm_multimem.disabled:
|
|
logger.info(
|
|
"Rank %s: SymmMemCommunicator (multimem) initialized", self.rank
|
|
)
|
|
else:
|
|
self.symm_mem_comm_multimem = None
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s",
|
|
self.rank,
|
|
e,
|
|
)
|
|
self.symm_mem_comm_multimem = None
|
|
|
|
try:
|
|
self.symm_mem_comm_two_shot = SymmMemCommunicator(
|
|
group=self.cpu_group,
|
|
device=self.device,
|
|
force_multimem=False,
|
|
max_size_override=self.max_size_override,
|
|
)
|
|
if not self.symm_mem_comm_two_shot.disabled:
|
|
logger.info(
|
|
"Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank
|
|
)
|
|
else:
|
|
self.symm_mem_comm_two_shot = None
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s",
|
|
self.rank,
|
|
e,
|
|
)
|
|
self.symm_mem_comm_two_shot = None
|
|
|
|
def benchmark_allreduce(
|
|
self, sequence_length: int, num_warmup: int, num_trials: int
|
|
) -> dict[str, float]:
|
|
"""Benchmark allreduce operations for all available communicators."""
|
|
|
|
results = {}
|
|
|
|
# Define communicators with their benchmark functions
|
|
communicators = []
|
|
|
|
if self.custom_allreduce is not None:
|
|
comm = self.custom_allreduce
|
|
# CustomAllreduce one-shot
|
|
communicators.append(
|
|
(
|
|
"ca_1stage",
|
|
lambda t, c=comm: c.custom_all_reduce(t),
|
|
lambda t, c=comm: c.should_custom_ar(t),
|
|
comm.capture(),
|
|
"1stage", # env variable value
|
|
)
|
|
)
|
|
# CustomAllreduce two-shot
|
|
communicators.append(
|
|
(
|
|
"ca_2stage",
|
|
lambda t, c=comm: c.custom_all_reduce(t),
|
|
lambda t, c=comm: c.should_custom_ar(t),
|
|
comm.capture(),
|
|
"2stage", # env variable value
|
|
)
|
|
)
|
|
|
|
if self.pynccl_comm is not None:
|
|
comm = self.pynccl_comm
|
|
communicators.append(
|
|
(
|
|
"pynccl",
|
|
lambda t, c=comm: c.all_reduce(t),
|
|
lambda t: True, # Always available if initialized
|
|
nullcontext(),
|
|
None, # no env variable needed
|
|
)
|
|
)
|
|
communicators.append(
|
|
(
|
|
"pynccl-symm",
|
|
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
|
lambda t: True, # Always available if initialized
|
|
nullcontext(),
|
|
None, # no env variable needed
|
|
)
|
|
)
|
|
|
|
if self.symm_mem_comm_multimem is not None:
|
|
comm = self.symm_mem_comm_multimem
|
|
communicators.append(
|
|
(
|
|
"symm_mem_multimem",
|
|
lambda t, c=comm: c.all_reduce(t),
|
|
lambda t, c=comm: c.should_use_symm_mem(t),
|
|
nullcontext(),
|
|
None, # no env variable needed
|
|
)
|
|
)
|
|
|
|
if self.symm_mem_comm_two_shot is not None:
|
|
comm = self.symm_mem_comm_two_shot
|
|
communicators.append(
|
|
(
|
|
"symm_mem_two_shot",
|
|
lambda t, c=comm: c.all_reduce(t),
|
|
lambda t, c=comm: c.should_use_symm_mem(t),
|
|
nullcontext(),
|
|
None, # no env variable needed
|
|
)
|
|
)
|
|
|
|
# Benchmark each communicator
|
|
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
|
# Set environment variable if needed
|
|
if env_var is not None:
|
|
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
|
else:
|
|
# Clear the environment variable to avoid interference
|
|
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
|
|
|
latency = self.benchmark_allreduce_single(
|
|
sequence_length,
|
|
allreduce_fn,
|
|
should_use_fn,
|
|
context,
|
|
num_warmup,
|
|
num_trials,
|
|
)
|
|
if latency is not None:
|
|
results[name] = latency
|
|
|
|
return results
|
|
|
|
def benchmark_allreduce_single(
|
|
self,
|
|
sequence_length: int,
|
|
allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None],
|
|
should_use_fn: Callable[[torch.Tensor], bool],
|
|
context,
|
|
num_warmup: int,
|
|
num_trials: int,
|
|
) -> float | None:
|
|
"""Benchmark method with CUDA graph optimization."""
|
|
try:
|
|
# Create test tensor (2D: sequence_length x hidden_size)
|
|
tensor = torch.randn(
|
|
sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device
|
|
)
|
|
if not should_use_fn(tensor):
|
|
return None
|
|
|
|
torch.cuda.synchronize()
|
|
stream = torch.cuda.Stream()
|
|
with torch.cuda.stream(stream):
|
|
graph_input = tensor.clone()
|
|
|
|
# Warmup before capture
|
|
for _ in range(3):
|
|
allreduce_fn(graph_input)
|
|
|
|
# Capture the graph using context manager
|
|
with context:
|
|
graph = torch.cuda.CUDAGraph()
|
|
graph_pool = torch.cuda.graph_pool_handle()
|
|
set_graph_pool_id(graph_pool)
|
|
with torch.cuda.graph(graph, pool=graph_pool):
|
|
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
|
allreduce_fn(graph_input)
|
|
|
|
torch.cuda.synchronize()
|
|
for _ in range(num_warmup):
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.synchronize()
|
|
start_time = time.perf_counter()
|
|
|
|
for _ in range(num_trials):
|
|
graph.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
end_time = time.perf_counter()
|
|
|
|
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
|
|
return (
|
|
(end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("CUDA graph benchmark failed: %s", e)
|
|
raise RuntimeError(
|
|
f"CUDA graph benchmark failed for communicator: {e}"
|
|
) from e
|
|
|
|
|
|
def _calculate_speedup_info(comm_results: dict[str, float]) -> str:
|
|
"""Calculate speedup information for a single tensor size."""
|
|
if not comm_results:
|
|
return "N/A"
|
|
|
|
# Find the fastest communicator
|
|
fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k])
|
|
fastest_time = comm_results[fastest_comm]
|
|
|
|
# Calculate speedup vs PyNccl if available
|
|
if "pynccl" in comm_results:
|
|
pynccl_time = comm_results["pynccl"]
|
|
speedup = pynccl_time / fastest_time
|
|
return f"{fastest_comm} ({speedup:.2f}x)"
|
|
else:
|
|
return f"{fastest_comm} (N/A)"
|
|
|
|
|
|
def print_results(
|
|
results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int
|
|
):
|
|
"""Print benchmark results in a formatted table."""
|
|
|
|
print(f"\n{'=' * 130}")
|
|
print("Device Communicator Benchmark Results")
|
|
print(
|
|
f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, "
|
|
f"Hidden Size: {HIDDEN_SIZE}"
|
|
)
|
|
print(f"{'=' * 130}")
|
|
|
|
# Get all communicator names
|
|
all_comms = set()
|
|
for size_results in results.values():
|
|
all_comms.update(size_results.keys())
|
|
|
|
all_comms = sorted(list(all_comms))
|
|
|
|
# Print header
|
|
header = f"{'Tensor Shape':<20}{'Tensor Size':<15}"
|
|
for comm in all_comms:
|
|
header += f"{comm:<20}"
|
|
header += f"{'Best (Speedup vs PyNccl)':<30}"
|
|
print(header)
|
|
print("-" * len(header))
|
|
|
|
# Print results for each sequence length
|
|
for seq_len in sequence_lengths:
|
|
if seq_len in results:
|
|
# Calculate tensor size in elements and bytes
|
|
tensor_elements = seq_len * HIDDEN_SIZE
|
|
tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize
|
|
|
|
# Format tensor size (MB)
|
|
tensor_size_mb = tensor_bytes / (1024 * 1024)
|
|
tensor_size_str = f"{tensor_size_mb:.2f} MB"
|
|
|
|
# Format tensor shape
|
|
tensor_shape = f"({seq_len}, {HIDDEN_SIZE})"
|
|
|
|
row = f"{tensor_shape:<20}{tensor_size_str:<15}"
|
|
for comm in all_comms:
|
|
if comm in results[seq_len]:
|
|
row += f"{results[seq_len][comm]:<20.3f}"
|
|
else:
|
|
row += f"{'N/A':<20}"
|
|
|
|
# Calculate speedup information
|
|
speedup_info = _calculate_speedup_info(results[seq_len])
|
|
row += f"{speedup_info:<30}"
|
|
|
|
print(row)
|
|
|
|
print(f"{'=' * 130}")
|
|
print("All times are in milliseconds (ms) per allreduce operation")
|
|
print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)")
|
|
|
|
|
|
def main():
|
|
parser = FlexibleArgumentParser(description="Benchmark device communicators")
|
|
|
|
parser.add_argument(
|
|
"--sequence-lengths",
|
|
type=int,
|
|
nargs="+",
|
|
default=DEFAULT_SEQUENCE_LENGTHS,
|
|
help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-warmup", type=int, default=5, help="Number of warmup iterations"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-trials", type=int, default=50, help="Number of benchmark trials"
|
|
)
|
|
|
|
parser.add_argument("--output-json", type=str, help="Output results to JSON file")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize distributed
|
|
if not dist.is_initialized():
|
|
dist.init_process_group(backend="gloo")
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
|
|
# Set device
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
|
|
# Get CPU process group
|
|
cpu_group = dist.new_group(backend="gloo")
|
|
|
|
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
|
|
# in symm_mem and custom_all_reduce for benchmark
|
|
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
|
|
|
# Initialize benchmark
|
|
benchmark = CommunicatorBenchmark(
|
|
rank, world_size, device, cpu_group, args.sequence_lengths
|
|
)
|
|
|
|
# Run benchmarks
|
|
all_results = {}
|
|
|
|
for seq_len in args.sequence_lengths:
|
|
if rank == 0:
|
|
logger.info(
|
|
"Benchmarking sequence length: %s (tensor shape: %s x %s)",
|
|
seq_len,
|
|
seq_len,
|
|
HIDDEN_SIZE,
|
|
)
|
|
|
|
results = benchmark.benchmark_allreduce(
|
|
sequence_length=seq_len,
|
|
num_warmup=args.num_warmup,
|
|
num_trials=args.num_trials,
|
|
)
|
|
|
|
all_results[seq_len] = results
|
|
|
|
# Synchronize between ranks
|
|
dist.barrier()
|
|
|
|
# Print results (only rank 0)
|
|
if rank == 0:
|
|
print_results(all_results, args.sequence_lengths, world_size)
|
|
|
|
# Save to JSON if requested
|
|
if args.output_json:
|
|
# Add speedup information to results
|
|
enhanced_results = {}
|
|
for seq_len, comm_results in all_results.items():
|
|
enhanced_results[seq_len] = {
|
|
"timings": comm_results,
|
|
"speedup_info": _calculate_speedup_info(comm_results),
|
|
}
|
|
|
|
output_data = {
|
|
"world_size": world_size,
|
|
"dtype": str(BENCHMARK_DTYPE),
|
|
"hidden_size": HIDDEN_SIZE,
|
|
"sequence_lengths": args.sequence_lengths,
|
|
"num_warmup": args.num_warmup,
|
|
"num_trials": args.num_trials,
|
|
"cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES,
|
|
"results": enhanced_results,
|
|
}
|
|
|
|
with open(args.output_json, "w") as f:
|
|
json.dump(output_data, f, indent=2)
|
|
|
|
logger.info("Results saved to %s", args.output_json)
|
|
|
|
# Cleanup
|
|
if cpu_group != dist.group.WORLD:
|
|
dist.destroy_process_group(cpu_group)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|