mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
38 lines
1.1 KiB
Python
Executable File
38 lines
1.1 KiB
Python
Executable File
'''Copyright The Microsoft DeepSpeed Team'''
|
|
|
|
import random
|
|
|
|
from .base_tuner import BaseTuner
|
|
|
|
|
|
class RandomTuner(BaseTuner):
|
|
"""Explore the search space in random order"""
|
|
|
|
def __init__(self, exps: list, resource_manager, metric):
|
|
super().__init__(exps, resource_manager, metric)
|
|
|
|
def next_batch(self, sample_size=1):
|
|
if sample_size > len(self.all_exps):
|
|
sample_size = len(self.all_exps)
|
|
|
|
sampled_batch = random.sample(self.all_exps, sample_size)
|
|
self.all_exps = [x for x in self.all_exps if x not in sampled_batch]
|
|
|
|
return sampled_batch
|
|
|
|
|
|
class GridSearchTuner(BaseTuner):
|
|
"""Explore the search space in sequential order"""
|
|
|
|
def __init__(self, exps: list, resource_manager, metric):
|
|
super().__init__(exps, resource_manager, metric)
|
|
|
|
def next_batch(self, sample_size=1):
|
|
if sample_size > len(self.all_exps):
|
|
sample_size = len(self.all_exps)
|
|
|
|
sampled_batch = self.all_exps[0:sample_size]
|
|
self.all_exps = [x for x in self.all_exps if x not in sampled_batch]
|
|
|
|
return sampled_batch
|