Compare commits

...

1 Commits

Author SHA1 Message Date
0484565082 Initial work, need to run tests 2023-07-27 11:22:28 -04:00

View File

@ -1842,47 +1842,42 @@ class Trainer:
if is_torch_tpu_available():
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
use_accelerate = (
(is_sagemaker_mp_enabled() and args.fp16)
or hasattr(self.optimizer, "clip_grad_norm")
or hasattr(model, "clip_grad_norm_")
or self.use_apex
)
if not use_accelerate:
if self.do_grad_scaling:
# AMP: gradients need unscaling when not using Accelerate
self.scaler.unscale_(self.optimizer)
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
# Accelerate handles unscaling of the gradients
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
# Optimizer step
optimizer_was_run = True
if is_torch_tpu_available():
if self.do_grad_scaling:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
self.optimizer.step()
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
@ -2649,9 +2644,7 @@ class Trainer:
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
elif self.use_apex:
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else: