mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/105928 Approved by: https://github.com/albanD
26 lines
1.1 KiB
Python
26 lines
1.1 KiB
Python
def basic_iteration_step(
|
|
self, ddp_model, criterion, optimizer, hook_state, epoch, index, batch
|
|
):
|
|
r"""
|
|
A function that performs an iteration of training.
|
|
Args:
|
|
ddp_model (nn.Module): distributed data parallel model
|
|
criterion (nn.Module): loss function to measure model
|
|
optimizer (optim.Optimizer): updates model parameters
|
|
hook_state (object): ddp communication hook state object
|
|
epoch (int): index of pass through the data
|
|
index (int): iteration number - 1 in current batch
|
|
batch (list): training examples
|
|
"""
|
|
hook_state.next_batch()
|
|
self.record_batch_start(self.epoch_key(epoch, index))
|
|
optimizer.zero_grad()
|
|
self.record_forward_start(self.epoch_key(epoch, index))
|
|
loss = criterion(ddp_model(batch[0]), batch[1])
|
|
self.record_forward_end(self.epoch_key(epoch, index))
|
|
self.record_backward_start(self.epoch_key(epoch, index))
|
|
loss.backward()
|
|
self.record_backward_end(self.epoch_key(epoch, index))
|
|
optimizer.step()
|
|
self.record_batch_end(self.epoch_key(epoch, index))
|