refactor ps benchmark (#60784)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60784

This pr refactors the ps benchmark for modular trainers.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D29697291

Pulled By: gcramer23

fbshipit-source-id: 64579a1f5326d3cd9f32936dcf53bc243d54b71d
This commit is contained in:
Garrett Cramer
2021-07-14 13:14:08 -07:00
committed by Facebook GitHub Bot
parent 7d2ea9a8f7
commit 304c02ee44
28 changed files with 838 additions and 923 deletions

View File

@ -1,40 +0,0 @@
from data.DummyData import DummyData
from models.DummyModel import DummyModel
from servers.AverageParameterServer import AverageParameterServer
from trainers.DdpNcclTrainer import DdpNcclTrainer
from trainers.DdpSparseRpcTrainer import DdpSparseRpcTrainer
from trainers.DdpTrainer import DdpTrainer
trainer_map = {
"DdpNcclTrainer": DdpNcclTrainer,
"DdpTrainer": DdpTrainer,
"DdpSparseRpcTrainer": DdpSparseRpcTrainer
}
server_map = {
"AverageParameterServer": AverageParameterServer
}
model_map = {
"DummyModel": DummyModel
}
data_map = {
"DummyData": DummyData
}
def get_benchmark_trainer_map():
return trainer_map
def get_benchmark_server_map():
return server_map
def get_benchmark_model_map():
return model_map
def get_benchmark_data_map():
return data_map

View File

@ -2,19 +2,10 @@
"DummyData": {
"data_class": "DummyData",
"configurations": {
"max_val": 100,
"input_samples": 100,
"input_dim": 100,
"max_val": 1024,
"sample_count": 1024,
"sample_length": 1024,
"sparsity_percentage": 20
}
},
"DummyData2": {
"data_class": "DummyData",
"configurations": {
"max_val": 100,
"input_samples": 100,
"input_dim": 100,
"sparsity_percentage": 80
}
}
}

View File

@ -2,20 +2,22 @@
"DummyModel": {
"model_class": "DummyModel",
"configurations": {
"num_embeddings": 100,
"embedding_dim": 100,
"dense_input_size": 100,
"dense_output_size": 100,
"num_embeddings": 1024,
"embedding_dim": 1024,
"dense_input_size": 1024,
"dense_output_size": 1024,
"dense_layers_count": 8,
"sparse": false
}
},
"DummyModelSparse": {
"model_class": "DummyModel",
"configurations": {
"num_embeddings": 100,
"embedding_dim": 100,
"dense_input_size": 100,
"dense_output_size": 100,
"num_embeddings": 1024,
"embedding_dim": 1024,
"dense_input_size": 1024,
"dense_output_size": 1024,
"dense_layers_count": 8,
"sparse": true
}
}

View File

@ -1,6 +1,7 @@
import random
import numpy as np
import torch
from torch.utils.data import Dataset
@ -10,13 +11,22 @@ class DummyData(Dataset):
def __init__(
self,
max_val: int,
input_samples: int,
input_dim: int,
sample_count: int,
sample_length: int,
sparsity_percentage: int
):
r"""
A data class that generates random data.
Args:
max_val (int): the maximum value for an element
sample_count (int): count of training samples
sample_length (int): number of elements in a sample
sparsity_percentage (int): the percentage of
embeddings used by the input data in each iteration
"""
self.max_val = max_val
self.input_samples = input_samples
self.input_dim = input_dim
self.input_samples = sample_count
self.input_dim = sample_length
self.sparsity_percentage = sparsity_percentage
def generate_input():
@ -35,9 +45,7 @@ class DummyData(Dataset):
return torch.from_numpy(np.array(data))
self.input = generate_input()
self.target = torch.randint(0, max_val, [input_samples])
self.start = 0
self.end = max_val
self.target = torch.randint(0, max_val, [sample_count])
def __len__(self):
return len(self.input)

View File

@ -0,0 +1,5 @@
from .DummyData import DummyData
data_map = {
"DummyData": DummyData
}

View File

@ -1,23 +0,0 @@
#!/bin/sh
cd "$(dirname "$0")"
cd ..
python -u launcher.py \
--master_addr="localhost" \
--master_port="29500" \
--trainer="DdpSparseRpcTrainer" \
--ntrainer=2 \
--ncudatrainer=0 \
--filestore="/tmp/tmpn_k_8so02" \
--server="AverageParameterServer" \
--nserver=1 \
--ncudaserver=0 \
--rpc_timeout=30 \
--backend="nccl" \
--epochs=10 \
--batch_size=10 \
--data="DummyData" \
--model="DummyModelSparse" \
--data_config_path="configurations/data_configurations.json" \
--model_config_path="configurations/model_configurations.json"

View File

@ -1,23 +0,0 @@
#!/bin/sh
cd "$(dirname "$0")"
cd ..
python -u launcher.py \
--master_addr="localhost" \
--master_port="29500" \
--trainer="DdpSparseRpcTrainer" \
--ntrainer=0 \
--ncudatrainer=2 \
--filestore="/tmp/tmpn_k_8so02" \
--server="AverageParameterServer" \
--nserver=0 \
--ncudaserver=1 \
--rpc_timeout=30 \
--backend="nccl" \
--epochs=10 \
--batch_size=10 \
--data="DummyData" \
--model="DummyModelSparse" \
--data_config_path="configurations/data_configurations.json" \
--model_config_path="configurations/model_configurations.json"

View File

@ -1,22 +0,0 @@
#!/bin/sh
cd "$(dirname "$0")"
cd ..
python -u launcher.py \
--master_addr="localhost" \
--master_port="29500" \
--trainer="DdpTrainer" \
--ntrainer=2 \
--ncudatrainer=0 \
--filestore="/tmp/tmpn_k_8so02" \
--nserver=0 \
--ncudaserver=0 \
--rpc_timeout=30 \
--backend="nccl" \
--epochs=10 \
--batch_size=10 \
--data="DummyData" \
--model="DummyModel" \
--data_config_path="configurations/data_configurations.json" \
--model_config_path="configurations/model_configurations.json"

View File

