Reland perf fix for nan inf check (#7184)

replace previous usage with logical ops for nan/inf detect with
torch.where

---------

Signed-off-by: Nadav Elyahu <nelyahu@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
This commit is contained in:
Nadav Elyahu
2025-04-02 19:21:25 +03:00
committed by GitHub
parent 79ff162722
commit 3c1817f38f
3 changed files with 16 additions and 20 deletions

View File

@ -823,6 +823,14 @@ def get_only_unique_item(items):
return unique_item
def mask_nan_or_inf_with_val_inplace(input, device=None, val=-1.):
norm_is_inf = input.isinf()
norm_is_nan = input.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
err = torch.tensor(-1.0, device=device, dtype=torch.float)
input.masked_fill_(inf_or_nan, err)
def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
"""Get norm of an iterable of tensors.
@ -897,8 +905,7 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)
inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan())
total_norm.masked_fill_(inf_or_nan, -1)
mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device)
return total_norm

View File

@ -19,7 +19,7 @@ from deepspeed.utils import logger
from deepspeed.utils.torch import register_grad_hook
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
@ -1453,12 +1453,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
total_norm = total_norm_cuda[0]**(1. / norm_type)
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device)
return total_norm.cpu()
@ -1815,7 +1810,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
total_norm = torch.where(inf_or_nan, err, total_norm)
return total_norm

View File

@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
align_dense_tensors, all_gather_dp_groups)
align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace)
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.adam import DeepSpeedCPUAdam
@ -1722,11 +1722,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
total_norm = total_norm.pow(1. / norm_type)
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
if norm_is_inf or norm_is_nan:
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)
mask_nan_or_inf_with_val_inplace(total_norm, device=self.device)
return total_norm
@ -1984,10 +1980,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
# handle total_norm invalid value -1
if clip > 1:
combined_scale = clip * self.loss_scale
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):