ps sparse rpc (#58003)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58003

adds trainer class DdpTrainer
adds trainer class DdpSparseRpcTrainer
adds server class ParameterServerBase
adds server class AverageParameterServer
adds experiment ddp_cpu_sparse_rpc_nccl_allreduce
adds experiment ddp_cuda_sparse_rpc_nccl_allreduce

quip document https://fb.quip.com/iQUtAeKIxWpF

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D29379696

Pulled By: gcramer23

fbshipit-source-id: 9cf5fb7398ba2fa3eb694afbddc4ed00d97f205f
This commit is contained in:
Garrett Cramer
2021-06-24 17:20:33 -07:00
committed by Facebook GitHub Bot
parent fadaa52f64
commit 4ed2d5d9bb
17 changed files with 1051 additions and 202 deletions

View File

@ -4,72 +4,100 @@ import json
import os
from pathlib import Path
import torch
import torch.distributed as c10d
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from torch.distributed.rpc import TensorPipeRpcBackendOptions
from torch.futures import wait_all
from torch.utils.data import DataLoader
from benchmark_class_helper import (get_benchmark_data_map,
get_benchmark_model_map,
get_benchmark_ps_map,
get_benchmark_server_map,
get_benchmark_trainer_map)
from BenchmarkConfigurations import BenchmarkConfigurations
from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
USE_CUDA_RPC = "use_cuda_rpc"
def get_name(rank, configs):
t_count = configs.trainer_count
ps_count = configs.ps_count
def get_name(rank, args):
t_count = args.ntrainer + args.ncudatrainer
s_count = args.nserver + args.ncudaserver
if rank < t_count:
return f"trainer{rank}"
elif rank < (t_count + ps_count):
return f"ps{rank}"
elif rank < (t_count + s_count):
return f"server{rank}"
else:
return "master"
def get_parameter_server_rank(rank, config):
# rank mod parameter server count to get parameter server number
# add trainer_count to get parameter server rank
rank_mod_ps_count = rank % config.ps_count
return rank_mod_ps_count + config.trainer_count
def get_server_rank(args, rank):
s_offset = args.ntrainer + args.ncudatrainer
tps = args.ntrainer // args.nserver
return rank // tps + s_offset
def get_ps_rref(parameter_server_rank, config):
ps_config = config.ps_config
ps = get_benchmark_ps_map()[str(ps_config["ps_class"])]
def get_cuda_server_rank(args, rank):
s_offset = args.ntrainer + args.ncudatrainer + args.nserver
t_index = rank - args.ntrainer
ctps = args.ncudatrainer // args.ncudaserver
return t_index // ctps + s_offset
def get_server_rref(server_rank, args, extra_args):
server = get_benchmark_server_map()[str(args.server)]
name = get_name(
parameter_server_rank,
config
server_rank,
args
)
ps_args = ps_config["configurations"].values()
ps_trainer_count = config.trainer_count / ps_config.ps_count
rem = config.trainer_count % ps_config.ps_count
if parameter_server_rank - config.trainer_count < rem:
ps_trainer_count += 1
if extra_args is not None:
server_args = extra_args.values()
else:
server_args = []
if server_rank >= args.ntrainer + args.ncudatrainer + args.nserver:
trainer_count = args.ncudatrainer / args.ncudaserver
use_cuda_rpc = True
else:
trainer_count = args.ntrainer / args.nserver
use_cuda_rpc = False
return rpc.remote(
name,
ps,
server,
args=(
parameter_server_rank,
ps_trainer_count,
*ps_args,
server_rank,
trainer_count,
use_cuda_rpc,
*server_args,
),
)
def run_trainer(
config, model, data, rank, ps_rref
args, extra_args, model, data, rank, server_rref
):
trainer_config = config.trainer_config
trainer_class = get_benchmark_trainer_map()[str(trainer_config["trainer_class"])]
trainer_args = trainer_config["configurations"].values()
trainer_class = get_benchmark_trainer_map()[str(args.trainer)]
if extra_args is not None:
trainer_args = extra_args.values()
else:
trainer_args = []
trainer_count = args.ntrainer + args.ncudatrainer
store = c10d.FileStore(args.filestore, trainer_count)
if args.backend == "gloo":
process_group = c10d.ProcessGroupGloo(
store, rank, trainer_count
)
elif args.backend == "nccl":
process_group = c10d.ProcessGroupNCCL(
store, rank, trainer_count
)
use_cuda_rpc = rank >= args.ntrainer
trainer = trainer_class(
rank,
config.trainer_count,
ps_rref,
args.ntrainer + args.ncudatrainer,
process_group,
use_cuda_rpc,
server_rref,
args.backend,
args.epochs,
*trainer_args
)
trainer.train(model, data)
@ -77,48 +105,44 @@ def run_trainer(
return [rank, metrics]
def call_trainers(config, model, train_data, parameter_server_rrefs):
def call_trainers(args, extra_args, model, train_data, server_rrefs):
futs = []
for trainer_rank in range(0, config.trainer_count):
for trainer_rank in range(0, args.ntrainer + args.ncudatrainer):
trainer_name = get_name(
trainer_rank,
config
args
)
ps_rref = None
if parameter_server_rrefs:
ps_rank = get_parameter_server_rank(trainer_rank, config)
ps_rref = parameter_server_rrefs[ps_rank]
server_rref = None
if server_rrefs:
if trainer_rank >= args.ntrainer:
server_rank = get_cuda_server_rank(args, trainer_rank)
else:
server_rank = get_server_rank(args, trainer_rank)
server_rref = server_rrefs[server_rank]
fut = rpc.rpc_async(
trainer_name,
run_trainer,
args=(
config,
args,
extra_args,
copy.deepcopy(model),
train_data[trainer_rank],
trainer_rank,
ps_rref,
server_rref,
),
timeout=config.rpc_async_timeout
timeout=args.rpc_timeout
)
futs.append(fut)
return futs
def benchmark_warmup(
config, model, data, parameter_server_rrefs
args, extra_args, model, data, server_rrefs
):
if config.ps_count > 0:
ps_config = config.ps_config
ps = get_benchmark_ps_map()[str(ps_config["ps_class"])]
futs = call_trainers(config, model, data, parameter_server_rrefs)
for fut in futs:
fut.wait()
for ps_rref in parameter_server_rrefs.values():
rpc.rpc_sync(
ps_rref.owner(),
ps.reset_state,
args=(ps_rref,)
)
futs = call_trainers(args, extra_args, model, data, server_rrefs)
wait_all(futs)
for server_rref in server_rrefs.values():
server_rref.rpc_sync().reset_state(server_rref)
print("benchmark warmup done\n")
@ -126,84 +150,88 @@ def split_list(arr, n):
return [arr[i::n] for i in range(n)]
def run_master(rank, model, data, config, rpc_backend_options):
world_size = config.trainer_count + config.ps_count + 1
def get_server_metrics(server_rrefs):
rank_metrics = []
for rank, server_rref in server_rrefs.items():
metrics = server_rref.rpc_sync().get_metrics(server_rref)
rank_metrics.append([rank, metrics])
return rank_metrics
def run_master(rank, model, data, args, extra_configs, rpc_backend_options):
world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
rpc.init_rpc(
get_name(
rank,
config
args
),
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options
)
parameter_server_rrefs = {}
server_rrefs = {}
for i in range(
config.trainer_count, world_size - 1
args.ntrainer + args.ncudatrainer, world_size - 1
):
parameter_server_rrefs[i] = get_ps_rref(i, config)
server_rrefs[i] = get_server_rref(i, args, extra_configs["server_config"])
train_data = split_list(
list(DataLoader(data, batch_size=config.batch_size)),
config.trainer_count
list(DataLoader(data, batch_size=args.batch_size)),
args.ntrainer + args.ncudatrainer
)
# warmup run the benchmark
benchmark_warmup(
config, model, train_data, parameter_server_rrefs
args, extra_configs["trainer_config"], model, train_data, server_rrefs
)
# run the benchmark
trainer_futs = call_trainers(
config, model, train_data, parameter_server_rrefs
args, extra_configs["trainer_config"], model, train_data, server_rrefs
)
# collect metrics and print
metrics_printer = ProcessedMetricsPrinter()
rank_metrics_list = [fut.wait() for fut in trainer_futs]
rank_metrics_list = wait_all(trainer_futs)
metrics_printer.print_metrics("trainer", rank_metrics_list)
rank_metrics_list = get_server_metrics(server_rrefs)
metrics_printer.print_metrics("parameter server", rank_metrics_list)
def run_benchmark(rank, model, data, config):
def run_benchmark(rank, model, data, args, config):
world_size = config.trainer_count + config.ps_count + 1
os.environ['MASTER_ADDR'] = config.master_addr
os.environ['MASTER_PORT'] = config.master_port
rpc_backend_options = TensorPipeRpcBackendOptions()
rpc_backend_options.init_method = config.rpc_init_method
torch.manual_seed(args.torch_seed)
torch.cuda.manual_seed_all(args.cuda_seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
rpc_backend_options = TensorPipeRpcBackendOptions(rpc_timeout=args.rpc_timeout)
if rank == world_size - 1:
# master = [trainer_count + parameter_server_count, trainer_count + parameter_server_count]
run_master(rank, model, data, config, rpc_backend_options)
elif rank >= config.trainer_count:
# parameter_servers = [trainer_count, trainer_count + parameter_server_count)
# master = [ntrainer + ncudatrainer + nserver + ncudaserver, ntrainer + ncudatrainer + nserver + ncudaserver]
run_master(rank, model, data, args, config, rpc_backend_options)
elif rank >= args.ntrainer + args.ncudatrainer:
# parameter_servers = [ntrainer + ncudatrainer, ntrainer + ncudatrainer + nserver + ncudaserver)
rpc.init_rpc(
get_name(
rank,
config
args
),
rank=rank,
world_size=world_size,
rpc_backend_options=rpc_backend_options
)
else:
# trainers = [0, trainer_count)
trainer_config = config.trainer_config
ps_config = config.ps_config
if (USE_CUDA_RPC in trainer_config and
trainer_config[USE_CUDA_RPC] and
USE_CUDA_RPC in ps_config and
ps_config[USE_CUDA_RPC] and
config.ps_count > 0):
ps_rank = get_parameter_server_rank(rank, config)
ps_name = get_name(
ps_rank,
config
)
# trainers = [0, ntrainer + ncudatrainer)
if rank >= args.ntrainer:
server_rank = get_cuda_server_rank(args, rank)
server_name = get_name(server_rank, args)
rpc_backend_options.set_device_map(
ps_name,
{rank: ps_rank}
server_name,
{rank: server_rank}
)
trainer_name = get_name(
rank,
config
args
)
rpc.init_rpc(
trainer_name,
@ -221,16 +249,18 @@ def get_json_config(file_name, id):
return json_config
def load_configurations(args):
def load_extra_configs(args):
trainer_config_file = args.trainer_config_path
ps_config_file = args.server_config_path
benchmark_config = get_json_config(args.benchmark_config_path, args.benchmark)
benchmark_config["trainer_config"] = get_json_config(trainer_config_file, args.trainer)
if args.server != "None":
benchmark_config["ps_config"] = get_json_config(ps_config_file, args.server)
else:
benchmark_config["ps_config"] = None
return BenchmarkConfigurations(**benchmark_config)
server_config_file = args.server_config_path
configurations = {
"trainer_config": None,
"server_config": None
}
if args.trainer is not None and trainer_config_file is not None:
configurations["trainer_config"] = get_json_config(trainer_config_file, args.trainer)
if args.server is not None and server_config_file is not None:
configurations["server_config"] = get_json_config(server_config_file, args.server)
return configurations
def get_data(data_class, data_config):
@ -255,43 +285,106 @@ def load_model(args):
return get_model(model_config["model_class"], model_config["configurations"])
def main():
parser = argparse.ArgumentParser(description="RPC PS Benchmark")
def main(args):
# CPU and RPC trainer checks
if args.ntrainer > 0 and args.ncudatrainer > 0:
assert args.nserver > 0 and args.ncudaserver > 0
if args.nserver > 0:
assert args.ntrainer > 0
assert args.ntrainer % args.nserver == 0
if args.ncudaserver > 0:
assert args.ncudatrainer > 0
assert args.ncudatrainer % args.ncudaserver == 0
extra_configs = load_extra_configs(args)
data = load_data(args)
model = load_model(args)
world_size = (
args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
)
mp.spawn(
run_benchmark,
args=(
model,
data,
args,
extra_configs,
),
nprocs=world_size,
join=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RPC server Benchmark")
parser.add_argument(
"--benchmark_config_path",
"--master_addr",
type=str,
default="configurations/benchmark_configurations.json",
help="path to benchmark configuration file"
help="IP address of the machine that will host the process with rank 0"
)
parser.add_argument(
"--data_config_path",
"--master_port",
type=str,
default="configurations/data_configurations.json",
help="path to data configuration file"
help="A free port on the machine that will host the process with rank 0"
)
parser.add_argument(
"--model_config_path",
"--trainer",
type=str,
default="configurations/model_configurations.json",
help="path to model configuration file"
help="trainer map key to get trainer class for benchmark run"
)
parser.add_argument(
"--server_config_path",
type=str,
default="configurations/server_configurations.json",
help="path to server configuration file"
"--ntrainer",
type=int,
help="trainer count for benchmark run"
)
parser.add_argument(
"--trainer_config_path",
type=str,
default="configurations/trainer_configurations.json",
help="path to trainer configuration file"
"--ncudatrainer",
type=int,
help="cudatrainer count for benchmark run"
)
parser.add_argument(
"--benchmark",
"--filestore",
type=str,
help="id for benchmark configuration"
help="filestore location for process group"
)
parser.add_argument(
"--server",
type=str,
help="server map key to get trainer class for benchmark run"
)
parser.add_argument(
"--nserver",
type=int,
help="server count for benchmark run"
)
parser.add_argument(
"--ncudaserver",
type=int,
help="cudaserver count for benchmark run"
)
parser.add_argument(
"--rpc_timeout",
type=int,
help="timeout in seconds to use for RPC"
)
parser.add_argument(
"--backend",
type=str,
help="distributed communication backend to use for benchmark run"
)
parser.add_argument(
"--epochs",
type=int,
help="epoch count for training"
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="number of training examples used in one iteration"
)
parser.add_argument(
"--data",
@ -304,35 +397,37 @@ def main():
help="id for model configuration"
)
parser.add_argument(
"--server",
"--data_config_path",
type=str,
help="id for parameter server configuration"
help="path to data configuration file"
)
parser.add_argument(
"--trainer",
"--model_config_path",
type=str,
help="id for trainer configuration"
help="path to model configuration file"
)
parser.add_argument(
"--server_config_path",
type=str,
help="path to server configuration file"
)
parser.add_argument(
"--trainer_config_path",
type=str,
help="path to trainer configuration file"
)
parser.add_argument(
"--torch_seed",
type=int,
default=0,
help="seed for generating random numbers to a non-deterministic random number"
)
parser.add_argument(
"--cuda_seed",
type=int,
default=0,
help="seed for generating random numbers to a random number for the current GPU"
)
args = parser.parse_args()
print(f"{args}\n")
config = load_configurations(args)
data = load_data(args)
model = load_model(args)
world_size = config.trainer_count + config.ps_count + 1
mp.spawn(
run_benchmark,
args=(
model,
data,
config,
),
nprocs=world_size,
join=True
)
if __name__ == "__main__":
main()
main(args)