@ -1,9 +1,22 @@
import argparse
import copy
import json
import os
from pathlib import Path
from data import data_map
from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
from models import model_map
from server import server_map
from trainer import (
criterion_map,
ddp_hook_map,
ddp_model_map,
hook_state_map,
iteration_step_map,
preprocess_data_map,
trainer_map,
)
import torch
import torch.distributed as c10d
import torch.distributed.rpc as rpc
@ -12,14 +25,15 @@ from torch.distributed.rpc import TensorPipeRpcBackendOptions
from torch.futures import wait_all
from torch.utils.data import DataLoader
from benchmark_class_helper import (get_benchmark_data_map,
get_benchmark_model_map,
get_benchmark_server_map,
get_benchmark_trainer_map)
from metrics.ProcessedMetricsPrinter import ProcessedMetricsPrinter
def get_name(rank, args):
r"""
A function that gets the name for the rank
argument
Args:
rank (int): process number in the world
args (parser): benchmark configurations
"""
t_count = args.ntrainer + args.ncudatrainer
s_count = args.nserver + args.ncudaserver
if rank < t_count:
@ -31,12 +45,26 @@ def get_name(rank, args):
def get_server_rank(args, rank):
r"""
A function that gets the server rank for
the rank argument.
Args:
args (parser): benchmark configurations
rank (int): trainer rank
"""
s_offset = args.ntrainer + args.ncudatrainer
tps = args.ntrainer // args.nserver
return rank // tps + s_offset
def get_cuda_server_rank(args, rank):
r"""
A function that gets the cudaserver rank for
the rank argument.
Args:
args (parser): benchmark configurations
rank (int): trainer rank
"""
s_offset = args.ntrainer + args.ncudatrainer + args.nserver
t_index = rank - args.ntrainer
ctps = args.ncudatrainer // args.ncudaserver
@ -44,7 +72,14 @@ def get_cuda_server_rank(args, rank):
def get_server_rref(server_rank, args, extra_args):
server = get_benchmark_server_map()[str(args.server)]
r"""
A function that creates a RRef to the server.
Args:
server_rank (int): process number in the world
args (parser): benchmark configurations
extra_args (dict): configurations added by the user
"""
server = server_map[args.server]
name = get_name(
server_rank,
args
@ -72,9 +107,19 @@ def get_server_rref(server_rank, args, extra_args):
def run_trainer(
args, extra_args, model, data, rank, server_rref
args, extra_args, data, rank, server_rref
):
trainer_class = get_benchmark_trainer_map()[str(args.trainer)]
r"""
A function that runs obtains a trainer instance and calls
the train method.
Args:
args (parser): benchmark configurations
extra_args (dict): configurations added by the user
data (list): training samples
rank (int): process number in the world
server_rrefs (dict): a dictionary containing server RRefs
"""
trainer_class = trainer_map[args.trainer]
if extra_args is not None:
trainer_args = extra_args.values()
else:
@ -89,15 +134,34 @@ def run_trainer(
process_group = c10d.ProcessGroupNCCL(
store, rank, trainer_count
)
elif args.backend == "multi":
process_group = c10d.ProcessGroupNCCL(
store, rank, trainer_count
)
if c10d.is_initialized() is False:
c10d.init_process_group(backend="gloo", rank=rank, world_size=trainer_count)
model = load_model(args)
preprocess_data = preprocess_data_map[args.preprocess_data]
create_criterion = criterion_map[args.create_criterion]
create_ddp_model = ddp_model_map[args.create_ddp_model]
iteration_step = iteration_step_map[args.iteration_step]
hook_state_class = hook_state_map[args.hook_state]
hook = ddp_hook_map[args.ddp_hook]
# check if this a cudatrainer
use_cuda_rpc = rank >= args.ntrainer
trainer = trainer_class(
rank,
args.ntrainer + args.ncudatrainer,
process_group,
use_cuda_rpc,
server_rref,
args.backend,
args.epochs,
preprocess_data,
create_criterion,
create_ddp_model,
hook_state_class,
hook,
iteration_step,
*trainer_args
)
trainer.train(model, data)
@ -105,7 +169,16 @@ def run_trainer(
return [rank, metrics]
def call_trainers(args, extra_args, model, train_data, server_rrefs):
def call_trainers(args, extra_args, train_data, server_rrefs):
r"""
A function that starts the trainers. Each trainer is started
using an rpc_async request.
Args:
args (parser): benchmark configurations
extra_args (dict): configurations added by the user
train_data (list): training samples
server_rrefs (dict): a dictionary containing server RRefs
"""
futs = []
for trainer_rank in range(0, args.ntrainer + args.ncudatrainer):
trainer_name = get_name(
@ -125,7 +198,6 @@ def call_trainers(args, extra_args, model, train_data, server_rrefs):
args=(
args,
extra_args,
copy.deepcopy(model),
train_data[trainer_rank],
trainer_rank,
server_rref,
@ -137,9 +209,18 @@ def call_trainers(args, extra_args, model, train_data, server_rrefs):
def benchmark_warmup(
args, extra_args, model, data, server_rrefs
args, extra_args, data, server_rrefs
):
futs = call_trainers(args, extra_args, model, data, server_rrefs)
r"""
A function that runs the training algorithm. The goal of this
function is to warm the rpc. The server states are reset.
Args:
args (parser): benchmark configurations
extra_args (dict): configurations added by the user
data (list): training samples
server_rrefs (dict): a dictionary containing server RRefs
"""
futs = call_trainers(args, extra_args, data, server_rrefs)
wait_all(futs)
for server_rref in server_rrefs.values():
server_rref.rpc_sync().reset_state(server_rref)
@ -147,10 +228,22 @@ def benchmark_warmup(
def split_list(arr, n):
r"""
A function that splits a list into n lists
Args:
arr (list): training samples
n (int): number of output lists
"""
return [arr[i::n] for i in range(n)]
def get_server_metrics(server_rrefs):
r"""
A function that calls the remote server to obtain metrics
collected during the benchmark run.
Args:
server_rrefs (dict): a dictionary containing server RRefs
"""
rank_metrics = []
for rank, server_rref in server_rrefs.items():
metrics = server_rref.rpc_sync().get_metrics(server_rref)
@ -158,7 +251,18 @@ def get_server_metrics(server_rrefs):
return rank_metrics
def run_master(rank, model, data, args, extra_configs, rpc_backend_options):
def run_master(rank, data, args, extra_configs, rpc_backend_options):
r"""
A function that runs the master process in the world. This function
obtains remote references to initialized servers, splits the data,
runs the trainers, and prints metrics.
Args:
rank (int): process number in the world
data (list): training samples
args (parser): benchmark configurations
extra_configs (dict): configurations added by the user
rpc_backend_options (rpc): configurations/options for the rpc TODO: fix
"""
world_size = args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
rpc.init_rpc(
get_name(
@ -181,21 +285,30 @@ def run_master(rank, model, data, args, extra_configs, rpc_backend_options):
# warmup run the benchmark
benchmark_warmup(
args, extra_configs["trainer_config"], model, train_data, server_rrefs
args, extra_configs["trainer_config"], train_data, server_rrefs
)
# run the benchmark
trainer_futs = call_trainers(
args, extra_configs["trainer_config"], model, train_data, server_rrefs
args, extra_configs["trainer_config"], train_data, server_rrefs
)
# collect metrics and print
metrics_printer = ProcessedMetricsPrinter()
rank_metrics_list = wait_all(trainer_futs)
metrics_printer.print_metrics("trainer", rank_metrics_list)
rank_metrics_list = get_server_metrics(server_rrefs)
metrics_printer.print_metrics("parameter server", rank_metrics_list)
metrics_printer.print_metrics("server", rank_metrics_list)
def run_benchmark(rank, model, data, args, config):
def run_benchmark(rank, args, data):
r"""
A function that runs the benchmark.
Args:
rank (int): process number in the world
args (parser): configuration args
data (list): training samples
"""
config = load_extra_configs(args)
torch.manual_seed(args.torch_seed)
torch.cuda.manual_seed_all(args.cuda_seed)
@ -208,7 +321,7 @@ def run_benchmark(rank, model, data, args, config):
rpc_backend_options = TensorPipeRpcBackendOptions(rpc_timeout=args.rpc_timeout)
if rank == world_size - 1:
# master = [ntrainer + ncudatrainer + nserver + ncudaserver, ntrainer + ncudatrainer + nserver + ncudaserver]
run_master(rank, model, data, args, config, rpc_backend_options)
run_master(rank, data, args, config, rpc_backend_options)
elif rank >= args.ntrainer + args.ncudatrainer:
# parameter_servers = [ntrainer + ncudatrainer, ntrainer + ncudatrainer + nserver + ncudaserver)
rpc.init_rpc(
@ -243,13 +356,25 @@ def run_benchmark(rank, model, data, args, config):
def get_json_config(file_name, id):
r"""
A function that loads a json configuration from a file.
Args:
file_name (str): name of configuration file to load
id (str): configuration that will be loaded
"""
with open(os.path.join(Path(__file__).parent, file_name), "r") as f:
json_config = json.load(f)[id]
return json_config
def load_extra_configs(args):
r"""
A function that creates a dictionary that contains any extra configurations
set by the user. The dictionary will contain two keys trainer_config and
server_config, with default values None.
Args:
args (parser): launcher configurations
"""
trainer_config_file = args.trainer_config_path
server_config_file = args.server_config_path
configurations = {
@ -263,30 +388,36 @@ def load_extra_configs(args):
return configurations
def get_data(data_class, data_config):
data_class = get_benchmark_data_map()[data_class]
return data_class(**data_config)
def load_data(args):
r"""
A function that creates an instance of the data class.
Args:
args (parser): launcher configurations
"""
data_config_file = args.data_config_path
data_config = get_json_config(data_config_file, args.data)
return get_data(data_config["data_class"], data_config["configurations"])
def get_model(model_class, model_config):
model_class = get_benchmark_model_map()[model_class]
return model_class(**model_config)
data_class = data_map[data_config["data_class"]]
return data_class(**data_config["configurations"])
def load_model(args):
r"""
A function that creates an instance of the model class.
Args:
args (parser): launcher configurations
"""
model_config_file = args.model_config_path
model_config = get_json_config(model_config_file, args.model)
return get_model(model_config["model_class"], model_config["configurations"])
model_class = model_map[model_config["model_class"]]
return model_class(**model_config["configurations"])
def main(args):
r"""
A function that creates multiple processes to run the benchmark.
Args:
args (parser): launcher configurations
"""
# CPU and RPC trainer checks
if args.ntrainer > 0 and args.ncudatrainer > 0:
assert args.nserver > 0 and args.ncudaserver > 0
@ -297,21 +428,17 @@ def main(args):
assert args.ncudatrainer > 0
assert args.ncudatrainer % args.ncudaserver == 0
extra_configs = load_extra_configs(args)
data = load_data(args)
model = load_model(args)
world_size = (
args.ntrainer + args.ncudatrainer + args.nserver + args.ncudaserver + 1
)
data = load_data(args)
mp.spawn(
run_benchmark,
args=(
model,
data,
args,
extra_configs,
data,
),
nprocs=world_size,
join=True
@ -383,7 +510,6 @@ if __name__ == "__main__":
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="number of training examples used in one iteration"
)
parser.add_argument(
@ -419,15 +545,44 @@ if __name__ == "__main__":
parser.add_argument(
"--torch_seed",
type=int,
default=0,
help="seed for generating random numbers to a non-deterministic random number"
)
parser.add_argument(
"--cuda_seed",
type=int,
default=0,
help="seed for generating random numbers to a random number for the current GPU"
)
parser.add_argument(
"--preprocess_data",
type=str,
help="this function will be used to preprocess data before training"
)
parser.add_argument(
"--create_criterion",
type=str,
help="this function will be used to create the criterion used for model loss calculation"
)
parser.add_argument(
"--create_ddp_model",
type=str,
help="this function will be used to create the ddp model used during training"
)
parser.add_argument(
"--hook_state",
type=str,
help="this will be the state class used when registering the ddp communication hook"
)
parser.add_argument(
"--ddp_hook",
type=str,
default="allreduce_hook",
help="ddp communication hook"
)
parser.add_argument(
"--iteration_step",
type=str,
help="this will be the function called for each iteration of training"
)
args = parser.parse_args()
print(f"{args}\n")
main(args)

View File

@ -9,13 +9,24 @@ class DummyModel(nn.Module):
embedding_dim: int,
dense_input_size: int,
dense_output_size: int,
dense_layers_count: int,
sparse: bool
):
r"""
A dummy model with an EmbeddingBag Layer and Dense Layer.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
dense_input_size (int): size of each input sample
dense_output_size (int): size of each output sample
dense_layers_count: (int): number of dense layers in dense Sequential module
sparse (bool): if True, gradient w.r.t. weight matrix will be a sparse tensor
"""
super().__init__()
self.embedding = nn.EmbeddingBag(
num_embeddings, embedding_dim, sparse=sparse
)
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(10)])
self.dense = nn.Sequential(*[nn.Linear(dense_input_size, dense_output_size) for _ in range(dense_layers_count)])
def forward(self, x):
x = self.embedding(x)

View File

@ -0,0 +1,5 @@
from .DummyModel import DummyModel
model_map = {
"DummyModel": DummyModel
}

View File

@ -0,0 +1,6 @@
from .server import AverageBatchParameterServer, AverageParameterServer
server_map = {
"AverageParameterServer": AverageParameterServer,
"AverageBatchParameterServer": AverageBatchParameterServer
}

View File

@ -0,0 +1,363 @@
import functools
import threading
import time
from abc import ABC, abstractmethod
from metrics.MetricsLogger import MetricsLogger
from utils import sparse_rpc_format_to_tensor, sparse_tensor_to_rpc_format
import torch
import torch.distributed.rpc as rpc
class ParameterServerBase(ABC):
PARAMETER_SERVER_BATCH_METRIC = "parameter_server_batch_metric"
PARAMETER_SERVER_STRAGGLER_METRIC = "parameter_server_straggler_metric"
PARAM_INDEX_STRAGGLER = "param_index_straggler"
PARAM_INDEX_BATCH = "param_index_batch"
def __init__(self, rank):
r"""
Inits ParameterServerBase class.
Args:
rank (int): worker rank
"""
self.__metrics_logger = MetricsLogger(rank)
@abstractmethod
def process_gradient(self):
r"""
A method to be implemented by child class that will process a
gradient received by a server.
"""
return
@staticmethod
@abstractmethod
def average_gradient():
r"""
A method to be implemented by child class that will average
gradients.
"""
return
@staticmethod
@abstractmethod
def reset_state():
r"""
A method to be implemented by child class that will reset
the server state.
"""
return
def record_start(self, type, key, name, cuda=True):
r"""
A method that records the start event for a metric.
Args:
type (str): group id for metric
key (str): unique id for metric within a group
name (str): description of the metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
type,
key,
name,
cuda
)
def record_end(self, type, key):
r"""
A method that records the end event for a metric
Args:
type (str): group id for metric
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
type,
key
)
def record_straggler_start(self, key, cuda=True):
r"""
A helper method that records a straggler metric
for the given key. A user should call this when
the first gradient for the param location is received.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.PARAMETER_SERVER_STRAGGLER_METRIC,
key,
self.PARAM_INDEX_STRAGGLER,
cuda
)
def record_straggler_end(self, key):
r"""
A helper method that records a straggler metric
for the given key. A user should call this when
the last gradient for the param location is received.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.PARAMETER_SERVER_STRAGGLER_METRIC,
key
)
def record_batch_start(self, key, cuda=True):
r"""
A helper method that records a batch metric
for the given key. A user should call this when
the first gradient for the param location is received.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.PARAMETER_SERVER_BATCH_METRIC,
key,
self.PARAM_INDEX_BATCH,
cuda
)
def record_batch_end(self, key):
r"""
A helper method that records a batch metric
for the given key. A user should call this when
all futures for a param location have had their
result set.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.PARAMETER_SERVER_BATCH_METRIC,
key
)
@staticmethod
def record_method(name, type="method_metric", cuda=True):
r"""
A decorator that records a metric for the decorated method.
Args:
name (str): description of the metric
type (str): group id for metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
def decorator(function):
@functools.wraps(function)
def wrapper(self, *args):
key = time.time()
self.__metrics_logger.record_start(type, key, name, cuda)
result = function(self, *args)
self.__metrics_logger.record_end(type, key)
return result
return wrapper
return decorator
@staticmethod
def get_metrics(server_rref):
r"""
A staticmethod that returns metrics captured by the __metrics_logger.
Args:
server_rref (RRef): remote reference to the server
"""
self = server_rref.local_value()
return self.__metrics_logger.get_processed_metrics()
def clear_metrics(self):
r"""
A method that clears __metrics_logger recorded metrics.
"""
return self.__metrics_logger.clear_metrics()
class AverageParameterServer(ParameterServerBase):
def __init__(
self,
rank,
trainer_count,
use_cuda_rpc
):
r"""
A parameter server that averages the gradients
from trainers for each training iteration step.
Gradients are added as they are received from trainers.
When all gradients have been received, the sum is
divided by the number of trainers.
Args:
rank (int): worker rank
trainer_count (int): count of trainers sending
gradients to the server
use_cuda_rpc (bool): indicator for CUDA RPC
"""
super().__init__(rank)
self.lock = threading.Lock()
self.rank = rank
self.trainer_count = trainer_count
self.use_cuda_rpc = use_cuda_rpc
self.batch_number = 0
self.futures = {}
self.gradient_dict = {}
@staticmethod
def reset_state(server_rref):
r"""
A method that clears the state of the server.
Args:
server_rref (RRef): remote reference to the server
"""
self = server_rref.local_value()
self.batch_number = 0
self.futures.clear()
self.gradient_dict.clear()
self.clear_metrics()
def param_key(self, param_loc):
r"""
A method that returns an encoded key that represents
the current batch and param location.
Args:
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
return f"{self.batch_number},{param_loc}"
def clear_batch_state(self):
r"""
Clears the current server batch state.
"""
self.futures.clear()
self.gradient_dict.clear()
def process_gradient(self, gradient, param_loc):
r"""
Stores the gradient if param_loc is not in gradient_dict.
Adds the gradient to param_loc if it is in gradient_dict.
Args:
gradient (torch.Tensor): tensor sent from trainer
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
if param_loc not in self.gradient_dict:
self.record_straggler_start(self.param_key(param_loc))
self.record_batch_start(self.param_key(param_loc))
self.gradient_dict[param_loc] = gradient
else:
self.gradient_dict[param_loc] += gradient
@ParameterServerBase.record_method(name="average computation")
def average(self, param_loc):
r"""
Obtains the tensor at the param_loc in the gradient_dict
and then divides by number of trainers.
Args:
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
param_loc_avg = self.gradient_dict[param_loc]
param_loc_avg / (1.0 * self.trainer_count)
return param_loc_avg
@staticmethod
@rpc.functions.async_execution
def average_gradient(
server_rref,
received_batch_number,
param_loc,
gradient
):
r"""
An asynchronous function that will average gradients
sent from trainers.
Args:
server_rref (RRef): remote reference to the server
received_batch_number (int): batch number sent by
the trainer
param_loc (int): bucket location sent by the trainer
containing the gradient
gradient (torch.Tensor or list): tensor sent by the trainer
"""
self = server_rref.local_value()
if type(gradient) is list:
gradient = sparse_rpc_format_to_tensor(gradient)
gradient = gradient.cuda(self.rank)
fut = torch.futures.Future()
with self.lock:
if self.batch_number < received_batch_number:
self.batch_number = received_batch_number
self.clear_batch_state()
self.process_gradient(gradient, param_loc)
if param_loc not in self.futures:
self.futures[param_loc] = []
self.futures[param_loc].append(fut)
if len(self.futures[param_loc]) == self.trainer_count:
self.record_straggler_end(self.param_key(param_loc))
param_loc_avg = self.average(param_loc)
if not self.use_cuda_rpc:
param_loc_avg = param_loc_avg.cpu()
if param_loc_avg.is_sparse:
param_loc_avg = sparse_tensor_to_rpc_format(param_loc_avg)
for cur_fut in self.futures[param_loc]:
cur_fut.set_result(param_loc_avg)
self.record_batch_end(self.param_key(param_loc))
return fut
class AverageBatchParameterServer(AverageParameterServer):
def __init__(
self,
rank,
trainer_count,
use_cuda_rpc
):
r"""
A parameter server that averages the gradients
from trainers for each training iteration step.
Gradients are stored and averaged when a gradient
has been received from each trainer for a param
location.
Args:
rank (int): worker rank
trainer_count (int): count of trainers sending
gradients to the server
use_cuda_rpc (bool): indicator for CUDA RPC
"""
super().__init__(rank, trainer_count, use_cuda_rpc)
def process_gradient(self, gradient, param_loc):
r"""
Adds the gradient to param_loc bucket stored in
the gradient_dict.
Args:
gradient (torch.Tensor): tensor sent from trainer
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
if param_loc not in self.gradient_dict:
self.record_straggler_start(self.param_key(param_loc))
self.record_batch_start(self.param_key(param_loc))
self.gradient_dict[param_loc] = []
self.gradient_dict[param_loc].append(gradient)
@ParameterServerBase.record_method(name="average computation")
def average(self, param_loc):
r"""
Sums the gradients at the param_loc then divides by the
number of trainers.
Args:
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
param_loc_avg = self.gradient_dict[param_loc][0]
for gradient in self.gradient_dict[param_loc][1:]:
param_loc_avg += gradient
param_loc_avg / (1.0 * self.trainer_count)
return param_loc_avg

View File

@ -1,143 +0,0 @@
import threading
import torch
import torch.distributed.rpc as rpc
from utils import sparse_rpc_format_to_tensor, sparse_tensor_to_rpc_format
from .ParameterServerBase import ParameterServerBase
class AverageParameterServer(ParameterServerBase):
lock = threading.Lock()
def __init__(
self,
rank,
trainer_count,
use_cuda_rpc
):
r"""
A parameter server that averages the gradients
from trainers for each training iteration step.
Gradients are added as they are received from trainers.
When all gradients have been received, the sum is
divided by the number of trainers.
Args:
rank (int): worker rank
trainer_count (int): count of trainers sending
gradients to the server
use_cuda_rpc (bool): indicator for CUDA RPC
"""
super().__init__(rank)
self.rank = rank
self.trainer_count = trainer_count
self.use_cuda_rpc = use_cuda_rpc
self.batch_number = 0
self.futures = {}
self.gradient_dict = {}
@staticmethod
def reset_state(server_rref):
r"""
A method that clears the state of the server.
Args:
server_rref (RRef): remote reference to the server
"""
self = server_rref.local_value()
self.batch_number = 0
self.futures.clear()
self.gradient_dict.clear()
self.clear_metrics()
def param_key(self, param_loc):
r"""
A method that returns an encoded key that represents
the current batch and param location.
Args:
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
return f"{self.batch_number},{param_loc}"
def clear_batch_state(self):
r"""
Clears the current server batch state.
"""
self.futures.clear()
self.gradient_dict.clear()
def process_gradient(self, gradient, param_loc):
r"""
Stores the gradient if param_loc is not in gradient_dict.
Adds the gradient to param_loc if it is in gradient_dict.
Args:
gradient (torch.Tensor): tensor sent from trainer
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
if param_loc not in self.gradient_dict:
self.record_straggler_start(self.param_key(param_loc))
self.record_batch_start(self.param_key(param_loc))
self.gradient_dict[param_loc] = gradient
else:
self.gradient_dict[param_loc] += gradient
@ParameterServerBase.record_method(name="average computation")
def average(self, param_loc):
r"""
Obtains the tensor at the param_loc in the gradient_dict
and then divides by number of trainers.
Args:
param_loc (int): bucket location sent by the trainer
containing the gradient
"""
param_loc_avg = self.gradient_dict[param_loc]
param_loc_avg / (1.0 * self.trainer_count)
return param_loc_avg
@staticmethod
@rpc.functions.async_execution
def average_gradient(
server_rref,
received_batch_number,
param_loc,
gradient
):
r"""
An asynchronous function that will average gradients
sent from trainers.
Args:
server_rref (RRef): remote reference to the server
received_batch_number (int): batch number sent by
the trainer
param_loc (int): bucket location sent by the trainer
containing the gradient
gradient (torch.Tensor or list): tensor sent by the trainer
"""
self = server_rref.local_value()
if type(gradient) is list:
gradient = sparse_rpc_format_to_tensor(gradient)
gradient = gradient.cuda(self.rank)
fut = torch.futures.Future()
with self.lock:
if self.batch_number < received_batch_number:
self.batch_number = received_batch_number
self.clear_batch_state()
self.process_gradient(gradient, param_loc)
if param_loc not in self.futures:
self.futures[param_loc] = []
self.futures[param_loc].append(fut)
if len(self.futures[param_loc]) == self.trainer_count:
self.record_straggler_end(self.param_key(param_loc))
param_loc_avg = self.average(param_loc)
if not self.use_cuda_rpc:
param_loc_avg = param_loc_avg.cpu()
if param_loc_avg.is_sparse:
param_loc_avg = sparse_tensor_to_rpc_format(param_loc_avg)
for cur_fut in self.futures[param_loc]:
cur_fut.set_result(param_loc_avg)
self.record_batch_end(self.param_key(param_loc))
return fut

View File

@ -1,170 +0,0 @@
import functools
import time
from abc import ABC, abstractmethod
from metrics.MetricsLogger import MetricsLogger
class ParameterServerBase(ABC):
PARAMETER_SERVER_BATCH_METRIC = "parameter_server_batch_metric"
PARAMETER_SERVER_STRAGGLER_METRIC = "parameter_server_straggler_metric"
PARAM_INDEX_STRAGGLER = "param_index_straggler"
PARAM_INDEX_BATCH = "param_index_batch"
def __init__(self, rank):
r"""
Inits ParameterServerBase class.
Args:
rank (int): worker rank
"""
self.__metrics_logger = MetricsLogger(rank)
@abstractmethod
def process_gradient(self):
r"""
A method to be implemented by child class that will process a
gradient received by a server.
"""
return
@staticmethod
@abstractmethod
def average_gradient():
r"""
A method to be implemented by child class that will average
gradients.
"""
return
@staticmethod
@abstractmethod
def reset_state():
r"""
A method to be implemented by child class that will reset
the server state.
"""
return
def record_start(self, type, key, name, cuda=True):
r"""
A method that records the start event for a metric.
Args:
type (str): group id for metric
key (str): unique id for metric within a group
name (str): description of the metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
type,
key,
name,
cuda
)
def record_end(self, type, key):
r"""
A method that records the end event for a metric
Args:
type (str): group id for metric
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
type,
key
)
def record_straggler_start(self, key, cuda=True):
r"""
A helper method that records a straggler metric
for the given key. A user should call this when
the first gradient for the param location is received.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.PARAMETER_SERVER_STRAGGLER_METRIC,
key,
self.PARAM_INDEX_STRAGGLER,
cuda
)
def record_straggler_end(self, key):
r"""
A helper method that records a straggler metric
for the given key. A user should call this when
the last gradient for the param location is received.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.PARAMETER_SERVER_STRAGGLER_METRIC,
key
)
def record_batch_start(self, key, cuda=True):
r"""
A helper method that records a batch metric
for the given key. A user should call this when
the first gradient for the param location is received.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.PARAMETER_SERVER_BATCH_METRIC,
key,
self.PARAM_INDEX_BATCH,
cuda
)
def record_batch_end(self, key):
r"""
A helper method that records a batch metric
for the given key. A user should call this when
all futures for a param location have had their
result set.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(
self.PARAMETER_SERVER_BATCH_METRIC,
key
)
@staticmethod
def record_method(name, type="method_metric", cuda=True):
r"""
A decorator that records a metric for the decorated method.
Args:
name (str): description of the metric
type (str): group id for metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
def decorator(function):
@functools.wraps(function)
def wrapper(self, *args):
key = time.time()
self.__metrics_logger.record_start(type, key, name, cuda)
result = function(self, *args)
self.__metrics_logger.record_end(type, key)
return result
return wrapper
return decorator
@staticmethod
def get_metrics(server_rref):
r"""
A staticmethod that returns metrics captured by the __metrics_logger.
Args:
server_rref (RRef): remote reference to the server
"""
self = server_rref.local_value()
return self.__metrics_logger.get_processed_metrics()
def clear_metrics(self):
r"""
A method that clears __metrics_logger recorded metrics.
"""
return self.__metrics_logger.clear_metrics()

View File

@ -1,31 +0,0 @@
import subprocess
from os.path import join
from pathlib import Path
script_dir = join(
Path(__file__).parent, "experiment_scripts"
)
encoding = 'utf-8'
def run_script(script_name):
# runs the script and asserts that there are no errors
p = subprocess.run(
["bash", f"{join(script_dir,script_name)}"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
error = p.stderr.decode(encoding)
assert not error
def test_ddp_nccl_allreduce():
run_script("ddp_nccl_allreduce.sh")
def test_ddp_cpu_sparse_rpc_nccl_allreduce():
run_script("ddp_cpu_sparse_rpc_nccl_allreduce.sh")
def test_ddp_cuda_sparse_rpc_nccl_allreduce():
run_script("ddp_cuda_sparse_rpc_nccl_allreduce.sh")

View File

@ -0,0 +1,30 @@
from .criterions import cel
from .ddp_models import basic_ddp_model
from .hook_states import BasicHookState
from .iteration_steps import basic_iteration_step
from .preprocess_data import preprocess_dummy_data
from .trainer import DdpTrainer
criterion_map = {
"cel": cel
}
ddp_model_map = {
"basic_ddp_model": basic_ddp_model
}
iteration_step_map = {
"basic_iteration_step": basic_iteration_step
}
preprocess_data_map = {
"preprocess_dummy_data": preprocess_dummy_data
}
hook_state_map = {
"BasicHookState": BasicHookState
}
trainer_map = {
"DdpTrainer": DdpTrainer
}

View File

@ -0,0 +1,10 @@
import torch.nn as nn
def cel(rank):
r"""A function that creates a CrossEntropyLoss
criterion for training.
Args:
rank (int): worker rank
"""
return nn.CrossEntropyLoss().cuda(rank)

View File

@ -0,0 +1,23 @@
from torch.nn.parallel import DistributedDataParallel as DDP
def basic_ddp_model(self, rank, model, process_group, hook_state, hook):
r"""
A function that creates a ddp_model and hook_state objects.
The ddp model is is initialized with a single device id and
the process group. The ddp_model also registers the communication
hook.
Args:
rank (int): worker rank
model (nn.Module): neural network model
process_group (ProcessGroup): distributed process group
HookState (class): class that will be used to keep track of state
during training.
hook (function): ddp communication hook
"""
ddp_model = DDP(
model, device_ids=[rank], process_group=process_group
)
hook_state = hook_state(self, process_group)
ddp_model.register_comm_hook(hook_state, hook)
return ddp_model, hook_state

View File

@ -0,0 +1,28 @@
class BasicHookState:
def __init__(self, cref, process_group):
r"""
A class that holds state information that is needed by the communication hook
during the training algorithm.
Args:
cref (DdpTrainer): reference to the self keyword of the trainer instance
process_group (ProcessGroup): distributed process group
"""
self.cref = cref
self.process_group = process_group
self.batch_number = -1
def get_key(self, bucket_index):
r"""
A method that returns an encoded key that represents the current batch and
bucket index.
Args:
bucket_index (int): index of the bucket being processed in backward
"""
return f"{self.batch_number},{bucket_index}"
def next_batch(self):
r"""
A method that increments batch_number by 1.
"""
self.batch_number += 1

View File

@ -0,0 +1,23 @@
def basic_iteration_step(self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch):
r"""
A function that performs an iteration of training.
Args:
ddp_model (nn.Module): distributed data parallel model
criterion (nn.Module): loss function to measure model
optimizer (optim.Optimizer): updates model parameters
hook_state (object): ddp communication hook state object
epoch (int): index of pass through the data
index (int): iteration number - 1 in current batch
batch (list): training examples
"""
hook_state.next_batch()
self.record_batch_start(self.epoch_key(epoch, index))
optimizer.zero_grad()
self.record_forward_start(self.epoch_key(epoch, index))
loss = criterion(ddp_model(batch[0]), batch[1])
self.record_forward_end(self.epoch_key(epoch, index))
self.record_backward_start(self.epoch_key(epoch, index))
loss.backward()
self.record_backward_end(self.epoch_key(epoch, index))
optimizer.step()
self.record_batch_end(self.epoch_key(epoch, index))

View File

@ -0,0 +1,12 @@
def preprocess_dummy_data(rank, data):
r"""
A function that moves the data from CPU to GPU
for DummyData class.
Args:
rank (int): worker rank
data (list): training examples
"""
for i in range(len(data)):
data[i][0] = data[i][0].cuda(rank)
data[i][1] = data[i][1].cuda(rank)
return data

View File

@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from metrics.MetricsLogger import MetricsLogger
import torch
class TrainerBase(ABC):
@ -175,3 +177,88 @@ class TrainerBase(ABC):
A method that clears __metrics_logger recorded metrics.
"""
return self.__metrics_logger.clear_metrics()
class DdpTrainer(TrainerBase):
def __init__(
self,
process_group,
use_cuda_rpc,
server_rref,
backend,
epochs,
preprocess_data,
create_criterion,
create_ddp_model,
hook_state_class,
hook,
iteration_step
):
r"""
A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
using the process_group implementation.
Args:
process_group (ProcessGroup): distributed process group
use_cuda_rpc (bool): indicator for CUDA RPC
server_rref (RRef): remote reference to the server
backend (str): distributed communication backend
epochs (int): epoch count for training
preprocess_data (function): preprocesses data passed
to the trainer before starting training
create_criterion (function): creates a criterion to calculate loss
create_ddp_model (function): creates a ddp model for the trainer
hook_state_class (class): class that will be used to keep tracking of state
during training.
hook (function): ddp communication hook
iteration_step (function): will perform 1 step of training
"""
super().__init__(process_group.rank())
self.process_group = process_group
self.use_cuda_rpc = use_cuda_rpc
self.server_rref = server_rref
self.backend = backend
self.epochs = epochs
self.preprocess_data = preprocess_data
self.create_criterion = create_criterion
self.create_ddp_model = create_ddp_model
self.hook_state_class = hook_state_class
self.hook = hook
self.iteration_step = iteration_step
self.rank = process_group.rank()
self.trainer_count = process_group.size()
def epoch_key(self, epoch, index):
r"""
A method that returns an encoded key that represents the current epoch and
iteration index.
Args:
epoch (int): epoch index
index (int): iteration index
"""
return f"{epoch},{index}"
def train(self, model, data):
r"""
A method that implements the training algorithm.
Args:
model (nn.Module): neural network model
data (list): training examples
"""
model = model.cuda(self.rank)
data = self.preprocess_data(self.rank, data)
criterion = self.create_criterion(self.rank)
ddp_model, hook_state = self.create_ddp_model(
self, self.rank, model, self.process_group, self.hook_state_class, self.hook
)
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
for epoch in range(self.epochs):
if epoch % 5 == 0 and self.rank == 0:
print(f"train epoch={epoch}")
for index, batch in enumerate(data):
self.iteration_step(
self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch
)
torch.cuda.synchronize(self.rank)

View File

@ -1,107 +0,0 @@
import torch
import torch.distributed as c10d
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from .DdpTrainerBase import DdpTrainerBase
class DdpNcclTrainer(DdpTrainerBase):
class HookState:
def __init__(self, cref, process_group):
self.cref = cref
self.process_group = process_group
self.process_group_size = process_group.size()
self.param_location = 0
self.batch_number = -1
def get_key(self):
return f"{self.batch_number},{self.param_location}"
def next_batch_state(self):
self.param_location = 0
self.batch_number += 1
def __init__(self, rank, trainer_count, ps_rref, epochs):
super().__init__(rank)
self.rank = rank
self.trainer_count = trainer_count
self.epochs = epochs
@staticmethod
def hook(state, bucket):
cref = state.cref
tensors_count = len(cref.bucket_to_parameters(bucket))
tensors = [bucket.get_tensor() / state.process_group_size]
key = state.get_key()
cref.record_hook_fut_start(key, cref.NCCL_ALLREDUCE)
fut = state.process_group.allreduce(tensors).get_future()
state.param_location += tensors_count
def callback(fut):
cref.record_hook_fut_end(key)
return fut.wait()
return fut.then(callback)
def train(self, model, data):
torch.manual_seed(0)
model = model.cuda(self.rank)
for i in range(len(data)):
data[i][0] = data[i][0].cuda(self.rank)
data[i][1] = data[i][1].cuda(self.rank)
torch.cuda.synchronize(self.rank)
process_group_size = self.trainer_count
store = c10d.FileStore("/tmp/tmpn_k_8so02", process_group_size)
process_group = c10d.ProcessGroupNCCL(
store, self.rank, process_group_size
)
ddp_model = DDP(
model, device_ids=[self.rank], process_group=process_group
)
hook_state = self.HookState(self, process_group)
ddp_model.register_comm_hook(hook_state, DdpNcclTrainer.hook)
criterion = nn.CrossEntropyLoss().cuda(self.rank)
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
def epoch_key(epoch, index):
return f"{epoch},{index}"
for epoch in range(self.epochs):
for index, batch in enumerate(data):
hook_state.next_batch_state()
input, target = batch[0], batch[1]
self.record_batch_start(epoch_key(epoch, index))
optimizer.zero_grad()
self.record_forward_start(epoch_key(epoch, index))
out = ddp_model(input)
self.record_forward_end(epoch_key(epoch, index))
loss = criterion(out, target)
self.record_backward_start(epoch_key(epoch, index))
loss.backward()
self.record_backward_end(epoch_key(epoch, index))
optimizer.step()
self.record_batch_end(epoch_key(epoch, index))
torch.cuda.synchronize(self.rank)

View File

@ -1,54 +0,0 @@
from utils import process_bucket_with_remote_server
from .DdpTrainer import DdpTrainer
class DdpSparseRpcTrainer(DdpTrainer):
def __init__(self, rank, trainer_count, process_group, use_cuda_rpc, server_rref, backend, epochs):
r"""
A trainer that implements a DDP training algorithm using a server and process group
allreduce. The trainer sends sparse gradients using RPC, and the server averages and
returns the gradients. The process group uses the backend allreduce implementation
to average the dense gradients.
Args:
rank (int): worker rank
trainer_count (int): count of trainer in the world
process_group (ProcessGroup): distributed process group
use_cuda_rpc (bool): indicator for CUDA RPC
server_rref (RRef): remote reference to the server
backend (str): distributed communication backend
epochs (int): epoch count for training
"""
super().__init__(rank, trainer_count, process_group, use_cuda_rpc, server_rref, backend, epochs)
@staticmethod
def hook(state, bucket):
r"""
A ddp communication hook that uses the current backend allreduce
implementation for dense tensors and a server for sparse tensors.
Args:
state (object): maintains state during the training process
bucket (GradBucket): gradient bucket
"""
tensor = bucket.get_tensor()
if tensor.is_sparse:
return process_bucket_with_remote_server(state, bucket)
else:
cref = state.cref
tensor = [tensor / state.process_group.size()]
key = state.get_key(bucket.get_index())
cref.record_hook_fut_start(key, f"{cref.backend}_dense_allreduce")
fut = state.process_group.allreduce(tensor).get_future()
def callback(fut):
cref.record_hook_fut_end(key)
return fut.wait()
return fut.then(callback)
def get_hook(self):
r"""
returns DdpSparseRpcTrainer.hook
"""
return DdpSparseRpcTrainer.hook

View File

@ -1,183 +0,0 @@
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from .DdpTrainerBase import DdpTrainerBase
class DdpTrainer(DdpTrainerBase):
class HookState:
def __init__(self, cref, process_group):
r"""
A class that holds state information that is needed by the communication hook
during the training algorithm.
Args:
cref (DdpTrainer): reference to the self keyword of the trainer instance
process_group (ProcessGroup): distributed process group
"""
self.cref = cref
self.process_group = process_group
self.batch_number = -1
def get_key(self, bucket_index):
r"""
A method that returns an encoded key that represents the current batch and
bucket index.
Args:
bucket_index (int): index of the bucket being processed in backward
"""
return f"{self.batch_number},{bucket_index}"
def next_batch(self):
r"""
A method that increments batch_number by 1.
"""
self.batch_number += 1
def __init__(self, rank, trainer_count, process_group, use_cuda_rpc, server_rref, backend, epochs):
r"""
A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
using the process_group implementation.
Args:
rank (int): worker rank
trainer_count (int): count of trainer in the world
process_group (ProcessGroup): distributed process group
use_cuda_rpc (bool): indicator for CUDA RPC
server_rref (RRef): remote reference to the server
backend (str): distributed communication backend
epochs (int): epoch count for training
"""
super().__init__(rank)
self.rank = rank
self.trainer_count = trainer_count
self.process_group = process_group
self.use_cuda_rpc = use_cuda_rpc
self.server_rref = server_rref
self.backend = backend
self.epochs = epochs
@staticmethod
def hook(state, bucket):
r"""
A ddp communication hook that uses the process_group allreduce implementation.
Args:
state (object): maintains state during the training process
bucket (GradBucket): gradient bucket
"""
cref = state.cref
tensors = [bucket.get_tensor() / state.process_group.size()]
key = state.get_key(bucket.get_index())
cref.record_hook_fut_start(key, f"{cref.backend}_allreduce")
fut = state.process_group.allreduce(tensors).get_future()
def callback(fut):
cref.record_hook_fut_end(key)
return fut.wait()
return fut.then(callback)
def get_hook(self):
r"""
returns DdpTrainer.hook
"""
return DdpTrainer.hook
def create_ddp_model(self, model):
r"""
A method that creates a ddp_model and hook_state objects.
It returns the ddp_model and hook_state objects.
Args:
model (nn.Module): neural network model
"""
ddp_model = DDP(
model, device_ids=[self.rank], process_group=self.process_group
)
hook_state = self.HookState(self, self.process_group)
ddp_model.register_comm_hook(hook_state, self.get_hook())
return ddp_model, hook_state
def create_criterion(self):
r"""
A method that creates a criterion for the training
algorithm.
"""
return nn.CrossEntropyLoss().cuda(self.rank)
def create_optimizer(self, parameters, lr):
r"""
A method that creates a optimizer for the training
algorithm.
Args:
parameters (iterable): iterable of parameters to optimize
lr (float): learning rate
"""
return torch.optim.SGD(parameters, lr)
def epoch_key(self, epoch, index):
r"""
A method that returns an encoded key that represents the current epoch and
iteration index.
Args:
epoch (int): epoch index
index (int): iteration index
"""
return f"{epoch},{index}"
def preprocess_data(self, data):
r"""
A method that moves the data from CPU to GPU.
Args:
data (list): training examples
"""
for i in range(len(data)):
data[i][0] = data[i][0].cuda(self.rank)
data[i][1] = data[i][1].cuda(self.rank)
return data
def iteration_step(self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch):
r"""
A method that performs an iteration of training.
Args:
ddp_model (nn.Module): distributed data parallel model
criterion (nn.Module): loss function to measure model
optimizer (optim.Optimizer): updates model parameters
hook_state (object): ddp communication hook state object
epoch (int): index of pass through the data
index (int): iteration number - 1 in current batch
batch (list): training examples
"""
hook_state.next_batch()
input, target = batch[0], batch[1]
self.record_batch_start(self.epoch_key(epoch, index))
optimizer.zero_grad()
self.record_forward_start(self.epoch_key(epoch, index))
out = ddp_model(input)
self.record_forward_end(self.epoch_key(epoch, index))
loss = criterion(out, target)
self.record_backward_start(self.epoch_key(epoch, index))
loss.backward()
self.record_backward_end(self.epoch_key(epoch, index))
optimizer.step()
self.record_batch_end(self.epoch_key(epoch, index))
def train(self, model, data):
r"""
A method that implements the training algorithm.
Args:
model (nn.Module): neural network model
data (list): training examples
"""
model = model.cuda(self.rank)
data = self.preprocess_data(data)
ddp_model, hook_state = self.create_ddp_model(model)
criterion = self.create_criterion()
optimizer = self.create_optimizer(ddp_model.parameters(), 1e-4)
for epoch in range(self.epochs):
for index, batch in enumerate(data):
self.iteration_step(
ddp_model, criterion, optimizer, hook_state, epoch, index, batch
)
torch.cuda.synchronize(self.rank)

View File

@ -1,49 +0,0 @@
from abc import abstractmethod
from .TrainerBase import TrainerBase
class DdpTrainerBase(TrainerBase):
HOOK_FUTURE_METRIC = "hook_future_metric"
def __init__(self, rank):
r"""
Inits DdpTrainerBase class.
Args:
rank (int): worker rank
"""
super().__init__(rank)
@staticmethod
@abstractmethod
def hook(state, bucket):
r"""
A method to be implemented by child class that will implement a DDP
training algorithm.
Args:
state (object): maintains state during the training process
bucket (GradBucket): gradient bucket
"""
return
def record_hook_fut_start(self, key, name, cuda=True):
r"""
A helper method that records a hook future metric
for the given key. A user should call this before
sending async request in the DDP communication hook.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.record_start(self.HOOK_FUTURE_METRIC, key, name, cuda)
def record_hook_fut_end(self, key):
r"""
A helper method that records a hook future metric
for the given key. A user should call this in a callback
attached to the future returned by an async request.
Args:
key (str): unique id for metric within a group
"""
self.record_end(self.HOOK_FUTURE_METRIC, key)

View File

@ -6,7 +6,7 @@ RPC_DENSE = "rpc_dense"
def sparse_tensor_to_rpc_format(sparse_tensor):
r"""
A helper method creates a list containing the indices, values, and size
A helper function creates a list containing the indices, values, and size
of a coalesced sparse tensor.
Args:
sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
@ -17,7 +17,7 @@ def sparse_tensor_to_rpc_format(sparse_tensor):
def sparse_rpc_format_to_tensor(sparse_rpc_format):
r"""
A helper method creates a sparse_coo_tensor from indices, values, and size.
A helper function creates a sparse_coo_tensor from indices, values, and size.
Args:
sparse_rpc_format (list): sparse_coo_tensor represented as a list
"""
@ -50,14 +50,15 @@ def process_bucket_with_remote_server(state, bucket):
tensor
]
key = state.get_key(b_index)
cref.record_hook_fut_start(
cref.record_start(
"hook_future_metric",
key,
RPC_SPARSE if sparse else RPC_DENSE
)
fut = cref.server_rref.rpc_async().average_gradient(*server_args)
def callback(fut):
cref.record_hook_fut_end(key)
cref.record_end("hook_future_metric", key)
tensor = fut.wait()
if type(tensor) is list:
tensor = sparse_rpc_format_to_tensor(tensor)