mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Reland perf fix for nan inf check (#7184)
replace previous usage with logical ops for nan/inf detect with torch.where --------- Signed-off-by: Nadav Elyahu <nelyahu@habana.ai> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
This commit is contained in:
@ -823,6 +823,14 @@ def get_only_unique_item(items):
|
||||
return unique_item
|
||||
|
||||
|
||||
def mask_nan_or_inf_with_val_inplace(input, device=None, val=-1.):
|
||||
norm_is_inf = input.isinf()
|
||||
norm_is_nan = input.isnan()
|
||||
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
|
||||
err = torch.tensor(-1.0, device=device, dtype=torch.float)
|
||||
input.masked_fill_(inf_or_nan, err)
|
||||
|
||||
|
||||
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
|
||||
"""Get norm of an iterable of tensors.
|
||||
|
||||
@ -897,8 +905,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
|
||||
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
|
||||
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)
|
||||
|
||||
inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan())
|
||||
total_norm.masked_fill_(inf_or_nan, -1)
|
||||
mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
@ -19,7 +19,7 @@ from deepspeed.utils import logger
|
||||
from deepspeed.utils.torch import register_grad_hook
|
||||
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
|
||||
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
|
||||
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
|
||||
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace
|
||||
from deepspeed.runtime.zero.partition_parameters import *
|
||||
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
|
||||
@ -1453,12 +1453,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
|
||||
total_norm = total_norm_cuda[0]**(1. / norm_type)
|
||||
|
||||
norm_is_inf = total_norm.isinf()
|
||||
norm_is_nan = total_norm.isnan()
|
||||
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
|
||||
|
||||
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
|
||||
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
|
||||
mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device)
|
||||
|
||||
return total_norm.cpu()
|
||||
|
||||
@ -1815,7 +1810,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
|
||||
|
||||
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
|
||||
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
|
||||
total_norm = torch.where(inf_or_nan, err, total_norm)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from deepspeed.runtime.base_optimizer import ZeROOptimizer
|
||||
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
|
||||
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
|
||||
align_dense_tensors, all_gather_dp_groups)
|
||||
align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace)
|
||||
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
||||
@ -1722,11 +1722,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
total_norm = total_norm.pow(1. / norm_type)
|
||||
|
||||
norm_is_inf = total_norm.isinf()
|
||||
norm_is_nan = total_norm.isnan()
|
||||
|
||||
if norm_is_inf or norm_is_nan:
|
||||
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)
|
||||
mask_nan_or_inf_with_val_inplace(total_norm, device=self.device)
|
||||
|
||||
return total_norm
|
||||
|
||||
@ -1984,10 +1980,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
if self.clip_grad > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
|
||||
|
||||
# handle total_norm invalid value -1
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
clip = torch.clamp(clip, min=1.0)
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
if isinstance(grad, list):
|
||||
|
Reference in New Issue
Block a user