Broadcast fp16 overflow in Z1 (#7580)

Fix #7568

Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Olatunji Ruwase
2025-09-23 11:51:43 -04:00
committed by GitHub
parent 8c7c56a932
commit bc9ed477e9

View File

@ -2178,10 +2178,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to(
get_accelerator().current_device_name())
if partition_gradients:
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs