Files
pytorch/benchmarks/distributed/rpc/parameter_server/trainer/trainer.py
Xuehai Pan 7763c83af6 [5/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort torch (#127126)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
2024-05-27 04:22:18 +00:00

247 lines
8.2 KiB
Python

import functools
import time
from abc import ABC, abstractmethod
from metrics.MetricsLogger import MetricsLogger
import torch
class TrainerBase(ABC):
BATCH_LEVEL_METRIC = "batch_level_metric"
BATCH_ALL = "batch_all"
FORWARD_METRIC = "forward_metric"
FORWARD_PASS = "forward_pass"
BACKWARD_METRIC = "backward_metric"
BACKWARD = "backward"
def __init__(self, rank):
r"""
Inits TrainerBase class.
Args:
rank (int): worker rank
"""
self.__metrics_logger = MetricsLogger(rank)
@abstractmethod
def train(self):
r"""
A method to be implemented by child class that will train a neural network.
"""
return
def record_start(self, type, key, name, cuda=True):
r"""
A method that records the start event for a metric.
Args:
type (str): group id for metric
key (str): unique id for metric within a group
name (str): description of the metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(type, key, name, cuda)
def record_end(self, type, key):
r"""
A method that records the end event for a metric.
Args:
type (str): group id for metric
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(type, key)
def record_batch_start(self, key, cuda=True):
r"""
A helper method that records a batch metric for the
given key. A user should call this at the start of an
iteration step during training.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.BATCH_LEVEL_METRIC, key, self.BATCH_ALL, cuda
)
def record_batch_end(self, key):
r"""
A helper method that records a batch metric for the
given key. A user should call this at the end of an
iteration step during training.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(self.BATCH_LEVEL_METRIC, key)
def record_forward_start(self, key, cuda=True):
r"""
A helper method that records a forward metric
for the given key. A user should call this before
their neural network forward.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.FORWARD_METRIC, key, self.FORWARD_PASS, cuda
)
def record_forward_end(self, key):
r"""
A helper method that records a forward metric
for the given key. A user should call this after their
neural network forward.
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(self.FORWARD_METRIC, key)
def record_backward_start(self, key, cuda=True):
r"""
A helper method that records a backward metric
for the given key. A user should call this before
their .backward() call.
Args:
key (str): unique id for metric within a group
cuda (bool): indicator to determine if this is a CUDA metric
"""
self.__metrics_logger.record_start(
self.BACKWARD_METRIC, key, self.BACKWARD, cuda
)
def record_backward_end(self, key):
r"""
A helper method that records a backward metric
for the given key. A user should call this after
.backward().
Args:
key (str): unique id for metric within a group
"""
self.__metrics_logger.record_end(self.BACKWARD_METRIC, key)
@staticmethod
def methodmetric(name, type="method_metric", cuda=True):
r"""
A decorator that records a metric for the decorated method.
Args:
name (str): description of the metric
type (str): group id for metric
cuda (bool): indicator to determine if this is a CUDA metric
"""
def decorator(function):
@functools.wraps(function)
def wrapper(self, *args):
key = time.time()
self.__metrics_logger.record_start(type, key, name, cuda)
result = function(self, *args)
self.__metrics_logger.record_end(type, key)
return result
return wrapper
return decorator
def get_metrics(self):
r"""
A method that returns metrics captured by the __metrics_logger.
"""
return self.__metrics_logger.get_processed_metrics()
def clear_metrics(self):
r"""
A method that clears __metrics_logger recorded metrics.
"""
return self.__metrics_logger.clear_metrics()
class DdpTrainer(TrainerBase):
def __init__(
self,
process_group,
use_cuda_rpc,
server_rref,
backend,
epochs,
preprocess_data,
create_criterion,
create_ddp_model,
hook_state_class,
hook,
iteration_step,
):
r"""
A trainer that implements a DDP training algorithm using a simple hook that performs allreduce
using the process_group implementation.
Args:
process_group (ProcessGroup): distributed process group
use_cuda_rpc (bool): indicator for CUDA RPC
server_rref (RRef): remote reference to the server
backend (str): distributed communication backend
epochs (int): epoch count for training
preprocess_data (function): preprocesses data passed
to the trainer before starting training
create_criterion (function): creates a criterion to calculate loss
create_ddp_model (function): creates a ddp model for the trainer
hook_state_class (class): class that will be used to keep tracking of state
during training.
hook (function): ddp communication hook
iteration_step (function): will perform 1 step of training
"""
super().__init__(process_group.rank())
self.process_group = process_group
self.use_cuda_rpc = use_cuda_rpc
self.server_rref = server_rref
self.backend = backend
self.epochs = epochs
self.preprocess_data = preprocess_data
self.create_criterion = create_criterion
self.create_ddp_model = create_ddp_model
self.hook_state_class = hook_state_class
self.hook = hook
self.iteration_step = iteration_step
self.rank = process_group.rank()
self.trainer_count = process_group.size()
def epoch_key(self, epoch, index):
r"""
A method that returns an encoded key that represents the current epoch and
iteration index.
Args:
epoch (int): epoch index
index (int): iteration index
"""
return f"{epoch},{index}"
def train(self, model, data):
r"""
A method that implements the training algorithm.
Args:
model (nn.Module): neural network model
data (list): training examples
"""
model = model.cuda(self.rank)
data = self.preprocess_data(self.rank, data)
criterion = self.create_criterion(self.rank)
ddp_model, hook_state = self.create_ddp_model(
self, self.rank, model, self.process_group, self.hook_state_class, self.hook
)
optimizer = torch.optim.SGD(ddp_model.parameters(), 1e-4)
for epoch in range(self.epochs):
if epoch % 5 == 0 and self.rank == 0:
print(f"train epoch={epoch}")
for index, batch in enumerate(data):
self.iteration_step(
self,
ddp_model,
criterion,
optimizer,
hook_state,
epoch,
index,
batch,
)
torch.cuda.synchronize(self.rank)