Fix benchmark_moe.py tuning for CUDA devices (#14164)

This commit is contained in:
Michael Goin
2025-03-04 00:11:03 -05:00
committed by GitHub
parent 66233af7b6
commit f78c0be80a

View File

@ -2,6 +2,7 @@
import argparse
import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from typing import Any, TypedDict
@ -412,7 +413,8 @@ class BenchmarkWorker:
hidden_size, search_space,
is_fp16, topk)
with torch.cuda.device(self.device_id):
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
) else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(