Files
pytorch/benchmarks/distributed/rpc/parameter_server/benchmark_class_helper.py
Garrett Cramer 4ed2d5d9bb ps sparse rpc (#58003)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58003

adds trainer class DdpTrainer
adds trainer class DdpSparseRpcTrainer
adds server class ParameterServerBase
adds server class AverageParameterServer
adds experiment ddp_cpu_sparse_rpc_nccl_allreduce
adds experiment ddp_cuda_sparse_rpc_nccl_allreduce

quip document https://fb.quip.com/iQUtAeKIxWpF

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29379696

Pulled By: gcramer23

fbshipit-source-id: 9cf5fb7398ba2fa3eb694afbddc4ed00d97f205f
2021-06-24 17:21:49 -07:00

41 lines
813 B
Python

from data.DummyData import DummyData
from models.DummyModel import DummyModel
from servers.AverageParameterServer import AverageParameterServer
from trainers.DdpNcclTrainer import DdpNcclTrainer
from trainers.DdpSparseRpcTrainer import DdpSparseRpcTrainer
from trainers.DdpTrainer import DdpTrainer
trainer_map = {
"DdpNcclTrainer": DdpNcclTrainer,
"DdpTrainer": DdpTrainer,
"DdpSparseRpcTrainer": DdpSparseRpcTrainer
}
server_map = {
"AverageParameterServer": AverageParameterServer
}
model_map = {
"DummyModel": DummyModel
}
data_map = {
"DummyData": DummyData
}
def get_benchmark_trainer_map():
return trainer_map
def get_benchmark_server_map():
return server_map
def get_benchmark_model_map():
return model_map
def get_benchmark_data_map():
return data_map