mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58003 adds trainer class DdpTrainer adds trainer class DdpSparseRpcTrainer adds server class ParameterServerBase adds server class AverageParameterServer adds experiment ddp_cpu_sparse_rpc_nccl_allreduce adds experiment ddp_cuda_sparse_rpc_nccl_allreduce quip document https://fb.quip.com/iQUtAeKIxWpF Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29379696 Pulled By: gcramer23 fbshipit-source-id: 9cf5fb7398ba2fa3eb694afbddc4ed00d97f205f
41 lines
813 B
Python
41 lines
813 B
Python
from data.DummyData import DummyData
|
|
from models.DummyModel import DummyModel
|
|
from servers.AverageParameterServer import AverageParameterServer
|
|
from trainers.DdpNcclTrainer import DdpNcclTrainer
|
|
from trainers.DdpSparseRpcTrainer import DdpSparseRpcTrainer
|
|
from trainers.DdpTrainer import DdpTrainer
|
|
|
|
trainer_map = {
|
|
"DdpNcclTrainer": DdpNcclTrainer,
|
|
"DdpTrainer": DdpTrainer,
|
|
"DdpSparseRpcTrainer": DdpSparseRpcTrainer
|
|
}
|
|
|
|
server_map = {
|
|
"AverageParameterServer": AverageParameterServer
|
|
}
|
|
|
|
model_map = {
|
|
"DummyModel": DummyModel
|
|
}
|
|
|
|
data_map = {
|
|
"DummyData": DummyData
|
|
}
|
|
|
|
|
|
def get_benchmark_trainer_map():
|
|
return trainer_map
|
|
|
|
|
|
def get_benchmark_server_map():
|
|
return server_map
|
|
|
|
|
|
def get_benchmark_model_map():
|
|
return model_map
|
|
|
|
|
|
def get_benchmark_data_map():
|
|
return data_map
|