mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Fix benchmark_moe.py tuning for CUDA devices (#14164)
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, TypedDict
|
||||
@ -412,7 +413,8 @@ class BenchmarkWorker:
|
||||
hidden_size, search_space,
|
||||
is_fp16, topk)
|
||||
|
||||
with torch.cuda.device(self.device_id):
|
||||
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
|
||||
) else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
|
Reference in New Issue
Block a user