mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor ps benchmark (#60784)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60784 This pr refactors the ps benchmark for modular trainers. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D29697291 Pulled By: gcramer23 fbshipit-source-id: 64579a1f5326d3cd9f32936dcf53bc243d54b71d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d2ea9a8f7
commit
304c02ee44
@ -1,40 +0,0 @@
|
||||
from data.DummyData import DummyData
|
||||
from models.DummyModel import DummyModel
|
||||
from servers.AverageParameterServer import AverageParameterServer
|
||||
from trainers.DdpNcclTrainer import DdpNcclTrainer
|
||||
from trainers.DdpSparseRpcTrainer import DdpSparseRpcTrainer
|
||||
from trainers.DdpTrainer import DdpTrainer
|
||||
|
||||
trainer_map = {
|
||||
"DdpNcclTrainer": DdpNcclTrainer,
|
||||
"DdpTrainer": DdpTrainer,
|
||||
"DdpSparseRpcTrainer": DdpSparseRpcTrainer
|
||||
}
|
||||
|
||||
server_map = {
|
||||
"AverageParameterServer": AverageParameterServer
|
||||
}
|
||||
|
||||
model_map = {
|
||||
"DummyModel": DummyModel
|
||||
}
|
||||
|
||||
data_map = {
|
||||
"DummyData": DummyData
|
||||
}
|
||||
|
||||
|
||||
def get_benchmark_trainer_map():
|
||||
return trainer_map
|
||||
|
||||
|
||||
def get_benchmark_server_map():
|
||||
return server_map
|
||||
|
||||
|
||||
def get_benchmark_model_map():
|
||||
return model_map
|
||||
|
||||
|
||||
def get_benchmark_data_map():
|
||||
return data_map
|
@ -2,19 +2,10 @@
|
||||
"DummyData": {
|
||||
"data_class": "DummyData",
|
||||
"configurations": {
|
||||
"max_val": 100,
|
||||
"input_samples": 100,
|
||||
"input_dim": 100,
|
||||
"max_val": 1024,
|
||||
"sample_count": 1024,
|
||||
"sample_length": 1024,
|
||||
"sparsity_percentage": 20
|
||||
}
|
||||
},
|
||||
"DummyData2": {
|
||||
"data_class": "DummyData",
|
||||
"configurations": {
|
||||
"max_val": 100,
|
||||
"input_samples": 100,
|
||||
"input_dim": 100,
|
||||
"sparsity_percentage": 80
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,20 +2,22 @@
|
||||
"DummyModel": {
|
||||
"model_class": "DummyModel",
|
||||
"configurations": {
|
||||
"num_embeddings": 100,
|
||||
"embedding_dim": 100,
|
||||
"dense_input_size": 100,
|
||||
"dense_output_size": 100,
|
||||
"num_embeddings": 1024,
|
||||
"embedding_dim": 1024,
|
||||
"dense_input_size": 1024,
|
||||
"dense_output_size": 1024,
|
||||
"dense_layers_count": 8,
|
||||
"sparse": false
|
||||
}
|
||||
},
|
||||
"DummyModelSparse": {
|
||||
"model_class": "DummyModel",
|
||||
"configurations": {
|
||||
"num_embeddings": 100,
|
||||
"embedding_dim": 100,
|
||||
"dense_input_size": 100,
|
||||
"dense_output_size": 100,
|
||||
"num_embeddings": 1024,
|
||||
"embedding_dim": 1024,
|
||||
"dense_input_size": 1024,
|
||||
"dense_output_size": 1024,
|
||||
"dense_layers_count": 8,
|
||||
"sparse": true
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
@ -10,13 +11,22 @@ class DummyData(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
max_val: int,
|
||||
input_samples: int,
|
||||
input_dim: int,
|
||||
sample_count: int,
|
||||
sample_length: int,
|
||||
sparsity_percentage: int
|
||||
):
|
||||
r"""
|
||||
A data class that generates random data.
|
||||
Args:
|
||||
max_val (int): the maximum value for an element
|
||||
sample_count (int): count of training samples
|
||||
sample_length (int): number of elements in a sample
|
||||
sparsity_percentage (int): the percentage of
|
||||
embeddings used by the input data in each iteration
|
||||
"""
|
||||
self.max_val = max_val
|
||||
self.input_samples = input_samples
|
||||
self.input_dim = input_dim
|
||||
self.input_samples = sample_count
|
||||
self.input_dim = sample_length
|
||||
self.sparsity_percentage = sparsity_percentage
|
||||
|
||||
def generate_input():
|
||||
@ -35,9 +45,7 @@ class DummyData(Dataset):
|
||||
return torch.from_numpy(np.array(data))
|
||||
|
||||
self.input = generate_input()
|
||||
self.target = torch.randint(0, max_val, [input_samples])
|
||||
self.start = 0
|
||||
self.end = max_val
|
||||
self.target = torch.randint(0, max_val, [sample_count])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input)
|
||||
|
@ -0,0 +1,5 @@
|
||||
from .DummyData import DummyData
|
||||
|
||||
data_map = {
|
||||
"DummyData": DummyData
|
||||
}
|
@ -1,23 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
cd ..
|
||||
|
||||
python -u launcher.py \
|
||||
--master_addr="localhost" \
|
||||
--master_port="29500" \
|
||||
--trainer="DdpSparseRpcTrainer" \
|
||||
--ntrainer=2 \
|
||||
--ncudatrainer=0 \
|
||||
--filestore="/tmp/tmpn_k_8so02" \
|
||||
--server="AverageParameterServer" \
|
||||
--nserver=1 \
|
||||
--ncudaserver=0 \
|
||||
--rpc_timeout=30 \
|
||||
--backend="nccl" \
|
||||
--epochs=10 \
|
||||
--batch_size=10 \
|
||||
--data="DummyData" \
|
||||
--model="DummyModelSparse" \
|
||||
--data_config_path="configurations/data_configurations.json" \
|
||||
--model_config_path="configurations/model_configurations.json"
|
@ -1,23 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
cd ..
|
||||
|
||||
python -u launcher.py \
|
||||
--master_addr="localhost" \
|
||||
--master_port="29500" \
|
||||
--trainer="DdpSparseRpcTrainer" \
|
||||
--ntrainer=0 \
|
||||
--ncudatrainer=2 \
|
||||
--filestore="/tmp/tmpn_k_8so02" \
|
||||
--server="AverageParameterServer" \
|
||||
--nserver=0 \
|
||||
--ncudaserver=1 \
|
||||
--rpc_timeout=30 \
|
||||
--backend="nccl" \
|
||||
--epochs=10 \
|
||||
--batch_size=10 \
|
||||
--data="DummyData" \
|
||||
--model="DummyModelSparse" \
|
||||
--data_config_path="configurations/data_configurations.json" \
|
||||
--model_config_path="configurations/model_configurations.json"
|
@ -1,22 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
cd ..
|
||||
|
||||
python -u launcher.py \
|
||||
--master_addr="localhost" \
|
||||
--master_port="29500" \
|
||||
--trainer="DdpTrainer" \
|
||||
--ntrainer=2 \
|
||||
--ncudatrainer=0 \
|
||||
--filestore="/tmp/tmpn_k_8so02" \
|
||||
--nserver=0 \
|
||||
--ncudaserver=0 \
|
||||
--rpc_timeout=30 \
|
||||
--backend="nccl" \
|
||||
--epochs=10 \
|
||||
--batch_size=10 \
|
||||
--data="DummyData" \
|
||||
--model="DummyModel" \
|
||||
--data_config_path="configurations/data_configurations.json" \
|
||||
--model_config_path="configurations/model_configurations.json"
|
@ -1,9 +1,22 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from data import data_map
|
||||
from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
|
||||
from models import model_map
|
||||
from server import server_map
|
||||
from trainer import (
|
||||
criterion_map,
|
||||
ddp_hook_map,
|
||||
ddp_model_map,
|
||||
hook_state_map,
|
||||
iteration_step_map,
|
||||
preprocess_data_map,
|
||||
trainer_map,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed.rpc as rpc
|
||||
@ -12,14 +25,15 @@ from torch.distributed.rpc import TensorPipeRpcBackendOptions
|
||||
from torch.futures import wait_all
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from benchmark_class_helper import (get_benchmark_data_map,
|
||||
get_benchmark_model_map,
|
||||
get_benchmark_server_map,
|
||||
get_benchmark_trainer_map)
|
||||
from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
|
||||
|
||||
|
||||
def get_name(rank, args):
|
||||
r"""
|
||||
A function that gets the name for the rank
|
||||
argument
|
||||
Args:
|
||||
rank (int): process number in the world
|
||||
args (parser): benchmark configurations
|
||||
"""
|
||||
t_count = args.ntrainer + args.ncudatrainer
|
||||
s_count = args.nserver + args.ncudaserver
|
||||
if rank < t_count:
|
||||
@ -31,12 +45,26 @@ def get_name(rank, args):
|
||||
|
||||
|
||||
def get_server_rank(args, rank):
|
||||
r"""
|
||||
A function that gets the server rank for
|
||||
the rank argument.
|
||||
Args:
|
||||
args (parser): benchmark configurations
|
||||
rank (int): trainer rank
|
||||
"""
|
||||
s_offset = args.ntrainer + args.ncudatrainer
|
||||
tps = args.ntrainer // args.nserver
|
||||
return rank // tps + s_offset
|
||||
|
||||
|
||||
def get_cuda_server_rank(args, rank):
|
||||
r"""
|
||||
A function that gets the cudaserver rank for
|
||||
the rank argument.
|
||||
Args:
|
||||
args (parser): benchmark configurations
|
||||
rank (int): trainer rank
|
||||
"""
|
||||
s_offset = args.ntrainer + args.ncudatrainer + args.nserver
|
||||
t_index = rank - args.ntrainer
|
||||
ctps = args.ncudatrainer // args.ncudaserver
|
||||
@ -44,7 +72,14 @@ def get_cuda_server_rank(args, rank):
|
||||
|
||||
|
||||
def get_server_rref(server_rank, args, extra_args):
|
||||
server = get_benchmark_server_map()[str(args.server)]
|
||||
r"""
|
||||
A function that creates a RRef to the server.
|
||||
Args:
|
||||
server_rank (int): process number in the world
|
||||
args (parser): benchmark configurations
|
||||
extra_args (dict): configurations added by the user
|
||||
"""
|
||||
server = server_map[args.server]
|
||||
name = get_name(
|
||||
server_rank,
|
||||
args
|
||||
@ -72,9 +107,19 @@ def get_server_rref(server_rank, args, extra_args):
|
||||
|
||||
|
||||
def run_trainer(
|
||||
args, extra_args, model, data, rank, server_rref
|
||||
args, extra_args, data, rank, server_rref
|
||||
):
|
||||
trainer_class = get_benchmark_trainer_map()[str(args.trainer)]
|
||||
r"""
|
||||
A function that runs obtains a trainer instance and calls
|
||||
the train method.
|
||||
Args:
|
||||
args (parser): benchmark configurations
|
||||
extra_args (dict): configurations added by the user
|
||||
data (list): training samples
|
||||
rank (int): process number in the world
|
||||
server_rrefs (dict): a dictionary containing server RRefs
|
||||
"""
|
||||
trainer_class = trainer_map[args.trainer]
|
||||
if extra_args is not None:
|
||||
trainer_args = extra_args.values()
|
||||
else:
|
||||
@ -89,15 +134,34 @@ def run_trainer(
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, rank, trainer_count
|
||||
)
|
||||
elif args.backend == "multi":
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, rank, trainer_count
|
||||
)
|
||||
if c10d.is_initialized() is False:
|
||||
c10d.init_process_group(backend="gloo", rank=rank, world_size=trainer_count)
|
||||
|
||||
model = load_model(args)
|
||||
preprocess_data = preprocess_data_map[args.preprocess_data]
|
||||
create_criterion = criterion_map[args.create_criterion]
|
||||
create_ddp_model = ddp_model_map[args.create_ddp_model]
|
||||
iteration_step = iteration_step_map[args.iteration_step]
|
||||
hook_state_class = hook_state_map[args.hook_state]
|
||||
hook = ddp_hook_map[args.ddp_hook]
|
||||
# check if this a cudatrainer
|
||||
use_cuda_rpc = rank >= args.ntrainer
|
||||
trainer = trainer_class(
|
||||
rank,
|
||||
args.ntrainer + args.ncudatrainer,
|
||||
process_group,
|
||||
use_cuda_rpc,
|
||||
server_rref,
|
||||
args.backend,
|
||||
args.epochs,
|
||||
preprocess_data,
|
||||
create_criterion,
|
||||
create_ddp_model,
|
||||
hook_state_class,
|
||||
hook,
|
||||
iteration_step,
|
||||
*trainer_args
|
||||
)
|
||||
trainer.train(model, data)
|
||||
@ -105,7 +169,16 @@ def run_trainer(
|
||||
return [rank, metrics]
|
||||
|
||||
|
||||
def call_trainers(args, extra_args, model, train_data, server_rrefs):
|
||||
def call_trainers(args, extra_args, train_data, server_rrefs):
|
||||
r"""
|
||||
A function that starts the trainers. Each trainer is started
|
||||
using an rpc_async request.
|
||||
Args:
|
||||
args (parser): benchmark configurations
|
||||
extra_args (dict): configurations added by the user
|
||||
train_data (list): training samples
|
||||
server_rrefs (dict): a dictionary containing server RRefs
|
||||
"""
|
||||
futs = []
|
||||
for trainer_rank in range(0, args.ntrainer + args.ncudatrainer):
|
||||
trainer_name = get_name(
|
||||
@ -125,7 +198,6 @@ def call_trainers(args, extra_args, model, train_data, server_rrefs):
|
||||
args=(
|
||||
args,
|
||||
extra_args,
|
||||
copy.deepcopy(model),
|
||||
train_data[trainer_rank],
|
||||
trainer_rank,
|
||||
server_rref,
|
||||
@ -137,9 +209,18 @@ def call_trainers(args, extra_args, model, train_data, server_rrefs):
|
||||
|
||||
|
||||
def benchmark_warmup(
|
||||
args, extra_args, model, data, server_rrefs
|
||||
args, extra_args, data, server_rrefs
|
||||
):
|
||||
futs = call_trainers(args, extra_args, model, data, server_rrefs)
|
||||
r"""
|
||||
A function that runs the training algorithm. The goal of this
|
||||
function is to warm the rpc. The server states are reset.
|
||||
Args:
|
||||
args (parser): benchmark configurations
|
||||
extra_args (dict): configurations added by the user
|
||||
data (list): training samples
|
||||
server_rrefs (dict): a dictionary containing server RRefs
|
||||
"""
|
||||
futs = call_trainers(args, extra_args, data, server_rrefs)
|
||||
wait_all(futs)
|
||||
for server_rref in server_rrefs.values():
|
||||
server_rref.rpc_sync().reset_state(server_rref)
|
||||
@ -147,10 +228,22 @@ def benchmark_warmup(
|
||||
|
||||
|
||||
def split_list(arr, n):
|
||||
r"""
|
||||
A function that splits a list into n lists
|
||||
Args:
|
||||
arr (list): training samples
|
||||
n (int): number of output lists
|
||||
"""
|
||||
return [arr[i::n] for i in range(n)]
|
||||
|
||||
|
||||
def get_server_metrics(server_rrefs):
|
||||
r"""
|
||||
A function that calls the remote server to obtain metrics
|
||||
collected during the benchmark run.
|
||||
Args:
|
||||
server_rrefs (dict): a dictionary containing server RRefs
|
||||
"""
|
||||
rank_metrics = []
|
||||
for rank, server_rref in server_rrefs.items():
|
||||
metrics = server_rref.rpc_sync().get_metrics(server_rref)
|
||||
@ -158,7 +251,18 @@ def get_server_metrics(server_rrefs):
|
||||
return rank_metrics
|
||||
|
||||
|
||||
def run_master(rank, model, data, args, extra_configs, rpc_backend_options):
|
||||
def run_master(rank, data, args, extra_configs, rpc_backend_options):
|
||||
r"""
|
||||
A function that runs the master process in the world. This function
|
||||
obtains remote references to initialized servers, splits the data,
|
||||
runs the trainers, and prints metrics.
|
||||
Args:
|
||||
rank (int): process number in the world
|
||||
data (list): training samples
|
||||
args (parser): benchmark configurations
|
||||
extra_configs (dict): configurations added by the user
|
||||
rpc_backend_options (rpc): configurations/options for the rpc TODO: fix
|
||||
"""
|
||||
world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
|
||||
rpc.init_rpc(
|
||||
get_name(
|
||||
@ -181,21 +285,30 @@ def run_master(rank, model, data, args, extra_configs, rpc_backend_options):
|
||||
|
||||
# warmup run the benchmark
|
||||
benchmark_warmup(
|
||||
args, extra_configs["trainer_config"], model, train_data, server_rrefs
|
||||
args, extra_configs["trainer_config"], train_data, server_rrefs
|
||||
)
|
||||
# run the benchmark
|
||||
trainer_futs = call_trainers(
|
||||
args, extra_configs["trainer_config"], model, train_data, server_rrefs
|
||||
args, extra_configs["trainer_config"], train_data, server_rrefs
|
||||
)
|
||||
# collect metrics and print
|
||||
metrics_printer = ProcessedMetricsPrinter()
|
||||
rank_metrics_list = wait_all(trainer_futs)
|
||||
metrics_printer.print_metrics("trainer", rank_metrics_list)
|
||||
rank_metrics_list = get_server_metrics(server_rrefs)
|
||||
metrics_printer.print_metrics("parameter server", rank_metrics_list)
|
||||
metrics_printer.print_metrics("server", rank_metrics_list)
|
||||
|
||||
|
||||
def run_benchmark(rank, model, data, args, config):
|
||||
def run_benchmark(rank, args, data):
|
||||
r"""
|
||||
A function that runs the benchmark.
|
||||
Args:
|
||||
rank (int): process number in the world
|
||||
args (parser): configuration args
|
||||
data (list): training samples
|
||||
"""
|
||||
|
||||
config = load_extra_configs(args)
|
||||
|
||||
torch.manual_seed(args.torch_seed)
|
||||
torch.cuda.manual_seed_all(args.cuda_seed)
|
||||
@ -208,7 +321,7 @@ def run_benchmark(rank, model, data, args, config):
|
||||
rpc_backend_options = TensorPipeRpcBackendOptions(rpc_timeout=args.rpc_timeout)
|
||||
if rank == world_size - 1:
|
||||
# master = [ntrainer + ncudatrainer + nserver + ncudaserver, ntrainer + ncudatrainer + nserver + ncudaserver]
|
||||
run_master(rank, model, data, args, config, rpc_backend_options)
|
||||
run_master(rank, data, args, config, rpc_backend_options)
|
||||
elif rank >= args.ntrainer + args.ncudatrainer:
|
||||
# parameter_servers = [ntrainer + ncudatrainer, ntrainer + ncudatrainer + nserver + ncudaserver)
|
||||
rpc.init_rpc(
|
||||
@ -243,13 +356,25 @@ def run_benchmark(rank, model, data, args, config):
|
||||
|
||||
|
||||
def get_json_config(file_name, id):
|
||||
r"""
|
||||
A function that loads a json configuration from a file.
|
||||
Args:
|
||||
file_name (str): name of configuration file to load
|
||||
id (str): configuration that will be loaded
|
||||
"""
|
||||
with open(os.path.join(Path(__file__).parent, file_name), "r") as f:
|
||||
json_config = json.load(f)[id]
|
||||
|
||||
return json_config
|
||||
|
||||
|
||||
def load_extra_configs(args):
|
||||
r"""
|
||||
A function that creates a dictionary that contains any extra configurations
|
||||
set by the user. The dictionary will contain two keys trainer_config and
|
||||
server_config, with default values None.
|
||||
Args:
|
||||
args (parser): launcher configurations
|
||||
"""
|
||||
trainer_config_file = args.trainer_config_path
|
||||
server_config_file = args.server_config_path
|
||||
configurations = {
|
||||
@ -263,30 +388,36 @@ def load_extra_configs(args):
|
||||
return configurations
|
||||
|
||||
|
||||
def get_data(data_class, data_config):
|
||||
data_class = get_benchmark_data_map()[data_class]
|
||||
return data_class(**data_config)
|
||||
|
||||
|
||||
def load_data(args):
|
||||
r"""
|
||||
A function that creates an instance of the data class.
|
||||
Args:
|
||||
args (parser): launcher configurations
|
||||
"""
|
||||
data_config_file = args.data_config_path
|
||||
data_config = get_json_config(data_config_file, args.data)
|
||||
return get_data(data_config["data_class"], data_config["configurations"])
|
||||
|
||||
|
||||
def get_model(model_class, model_config):
|
||||
model_class = get_benchmark_model_map()[model_class]
|
||||
return model_class(**model_config)
|
||||
data_class = data_map[data_config["data_class"]]
|
||||
return data_class(**data_config["configurations"])
|
||||
|
||||
|
||||
def load_model(args):
|
||||
r"""
|
||||
A function that creates an instance of the model class.
|
||||
Args:
|
||||
args (parser): launcher configurations
|
||||
"""
|
||||
model_config_file = args.model_config_path
|
||||
model_config = get_json_config(model_config_file, args.model)
|
||||
return get_model(model_config["model_class"], model_config["configurations"])
|
||||
model_class = model_map[model_config["model_class"]]
|
||||
return model_class(**model_config["configurations"])
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
r"""
|
||||
A function that creates multiple processes to run the benchmark.
|
||||
Args:
|
||||
args (parser): launcher configurations
|
||||
"""
|
||||
# CPU and RPC trainer checks
|
||||
if args.ntrainer > 0 and args.ncudatrainer > 0:
|
||||
assert args.nserver > 0 and args.ncudaserver > 0
|
||||
@ -297,21 +428,17 @@ def main(args):
|
||||
assert args.ncudatrainer > 0
|
||||
assert args.ncudatrainer % args.ncudaserver == 0
|
||||
|
||||
extra_configs = load_extra_configs(args)
|
||||
data = load_data(args)
|
||||
model = load_model(args)
|
||||
|
||||
world_size = (
|
||||
args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
|
||||
)
|
||||
|
||||
data = load_data(args)
|
||||
|
||||
mp.spawn(
|
||||
run_benchmark,
|
||||
args=(
|
||||
model,
|
||||
data,
|
||||
args,
|
||||
extra_configs,
|
||||
data,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True
|
||||
@ -383,7 +510,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of training examples used in one iteration"
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -419,15 +545,44 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--torch_seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="seed for generating random numbers to a non-deterministic random number"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cuda_seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="seed for generating random numbers to a random number for the current GPU"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocess_data",
|
||||
type=str,
|
||||
help="this function will be used to preprocess data before training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create_criterion",
|
||||
type=str,
|
||||
help="this function will be used to create the criterion used for model loss calculation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create_ddp_model",
|
||||
type=str,
|
||||
help="this function will be used to create the ddp model used during training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hook_state",
|
||||
type=str,
|
||||
help="this will be the state class used when registering the ddp communication hook"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_hook",
|
||||
type=str,
|
||||
default="allreduce_hook",
|
||||
help="ddp communication hook"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iteration_step",
|
||||
type=str,
|
||||
help="this will be the function called for each iteration of training"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(f"{args}\n")
|
||||
main(args)
|
||||
|
@ -9,13 +9,24 @@ class DummyModel(nn.Module):
|
||||
embedding_dim: int,
|
||||
dense_input_size: int,
|
||||
dense_output_size: int,
|
||||
dense_layers_count: int,
|
||||
sparse: bool
|
||||
):
|
||||
r"""
|
||||
A dummy model with an EmbeddingBag Layer and Dense Layer.
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
dense_input_size (int): size of each input sample
|
||||
dense_output_size (int): size of each output sample
|
||||
dense_layers_count: (int): number of dense layers in dense Sequential module
|
||||
sparse (bool): if True, gradient w.r.t. weight matrix will be a sparse tensor
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding = nn.EmbeddingBag(
|
||||
num_embeddings, embedding_dim, sparse=sparse
|
||||
)
|
||||
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(10)])
|
||||
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(dense_layers_count)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.embedding(x)
|
||||
|
@ -0,0 +1,5 @@
|
||||
from .DummyModel import DummyModel
|
||||
|
||||
model_map = {
|
||||
"DummyModel": DummyModel
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
from .server import AverageBatchParameterServer, AverageParameterServer
|
||||
|
||||
server_map = {
|
||||
"AverageParameterServer": AverageParameterServer,
|
||||
"AverageBatchParameterServer": AverageBatchParameterServer
|
||||
}
|
363
benchmarks/distributed/rpc/parameter_server/server/server.py
Normal file
363
benchmarks/distributed/rpc/parameter_server/server/server.py
Normal file
@ -0,0 +1,363 @@
|
||||
import functools
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from metrics.MetricsLogger import MetricsLogger
|
||||
from utils import sparse_rpc_format_to_tensor, sparse_tensor_to_rpc_format
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
|
||||
|
||||
class ParameterServerBase(ABC):
|
||||
|
||||
PARAMETER_SERVER_BATCH_METRIC = "parameter_server_batch_metric"
|
||||
PARAMETER_SERVER_STRAGGLER_METRIC = "parameter_server_straggler_metric"
|
||||
PARAM_INDEX_STRAGGLER = "param_index_straggler"
|
||||
PARAM_INDEX_BATCH = "param_index_batch"
|
||||
|
||||
def __init__(self, rank):
|
||||
r"""
|
||||
Inits ParameterServerBase class.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
"""
|
||||
self.__metrics_logger = MetricsLogger(rank)
|
||||
|
||||
@abstractmethod
|
||||
def process_gradient(self):
|
||||
r"""
|
||||
A method to be implemented by child class that will process a
|
||||
gradient received by a server.
|
||||
"""
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def average_gradient():
|
||||
r"""
|
||||
A method to be implemented by child class that will average
|
||||
gradients.
|
||||
"""
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def reset_state():
|
||||
r"""
|
||||
A method to be implemented by child class that will reset
|
||||
the server state.
|
||||
"""
|
||||
return
|
||||
|
||||
def record_start(self, type, key, name, cuda=True):
|
||||
r"""
|
||||
A method that records the start event for a metric.
|
||||
Args:
|
||||
type (str): group id for metric
|
||||
key (str): unique id for metric within a group
|
||||
name (str): description of the metric
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.__metrics_logger.record_start(
|
||||
type,
|
||||
key,
|
||||
name,
|
||||
cuda
|
||||
)
|
||||
|
||||
def record_end(self, type, key):
|
||||
r"""
|
||||
A method that records the end event for a metric
|
||||
Args:
|
||||
type (str): group id for metric
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.__metrics_logger.record_end(
|
||||
type,
|
||||
key
|
||||
)
|
||||
|
||||
def record_straggler_start(self, key, cuda=True):
|
||||
r"""
|
||||
A helper method that records a straggler metric
|
||||
for the given key. A user should call this when
|
||||
the first gradient for the param location is received.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.__metrics_logger.record_start(
|
||||
self.PARAMETER_SERVER_STRAGGLER_METRIC,
|
||||
key,
|
||||
self.PARAM_INDEX_STRAGGLER,
|
||||
cuda
|
||||
)
|
||||
|
||||
def record_straggler_end(self, key):
|
||||
r"""
|
||||
A helper method that records a straggler metric
|
||||
for the given key. A user should call this when
|
||||
the last gradient for the param location is received.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.__metrics_logger.record_end(
|
||||
self.PARAMETER_SERVER_STRAGGLER_METRIC,
|
||||
key
|
||||
)
|
||||
|
||||
def record_batch_start(self, key, cuda=True):
|
||||
r"""
|
||||
A helper method that records a batch metric
|
||||
for the given key. A user should call this when
|
||||
the first gradient for the param location is received.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.__metrics_logger.record_start(
|
||||
self.PARAMETER_SERVER_BATCH_METRIC,
|
||||
key,
|
||||
self.PARAM_INDEX_BATCH,
|
||||
cuda
|
||||
)
|
||||
|
||||
def record_batch_end(self, key):
|
||||
r"""
|
||||
A helper method that records a batch metric
|
||||
for the given key. A user should call this when
|
||||
all futures for a param location have had their
|
||||
result set.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.__metrics_logger.record_end(
|
||||
self.PARAMETER_SERVER_BATCH_METRIC,
|
||||
key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def record_method(name, type="method_metric", cuda=True):
|
||||
r"""
|
||||
A decorator that records a metric for the decorated method.
|
||||
Args:
|
||||
name (str): description of the metric
|
||||
type (str): group id for metric
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
def decorator(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, *args):
|
||||
key = time.time()
|
||||
self.__metrics_logger.record_start(type, key, name, cuda)
|
||||
result = function(self, *args)
|
||||
self.__metrics_logger.record_end(type, key)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def get_metrics(server_rref):
|
||||
r"""
|
||||
A staticmethod that returns metrics captured by the __metrics_logger.
|
||||
Args:
|
||||
server_rref (RRef): remote reference to the server
|
||||
"""
|
||||
self = server_rref.local_value()
|
||||
return self.__metrics_logger.get_processed_metrics()
|
||||
|
||||
def clear_metrics(self):
|
||||
r"""
|
||||
A method that clears __metrics_logger recorded metrics.
|
||||
"""
|
||||
return self.__metrics_logger.clear_metrics()
|
||||
|
||||
|
||||
class AverageParameterServer(ParameterServerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank,
|
||||
trainer_count,
|
||||
use_cuda_rpc
|
||||
):
|
||||
r"""
|
||||
A parameter server that averages the gradients
|
||||
from trainers for each training iteration step.
|
||||
Gradients are added as they are received from trainers.
|
||||
When all gradients have been received, the sum is
|
||||
divided by the number of trainers.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
trainer_count (int): count of trainers sending
|
||||
gradients to the server
|
||||
use_cuda_rpc (bool): indicator for CUDA RPC
|
||||
"""
|
||||
super().__init__(rank)
|
||||
|
||||
self.lock = threading.Lock()
|
||||
self.rank = rank
|
||||
self.trainer_count = trainer_count
|
||||
self.use_cuda_rpc = use_cuda_rpc
|
||||
|
||||
self.batch_number = 0
|
||||
self.futures = {}
|
||||
self.gradient_dict = {}
|
||||
|
||||
@staticmethod
|
||||
def reset_state(server_rref):
|
||||
r"""
|
||||
A method that clears the state of the server.
|
||||
Args:
|
||||
server_rref (RRef): remote reference to the server
|
||||
"""
|
||||
self = server_rref.local_value()
|
||||
self.batch_number = 0
|
||||
self.futures.clear()
|
||||
self.gradient_dict.clear()
|
||||
self.clear_metrics()
|
||||
|
||||
def param_key(self, param_loc):
|
||||
r"""
|
||||
A method that returns an encoded key that represents
|
||||
the current batch and param location.
|
||||
Args:
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
return f"{self.batch_number},{param_loc}"
|
||||
|
||||
def clear_batch_state(self):
|
||||
r"""
|
||||
Clears the current server batch state.
|
||||
"""
|
||||
self.futures.clear()
|
||||
self.gradient_dict.clear()
|
||||
|
||||
def process_gradient(self, gradient, param_loc):
|
||||
r"""
|
||||
Stores the gradient if param_loc is not in gradient_dict.
|
||||
Adds the gradient to param_loc if it is in gradient_dict.
|
||||
Args:
|
||||
gradient (torch.Tensor): tensor sent from trainer
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
if param_loc not in self.gradient_dict:
|
||||
self.record_straggler_start(self.param_key(param_loc))
|
||||
self.record_batch_start(self.param_key(param_loc))
|
||||
self.gradient_dict[param_loc] = gradient
|
||||
else:
|
||||
self.gradient_dict[param_loc] += gradient
|
||||
|
||||
@ParameterServerBase.record_method(name="average computation")
|
||||
def average(self, param_loc):
|
||||
r"""
|
||||
Obtains the tensor at the param_loc in the gradient_dict
|
||||
and then divides by number of trainers.
|
||||
Args:
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
param_loc_avg = self.gradient_dict[param_loc]
|
||||
param_loc_avg / (1.0 * self.trainer_count)
|
||||
return param_loc_avg
|
||||
|
||||
@staticmethod
|
||||
@rpc.functions.async_execution
|
||||
def average_gradient(
|
||||
server_rref,
|
||||
received_batch_number,
|
||||
param_loc,
|
||||
gradient
|
||||
):
|
||||
r"""
|
||||
An asynchronous function that will average gradients
|
||||
sent from trainers.
|
||||
Args:
|
||||
server_rref (RRef): remote reference to the server
|
||||
received_batch_number (int): batch number sent by
|
||||
the trainer
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
gradient (torch.Tensor or list): tensor sent by the trainer
|
||||
"""
|
||||
self = server_rref.local_value()
|
||||
if type(gradient) is list:
|
||||
gradient = sparse_rpc_format_to_tensor(gradient)
|
||||
gradient = gradient.cuda(self.rank)
|
||||
fut = torch.futures.Future()
|
||||
with self.lock:
|
||||
if self.batch_number < received_batch_number:
|
||||
self.batch_number = received_batch_number
|
||||
self.clear_batch_state()
|
||||
self.process_gradient(gradient, param_loc)
|
||||
if param_loc not in self.futures:
|
||||
self.futures[param_loc] = []
|
||||
self.futures[param_loc].append(fut)
|
||||
if len(self.futures[param_loc]) == self.trainer_count:
|
||||
self.record_straggler_end(self.param_key(param_loc))
|
||||
param_loc_avg = self.average(param_loc)
|
||||
if not self.use_cuda_rpc:
|
||||
param_loc_avg = param_loc_avg.cpu()
|
||||
if param_loc_avg.is_sparse:
|
||||
param_loc_avg = sparse_tensor_to_rpc_format(param_loc_avg)
|
||||
for cur_fut in self.futures[param_loc]:
|
||||
cur_fut.set_result(param_loc_avg)
|
||||
self.record_batch_end(self.param_key(param_loc))
|
||||
return fut
|
||||
|
||||
|
||||
class AverageBatchParameterServer(AverageParameterServer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank,
|
||||
trainer_count,
|
||||
use_cuda_rpc
|
||||
):
|
||||
r"""
|
||||
A parameter server that averages the gradients
|
||||
from trainers for each training iteration step.
|
||||
Gradients are stored and averaged when a gradient
|
||||
has been received from each trainer for a param
|
||||
location.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
trainer_count (int): count of trainers sending
|
||||
gradients to the server
|
||||
use_cuda_rpc (bool): indicator for CUDA RPC
|
||||
"""
|
||||
super().__init__(rank, trainer_count, use_cuda_rpc)
|
||||
|
||||
def process_gradient(self, gradient, param_loc):
|
||||
r"""
|
||||
Adds the gradient to param_loc bucket stored in
|
||||
the gradient_dict.
|
||||
Args:
|
||||
gradient (torch.Tensor): tensor sent from trainer
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
if param_loc not in self.gradient_dict:
|
||||
self.record_straggler_start(self.param_key(param_loc))
|
||||
self.record_batch_start(self.param_key(param_loc))
|
||||
self.gradient_dict[param_loc] = []
|
||||
self.gradient_dict[param_loc].append(gradient)
|
||||
|
||||
@ParameterServerBase.record_method(name="average computation")
|
||||
def average(self, param_loc):
|
||||
r"""
|
||||
Sums the gradients at the param_loc then divides by the
|
||||
number of trainers.
|
||||
Args:
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
param_loc_avg = self.gradient_dict[param_loc][0]
|
||||
for gradient in self.gradient_dict[param_loc][1:]:
|
||||
param_loc_avg += gradient
|
||||
param_loc_avg / (1.0 * self.trainer_count)
|
||||
return param_loc_avg
|
@ -1,143 +0,0 @@
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
from utils import sparse_rpc_format_to_tensor, sparse_tensor_to_rpc_format
|
||||
|
||||
from .ParameterServerBase import ParameterServerBase
|
||||
|
||||
|
||||
class AverageParameterServer(ParameterServerBase):
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank,
|
||||
trainer_count,
|
||||
use_cuda_rpc
|
||||
):
|
||||
r"""
|
||||
A parameter server that averages the gradients
|
||||
from trainers for each training iteration step.
|
||||
Gradients are added as they are received from trainers.
|
||||
When all gradients have been received, the sum is
|
||||
divided by the number of trainers.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
trainer_count (int): count of trainers sending
|
||||
gradients to the server
|
||||
use_cuda_rpc (bool): indicator for CUDA RPC
|
||||
"""
|
||||
super().__init__(rank)
|
||||
|
||||
self.rank = rank
|
||||
self.trainer_count = trainer_count
|
||||
self.use_cuda_rpc = use_cuda_rpc
|
||||
|
||||
self.batch_number = 0
|
||||
self.futures = {}
|
||||
self.gradient_dict = {}
|
||||
|
||||
@staticmethod
|
||||
def reset_state(server_rref):
|
||||
r"""
|
||||
A method that clears the state of the server.
|
||||
Args:
|
||||
server_rref (RRef): remote reference to the server
|
||||
"""
|
||||
self = server_rref.local_value()
|
||||
self.batch_number = 0
|
||||
self.futures.clear()
|
||||
self.gradient_dict.clear()
|
||||
self.clear_metrics()
|
||||
|
||||
def param_key(self, param_loc):
|
||||
r"""
|
||||
A method that returns an encoded key that represents
|
||||
the current batch and param location.
|
||||
Args:
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
return f"{self.batch_number},{param_loc}"
|
||||
|
||||
def clear_batch_state(self):
|
||||
r"""
|
||||
Clears the current server batch state.
|
||||
"""
|
||||
self.futures.clear()
|
||||
self.gradient_dict.clear()
|
||||
|
||||
def process_gradient(self, gradient, param_loc):
|
||||
r"""
|
||||
Stores the gradient if param_loc is not in gradient_dict.
|
||||
Adds the gradient to param_loc if it is in gradient_dict.
|
||||
Args:
|
||||
gradient (torch.Tensor): tensor sent from trainer
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
if param_loc not in self.gradient_dict:
|
||||
self.record_straggler_start(self.param_key(param_loc))
|
||||
self.record_batch_start(self.param_key(param_loc))
|
||||
self.gradient_dict[param_loc] = gradient
|
||||
else:
|
||||
self.gradient_dict[param_loc] += gradient
|
||||
|
||||
@ParameterServerBase.record_method(name="average computation")
|
||||
def average(self, param_loc):
|
||||
r"""
|
||||
Obtains the tensor at the param_loc in the gradient_dict
|
||||
and then divides by number of trainers.
|
||||
Args:
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
"""
|
||||
param_loc_avg = self.gradient_dict[param_loc]
|
||||
param_loc_avg / (1.0 * self.trainer_count)
|
||||
return param_loc_avg
|
||||
|
||||
@staticmethod
|
||||
@rpc.functions.async_execution
|
||||
def average_gradient(
|
||||
server_rref,
|
||||
received_batch_number,
|
||||
param_loc,
|
||||
gradient
|
||||
):
|
||||
r"""
|
||||
An asynchronous function that will average gradients
|
||||
sent from trainers.
|
||||
Args:
|
||||
server_rref (RRef): remote reference to the server
|
||||
received_batch_number (int): batch number sent by
|
||||
the trainer
|
||||
param_loc (int): bucket location sent by the trainer
|
||||
containing the gradient
|
||||
gradient (torch.Tensor or list): tensor sent by the trainer
|
||||
"""
|
||||
self = server_rref.local_value()
|
||||
if type(gradient) is list:
|
||||
gradient = sparse_rpc_format_to_tensor(gradient)
|
||||
gradient = gradient.cuda(self.rank)
|
||||
fut = torch.futures.Future()
|
||||
with self.lock:
|
||||
if self.batch_number < received_batch_number:
|
||||
self.batch_number = received_batch_number
|
||||
self.clear_batch_state()
|
||||
self.process_gradient(gradient, param_loc)
|
||||
if param_loc not in self.futures:
|
||||
self.futures[param_loc] = []
|
||||
self.futures[param_loc].append(fut)
|
||||
if len(self.futures[param_loc]) == self.trainer_count:
|
||||
self.record_straggler_end(self.param_key(param_loc))
|
||||
param_loc_avg = self.average(param_loc)
|
||||
if not self.use_cuda_rpc:
|
||||
param_loc_avg = param_loc_avg.cpu()
|
||||
if param_loc_avg.is_sparse:
|
||||
param_loc_avg = sparse_tensor_to_rpc_format(param_loc_avg)
|
||||
for cur_fut in self.futures[param_loc]:
|
||||
cur_fut.set_result(param_loc_avg)
|
||||
self.record_batch_end(self.param_key(param_loc))
|
||||
return fut
|
@ -1,170 +0,0 @@
|
||||
import functools
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from metrics.MetricsLogger import MetricsLogger
|
||||
|
||||
|
||||
class ParameterServerBase(ABC):
|
||||
|
||||
PARAMETER_SERVER_BATCH_METRIC = "parameter_server_batch_metric"
|
||||
PARAMETER_SERVER_STRAGGLER_METRIC = "parameter_server_straggler_metric"
|
||||
PARAM_INDEX_STRAGGLER = "param_index_straggler"
|
||||
PARAM_INDEX_BATCH = "param_index_batch"
|
||||
|
||||
def __init__(self, rank):
|
||||
r"""
|
||||
Inits ParameterServerBase class.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
"""
|
||||
self.__metrics_logger = MetricsLogger(rank)
|
||||
|
||||
@abstractmethod
|
||||
def process_gradient(self):
|
||||
r"""
|
||||
A method to be implemented by child class that will process a
|
||||
gradient received by a server.
|
||||
"""
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def average_gradient():
|
||||
r"""
|
||||
A method to be implemented by child class that will average
|
||||
gradients.
|
||||
"""
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def reset_state():
|
||||
r"""
|
||||
A method to be implemented by child class that will reset
|
||||
the server state.
|
||||
"""
|
||||
return
|
||||
|
||||
def record_start(self, type, key, name, cuda=True):
|
||||
r"""
|
||||
A method that records the start event for a metric.
|
||||
Args:
|
||||
type (str): group id for metric
|
||||
key (str): unique id for metric within a group
|
||||
name (str): description of the metric
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.__metrics_logger.record_start(
|
||||
type,
|
||||
key,
|
||||
name,
|
||||
cuda
|
||||
)
|
||||
|
||||
def record_end(self, type, key):
|
||||
r"""
|
||||
A method that records the end event for a metric
|
||||
Args:
|
||||
type (str): group id for metric
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.__metrics_logger.record_end(
|
||||
type,
|
||||
key
|
||||
)
|
||||
|
||||
def record_straggler_start(self, key, cuda=True):
|
||||
r"""
|
||||
A helper method that records a straggler metric
|
||||
for the given key. A user should call this when
|
||||
the first gradient for the param location is received.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.__metrics_logger.record_start(
|
||||
self.PARAMETER_SERVER_STRAGGLER_METRIC,
|
||||
key,
|
||||
self.PARAM_INDEX_STRAGGLER,
|
||||
cuda
|
||||
)
|
||||
|
||||
def record_straggler_end(self, key):
|
||||
r"""
|
||||
A helper method that records a straggler metric
|
||||
for the given key. A user should call this when
|
||||
the last gradient for the param location is received.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.__metrics_logger.record_end(
|
||||
self.PARAMETER_SERVER_STRAGGLER_METRIC,
|
||||
key
|
||||
)
|
||||
|
||||
def record_batch_start(self, key, cuda=True):
|
||||
r"""
|
||||
A helper method that records a batch metric
|
||||
for the given key. A user should call this when
|
||||
the first gradient for the param location is received.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.__metrics_logger.record_start(
|
||||
self.PARAMETER_SERVER_BATCH_METRIC,
|
||||
key,
|
||||
self.PARAM_INDEX_BATCH,
|
||||
cuda
|
||||
)
|
||||
|
||||
def record_batch_end(self, key):
|
||||
r"""
|
||||
A helper method that records a batch metric
|
||||
for the given key. A user should call this when
|
||||
all futures for a param location have had their
|
||||
result set.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.__metrics_logger.record_end(
|
||||
self.PARAMETER_SERVER_BATCH_METRIC,
|
||||
key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def record_method(name, type="method_metric", cuda=True):
|
||||
r"""
|
||||
A decorator that records a metric for the decorated method.
|
||||
Args:
|
||||
name (str): description of the metric
|
||||
type (str): group id for metric
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
def decorator(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(self, *args):
|
||||
key = time.time()
|
||||
self.__metrics_logger.record_start(type, key, name, cuda)
|
||||
result = function(self, *args)
|
||||
self.__metrics_logger.record_end(type, key)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def get_metrics(server_rref):
|
||||
r"""
|
||||
A staticmethod that returns metrics captured by the __metrics_logger.
|
||||
Args:
|
||||
server_rref (RRef): remote reference to the server
|
||||
"""
|
||||
self = server_rref.local_value()
|
||||
return self.__metrics_logger.get_processed_metrics()
|
||||
|
||||
def clear_metrics(self):
|
||||
r"""
|
||||
A method that clears __metrics_logger recorded metrics.
|
||||
"""
|
||||
return self.__metrics_logger.clear_metrics()
|
@ -1,31 +0,0 @@
|
||||
import subprocess
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
|
||||
script_dir = join(
|
||||
Path(__file__).parent, "experiment_scripts"
|
||||
)
|
||||
encoding = 'utf-8'
|
||||
|
||||
|
||||
def run_script(script_name):
|
||||
# runs the script and asserts that there are no errors
|
||||
p = subprocess.run(
|
||||
["bash", f"{join(script_dir,script_name)}"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE
|
||||
)
|
||||
error = p.stderr.decode(encoding)
|
||||
assert not error
|
||||
|
||||
|
||||
def test_ddp_nccl_allreduce():
|
||||
run_script("ddp_nccl_allreduce.sh")
|
||||
|
||||
|
||||
def test_ddp_cpu_sparse_rpc_nccl_allreduce():
|
||||
run_script("ddp_cpu_sparse_rpc_nccl_allreduce.sh")
|
||||
|
||||
|
||||
def test_ddp_cuda_sparse_rpc_nccl_allreduce():
|
||||
run_script("ddp_cuda_sparse_rpc_nccl_allreduce.sh")
|
@ -0,0 +1,30 @@
|
||||
from .criterions import cel
|
||||
from .ddp_models import basic_ddp_model
|
||||
from .hook_states import BasicHookState
|
||||
from .iteration_steps import basic_iteration_step
|
||||
from .preprocess_data import preprocess_dummy_data
|
||||
from .trainer import DdpTrainer
|
||||
|
||||
criterion_map = {
|
||||
"cel": cel
|
||||
}
|
||||
|
||||
ddp_model_map = {
|
||||
"basic_ddp_model": basic_ddp_model
|
||||
}
|
||||
|
||||
iteration_step_map = {
|
||||
"basic_iteration_step": basic_iteration_step
|
||||
}
|
||||
|
||||
preprocess_data_map = {
|
||||
"preprocess_dummy_data": preprocess_dummy_data
|
||||
}
|
||||
|
||||
hook_state_map = {
|
||||
"BasicHookState": BasicHookState
|
||||
}
|
||||
|
||||
trainer_map = {
|
||||
"DdpTrainer": DdpTrainer
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def cel(rank):
|
||||
r"""A function that creates a CrossEntropyLoss
|
||||
criterion for training.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
"""
|
||||
return nn.CrossEntropyLoss().cuda(rank)
|
@ -0,0 +1,23 @@
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def basic_ddp_model(self, rank, model, process_group, hook_state, hook):
|
||||
r"""
|
||||
A function that creates a ddp_model and hook_state objects.
|
||||
The ddp model is is initialized with a single device id and
|
||||
the process group. The ddp_model also registers the communication
|
||||
hook.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
model (nn.Module): neural network model
|
||||
process_group (ProcessGroup): distributed process group
|
||||
HookState (class): class that will be used to keep track of state
|
||||
during training.
|
||||
hook (function): ddp communication hook
|
||||
"""
|
||||
ddp_model = DDP(
|
||||
model, device_ids=[rank], process_group=process_group
|
||||
)
|
||||
hook_state = hook_state(self, process_group)
|
||||
ddp_model.register_comm_hook(hook_state, hook)
|
||||
return ddp_model, hook_state
|
@ -0,0 +1,28 @@
|
||||
class BasicHookState:
|
||||
|
||||
def __init__(self, cref, process_group):
|
||||
r"""
|
||||
A class that holds state information that is needed by the communication hook
|
||||
during the training algorithm.
|
||||
Args:
|
||||
cref (DdpTrainer): reference to the self keyword of the trainer instance
|
||||
process_group (ProcessGroup): distributed process group
|
||||
"""
|
||||
self.cref = cref
|
||||
self.process_group = process_group
|
||||
self.batch_number = -1
|
||||
|
||||
def get_key(self, bucket_index):
|
||||
r"""
|
||||
A method that returns an encoded key that represents the current batch and
|
||||
bucket index.
|
||||
Args:
|
||||
bucket_index (int): index of the bucket being processed in backward
|
||||
"""
|
||||
return f"{self.batch_number},{bucket_index}"
|
||||
|
||||
def next_batch(self):
|
||||
r"""
|
||||
A method that increments batch_number by 1.
|
||||
"""
|
||||
self.batch_number += 1
|
@ -0,0 +1,23 @@
|
||||
def basic_iteration_step(self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch):
|
||||
r"""
|
||||
A function that performs an iteration of training.
|
||||
Args:
|
||||
ddp_model (nn.Module): distributed data parallel model
|
||||
criterion (nn.Module): loss function to measure model
|
||||
optimizer (optim.Optimizer): updates model parameters
|
||||
hook_state (object): ddp communication hook state object
|
||||
epoch (int): index of pass through the data
|
||||
index (int): iteration number - 1 in current batch
|
||||
batch (list): training examples
|
||||
"""
|
||||
hook_state.next_batch()
|
||||
self.record_batch_start(self.epoch_key(epoch, index))
|
||||
optimizer.zero_grad()
|
||||
self.record_forward_start(self.epoch_key(epoch, index))
|
||||
loss = criterion(ddp_model(batch[0]), batch[1])
|
||||
self.record_forward_end(self.epoch_key(epoch, index))
|
||||
self.record_backward_start(self.epoch_key(epoch, index))
|
||||
loss.backward()
|
||||
self.record_backward_end(self.epoch_key(epoch, index))
|
||||
optimizer.step()
|
||||
self.record_batch_end(self.epoch_key(epoch, index))
|
@ -0,0 +1,12 @@
|
||||
def preprocess_dummy_data(rank, data):
|
||||
r"""
|
||||
A function that moves the data from CPU to GPU
|
||||
for DummyData class.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
data (list): training examples
|
||||
"""
|
||||
for i in range(len(data)):
|
||||
data[i][0] = data[i][0].cuda(rank)
|
||||
data[i][1] = data[i][1].cuda(rank)
|
||||
return data
|
@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from metrics.MetricsLogger import MetricsLogger
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TrainerBase(ABC):
|
||||
|
||||
@ -175,3 +177,88 @@ class TrainerBase(ABC):
|
||||
A method that clears __metrics_logger recorded metrics.
|
||||
"""
|
||||
return self.__metrics_logger.clear_metrics()
|
||||
|
||||
|
||||
class DdpTrainer(TrainerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_group,
|
||||
use_cuda_rpc,
|
||||
server_rref,
|
||||
backend,
|
||||
epochs,
|
||||
preprocess_data,
|
||||
create_criterion,
|
||||
create_ddp_model,
|
||||
hook_state_class,
|
||||
hook,
|
||||
iteration_step
|
||||
):
|
||||
r"""
|
||||
A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
|
||||
using the process_group implementation.
|
||||
Args:
|
||||
process_group (ProcessGroup): distributed process group
|
||||
use_cuda_rpc (bool): indicator for CUDA RPC
|
||||
server_rref (RRef): remote reference to the server
|
||||
backend (str): distributed communication backend
|
||||
epochs (int): epoch count for training
|
||||
preprocess_data (function): preprocesses data passed
|
||||
to the trainer before starting training
|
||||
create_criterion (function): creates a criterion to calculate loss
|
||||
create_ddp_model (function): creates a ddp model for the trainer
|
||||
hook_state_class (class): class that will be used to keep tracking of state
|
||||
during training.
|
||||
hook (function): ddp communication hook
|
||||
iteration_step (function): will perform 1 step of training
|
||||
"""
|
||||
super().__init__(process_group.rank())
|
||||
self.process_group = process_group
|
||||
self.use_cuda_rpc = use_cuda_rpc
|
||||
self.server_rref = server_rref
|
||||
self.backend = backend
|
||||
self.epochs = epochs
|
||||
self.preprocess_data = preprocess_data
|
||||
self.create_criterion = create_criterion
|
||||
self.create_ddp_model = create_ddp_model
|
||||
self.hook_state_class = hook_state_class
|
||||
self.hook = hook
|
||||
self.iteration_step = iteration_step
|
||||
|
||||
self.rank = process_group.rank()
|
||||
self.trainer_count = process_group.size()
|
||||
|
||||
def epoch_key(self, epoch, index):
|
||||
r"""
|
||||
A method that returns an encoded key that represents the current epoch and
|
||||
iteration index.
|
||||
Args:
|
||||
epoch (int): epoch index
|
||||
index (int): iteration index
|
||||
"""
|
||||
return f"{epoch},{index}"
|
||||
|
||||
def train(self, model, data):
|
||||
r"""
|
||||
A method that implements the training algorithm.
|
||||
Args:
|
||||
model (nn.Module): neural network model
|
||||
data (list): training examples
|
||||
"""
|
||||
model = model.cuda(self.rank)
|
||||
data = self.preprocess_data(self.rank, data)
|
||||
criterion = self.create_criterion(self.rank)
|
||||
ddp_model, hook_state = self.create_ddp_model(
|
||||
self, self.rank, model, self.process_group, self.hook_state_class, self.hook
|
||||
)
|
||||
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
if epoch % 5 == 0 and self.rank == 0:
|
||||
print(f"train epoch={epoch}")
|
||||
for index, batch in enumerate(data):
|
||||
self.iteration_step(
|
||||
self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch
|
||||
)
|
||||
torch.cuda.synchronize(self.rank)
|
@ -1,107 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from .DdpTrainerBase import DdpTrainerBase
|
||||
|
||||
|
||||
class DdpNcclTrainer(DdpTrainerBase):
|
||||
|
||||
class HookState:
|
||||
|
||||
def __init__(self, cref, process_group):
|
||||
self.cref = cref
|
||||
self.process_group = process_group
|
||||
self.process_group_size = process_group.size()
|
||||
self.param_location = 0
|
||||
self.batch_number = -1
|
||||
|
||||
def get_key(self):
|
||||
return f"{self.batch_number},{self.param_location}"
|
||||
|
||||
def next_batch_state(self):
|
||||
self.param_location = 0
|
||||
self.batch_number += 1
|
||||
|
||||
def __init__(self, rank, trainer_count, ps_rref, epochs):
|
||||
super().__init__(rank)
|
||||
self.rank = rank
|
||||
self.trainer_count = trainer_count
|
||||
self.epochs = epochs
|
||||
|
||||
@staticmethod
|
||||
def hook(state, bucket):
|
||||
cref = state.cref
|
||||
tensors_count = len(cref.bucket_to_parameters(bucket))
|
||||
tensors = [bucket.get_tensor() / state.process_group_size]
|
||||
key = state.get_key()
|
||||
cref.record_hook_fut_start(key, cref.NCCL_ALLREDUCE)
|
||||
fut = state.process_group.allreduce(tensors).get_future()
|
||||
state.param_location += tensors_count
|
||||
|
||||
def callback(fut):
|
||||
cref.record_hook_fut_end(key)
|
||||
return fut.wait()
|
||||
|
||||
return fut.then(callback)
|
||||
|
||||
def train(self, model, data):
|
||||
torch.manual_seed(0)
|
||||
model = model.cuda(self.rank)
|
||||
for i in range(len(data)):
|
||||
data[i][0] = data[i][0].cuda(self.rank)
|
||||
data[i][1] = data[i][1].cuda(self.rank)
|
||||
torch.cuda.synchronize(self.rank)
|
||||
|
||||
process_group_size = self.trainer_count
|
||||
|
||||
store = c10d.FileStore("/tmp/tmpn_k_8so02", process_group_size)
|
||||
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, self.rank, process_group_size
|
||||
)
|
||||
|
||||
ddp_model = DDP(
|
||||
model, device_ids=[self.rank], process_group=process_group
|
||||
)
|
||||
|
||||
hook_state = self.HookState(self, process_group)
|
||||
|
||||
ddp_model.register_comm_hook(hook_state, DdpNcclTrainer.hook)
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda(self.rank)
|
||||
|
||||
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
|
||||
|
||||
def epoch_key(epoch, index):
|
||||
return f"{epoch},{index}"
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
for index, batch in enumerate(data):
|
||||
hook_state.next_batch_state()
|
||||
input, target = batch[0], batch[1]
|
||||
|
||||
self.record_batch_start(epoch_key(epoch, index))
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
self.record_forward_start(epoch_key(epoch, index))
|
||||
|
||||
out = ddp_model(input)
|
||||
|
||||
self.record_forward_end(epoch_key(epoch, index))
|
||||
|
||||
loss = criterion(out, target)
|
||||
|
||||
self.record_backward_start(epoch_key(epoch, index))
|
||||
|
||||
loss.backward()
|
||||
|
||||
self.record_backward_end(epoch_key(epoch, index))
|
||||
|
||||
optimizer.step()
|
||||
|
||||
self.record_batch_end(epoch_key(epoch, index))
|
||||
|
||||
torch.cuda.synchronize(self.rank)
|
@ -1,54 +0,0 @@
|
||||
from utils import process_bucket_with_remote_server
|
||||
|
||||
from .DdpTrainer import DdpTrainer
|
||||
|
||||
|
||||
class DdpSparseRpcTrainer(DdpTrainer):
|
||||
|
||||
def __init__(self, rank, trainer_count, process_group, use_cuda_rpc, server_rref, backend, epochs):
|
||||
r"""
|
||||
A trainer that implements a DDP training algorithm using a server and process group
|
||||
allreduce. The trainer sends sparse gradients using RPC, and the server averages and
|
||||
returns the gradients. The process group uses the backend allreduce implementation
|
||||
to average the dense gradients.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
trainer_count (int): count of trainer in the world
|
||||
process_group (ProcessGroup): distributed process group
|
||||
use_cuda_rpc (bool): indicator for CUDA RPC
|
||||
server_rref (RRef): remote reference to the server
|
||||
backend (str): distributed communication backend
|
||||
epochs (int): epoch count for training
|
||||
"""
|
||||
super().__init__(rank, trainer_count, process_group, use_cuda_rpc, server_rref, backend, epochs)
|
||||
|
||||
@staticmethod
|
||||
def hook(state, bucket):
|
||||
r"""
|
||||
A ddp communication hook that uses the current backend allreduce
|
||||
implementation for dense tensors and a server for sparse tensors.
|
||||
Args:
|
||||
state (object): maintains state during the training process
|
||||
bucket (GradBucket): gradient bucket
|
||||
"""
|
||||
tensor = bucket.get_tensor()
|
||||
if tensor.is_sparse:
|
||||
return process_bucket_with_remote_server(state, bucket)
|
||||
else:
|
||||
cref = state.cref
|
||||
tensor = [tensor / state.process_group.size()]
|
||||
key = state.get_key(bucket.get_index())
|
||||
cref.record_hook_fut_start(key, f"{cref.backend}_dense_allreduce")
|
||||
fut = state.process_group.allreduce(tensor).get_future()
|
||||
|
||||
def callback(fut):
|
||||
cref.record_hook_fut_end(key)
|
||||
return fut.wait()
|
||||
|
||||
return fut.then(callback)
|
||||
|
||||
def get_hook(self):
|
||||
r"""
|
||||
returns DdpSparseRpcTrainer.hook
|
||||
"""
|
||||
return DdpSparseRpcTrainer.hook
|
@ -1,183 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from .DdpTrainerBase import DdpTrainerBase
|
||||
|
||||
|
||||
class DdpTrainer(DdpTrainerBase):
|
||||
|
||||
class HookState:
|
||||
|
||||
def __init__(self, cref, process_group):
|
||||
r"""
|
||||
A class that holds state information that is needed by the communication hook
|
||||
during the training algorithm.
|
||||
Args:
|
||||
cref (DdpTrainer): reference to the self keyword of the trainer instance
|
||||
process_group (ProcessGroup): distributed process group
|
||||
"""
|
||||
self.cref = cref
|
||||
self.process_group = process_group
|
||||
self.batch_number = -1
|
||||
|
||||
def get_key(self, bucket_index):
|
||||
r"""
|
||||
A method that returns an encoded key that represents the current batch and
|
||||
bucket index.
|
||||
Args:
|
||||
bucket_index (int): index of the bucket being processed in backward
|
||||
"""
|
||||
return f"{self.batch_number},{bucket_index}"
|
||||
|
||||
def next_batch(self):
|
||||
r"""
|
||||
A method that increments batch_number by 1.
|
||||
"""
|
||||
self.batch_number += 1
|
||||
|
||||
def __init__(self, rank, trainer_count, process_group, use_cuda_rpc, server_rref, backend, epochs):
|
||||
r"""
|
||||
A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
|
||||
using the process_group implementation.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
trainer_count (int): count of trainer in the world
|
||||
process_group (ProcessGroup): distributed process group
|
||||
use_cuda_rpc (bool): indicator for CUDA RPC
|
||||
server_rref (RRef): remote reference to the server
|
||||
backend (str): distributed communication backend
|
||||
epochs (int): epoch count for training
|
||||
"""
|
||||
super().__init__(rank)
|
||||
self.rank = rank
|
||||
self.trainer_count = trainer_count
|
||||
self.process_group = process_group
|
||||
self.use_cuda_rpc = use_cuda_rpc
|
||||
self.server_rref = server_rref
|
||||
self.backend = backend
|
||||
self.epochs = epochs
|
||||
|
||||
@staticmethod
|
||||
def hook(state, bucket):
|
||||
r"""
|
||||
A ddp communication hook that uses the process_group allreduce implementation.
|
||||
Args:
|
||||
state (object): maintains state during the training process
|
||||
bucket (GradBucket): gradient bucket
|
||||
"""
|
||||
cref = state.cref
|
||||
tensors = [bucket.get_tensor() / state.process_group.size()]
|
||||
key = state.get_key(bucket.get_index())
|
||||
cref.record_hook_fut_start(key, f"{cref.backend}_allreduce")
|
||||
fut = state.process_group.allreduce(tensors).get_future()
|
||||
|
||||
def callback(fut):
|
||||
cref.record_hook_fut_end(key)
|
||||
return fut.wait()
|
||||
|
||||
return fut.then(callback)
|
||||
|
||||
def get_hook(self):
|
||||
r"""
|
||||
returns DdpTrainer.hook
|
||||
"""
|
||||
return DdpTrainer.hook
|
||||
|
||||
def create_ddp_model(self, model):
|
||||
r"""
|
||||
A method that creates a ddp_model and hook_state objects.
|
||||
It returns the ddp_model and hook_state objects.
|
||||
Args:
|
||||
model (nn.Module): neural network model
|
||||
"""
|
||||
ddp_model = DDP(
|
||||
model, device_ids=[self.rank], process_group=self.process_group
|
||||
)
|
||||
hook_state = self.HookState(self, self.process_group)
|
||||
ddp_model.register_comm_hook(hook_state, self.get_hook())
|
||||
return ddp_model, hook_state
|
||||
|
||||
def create_criterion(self):
|
||||
r"""
|
||||
A method that creates a criterion for the training
|
||||
algorithm.
|
||||
"""
|
||||
return nn.CrossEntropyLoss().cuda(self.rank)
|
||||
|
||||
def create_optimizer(self, parameters, lr):
|
||||
r"""
|
||||
A method that creates a optimizer for the training
|
||||
algorithm.
|
||||
Args:
|
||||
parameters (iterable): iterable of parameters to optimize
|
||||
lr (float): learning rate
|
||||
"""
|
||||
return torch.optim.SGD(parameters, lr)
|
||||
|
||||
def epoch_key(self, epoch, index):
|
||||
r"""
|
||||
A method that returns an encoded key that represents the current epoch and
|
||||
iteration index.
|
||||
Args:
|
||||
epoch (int): epoch index
|
||||
index (int): iteration index
|
||||
"""
|
||||
return f"{epoch},{index}"
|
||||
|
||||
def preprocess_data(self, data):
|
||||
r"""
|
||||
A method that moves the data from CPU to GPU.
|
||||
Args:
|
||||
data (list): training examples
|
||||
"""
|
||||
for i in range(len(data)):
|
||||
data[i][0] = data[i][0].cuda(self.rank)
|
||||
data[i][1] = data[i][1].cuda(self.rank)
|
||||
return data
|
||||
|
||||
def iteration_step(self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch):
|
||||
r"""
|
||||
A method that performs an iteration of training.
|
||||
Args:
|
||||
ddp_model (nn.Module): distributed data parallel model
|
||||
criterion (nn.Module): loss function to measure model
|
||||
optimizer (optim.Optimizer): updates model parameters
|
||||
hook_state (object): ddp communication hook state object
|
||||
epoch (int): index of pass through the data
|
||||
index (int): iteration number - 1 in current batch
|
||||
batch (list): training examples
|
||||
"""
|
||||
hook_state.next_batch()
|
||||
input, target = batch[0], batch[1]
|
||||
self.record_batch_start(self.epoch_key(epoch, index))
|
||||
optimizer.zero_grad()
|
||||
self.record_forward_start(self.epoch_key(epoch, index))
|
||||
out = ddp_model(input)
|
||||
self.record_forward_end(self.epoch_key(epoch, index))
|
||||
loss = criterion(out, target)
|
||||
self.record_backward_start(self.epoch_key(epoch, index))
|
||||
loss.backward()
|
||||
self.record_backward_end(self.epoch_key(epoch, index))
|
||||
optimizer.step()
|
||||
self.record_batch_end(self.epoch_key(epoch, index))
|
||||
|
||||
def train(self, model, data):
|
||||
r"""
|
||||
A method that implements the training algorithm.
|
||||
Args:
|
||||
model (nn.Module): neural network model
|
||||
data (list): training examples
|
||||
"""
|
||||
model = model.cuda(self.rank)
|
||||
data = self.preprocess_data(data)
|
||||
ddp_model, hook_state = self.create_ddp_model(model)
|
||||
criterion = self.create_criterion()
|
||||
optimizer = self.create_optimizer(ddp_model.parameters(), 1e-4)
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
for index, batch in enumerate(data):
|
||||
self.iteration_step(
|
||||
ddp_model, criterion, optimizer, hook_state, epoch, index, batch
|
||||
)
|
||||
torch.cuda.synchronize(self.rank)
|
@ -1,49 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from .TrainerBase import TrainerBase
|
||||
|
||||
|
||||
class DdpTrainerBase(TrainerBase):
|
||||
|
||||
HOOK_FUTURE_METRIC = "hook_future_metric"
|
||||
|
||||
def __init__(self, rank):
|
||||
r"""
|
||||
Inits DdpTrainerBase class.
|
||||
Args:
|
||||
rank (int): worker rank
|
||||
"""
|
||||
super().__init__(rank)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def hook(state, bucket):
|
||||
r"""
|
||||
A method to be implemented by child class that will implement a DDP
|
||||
training algorithm.
|
||||
Args:
|
||||
state (object): maintains state during the training process
|
||||
bucket (GradBucket): gradient bucket
|
||||
"""
|
||||
return
|
||||
|
||||
def record_hook_fut_start(self, key, name, cuda=True):
|
||||
r"""
|
||||
A helper method that records a hook future metric
|
||||
for the given key. A user should call this before
|
||||
sending async request in the DDP communication hook.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
cuda (bool): indicator to determine if this is a CUDA metric
|
||||
"""
|
||||
self.record_start(self.HOOK_FUTURE_METRIC, key, name, cuda)
|
||||
|
||||
def record_hook_fut_end(self, key):
|
||||
r"""
|
||||
A helper method that records a hook future metric
|
||||
for the given key. A user should call this in a callback
|
||||
attached to the future returned by an async request.
|
||||
Args:
|
||||
key (str): unique id for metric within a group
|
||||
"""
|
||||
self.record_end(self.HOOK_FUTURE_METRIC, key)
|
@ -6,7 +6,7 @@ RPC_DENSE = "rpc_dense"
|
||||
|
||||
def sparse_tensor_to_rpc_format(sparse_tensor):
|
||||
r"""
|
||||
A helper method creates a list containing the indices, values, and size
|
||||
A helper function creates a list containing the indices, values, and size
|
||||
of a coalesced sparse tensor.
|
||||
Args:
|
||||
sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
|
||||
@ -17,7 +17,7 @@ def sparse_tensor_to_rpc_format(sparse_tensor):
|
||||
|
||||
def sparse_rpc_format_to_tensor(sparse_rpc_format):
|
||||
r"""
|
||||
A helper method creates a sparse_coo_tensor from indices, values, and size.
|
||||
A helper function creates a sparse_coo_tensor from indices, values, and size.
|
||||
Args:
|
||||
sparse_rpc_format (list): sparse_coo_tensor represented as a list
|
||||
"""
|
||||
@ -50,14 +50,15 @@ def process_bucket_with_remote_server(state, bucket):
|
||||
tensor
|
||||
]
|
||||
key = state.get_key(b_index)
|
||||
cref.record_hook_fut_start(
|
||||
cref.record_start(
|
||||
"hook_future_metric",
|
||||
key,
|
||||
RPC_SPARSE if sparse else RPC_DENSE
|
||||
)
|
||||
fut = cref.server_rref.rpc_async().average_gradient(*server_args)
|
||||
|
||||
def callback(fut):
|
||||
cref.record_hook_fut_end(key)
|
||||
cref.record_end("hook_future_metric", key)
|
||||
tensor = fut.wait()
|
||||
if type(tensor) is list:
|
||||
tensor = sparse_rpc_format_to_tensor(tensor)
|
||||
|
Reference in New Issue
Block a user