From 7de53b1f5d8908c99b82694a635286f96066d3a0 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 1 May 2025 16:12:20 +0000 Subject: [PATCH] sequence parallel mpu support --- deepspeed/comm/torch.py | 5 ++++- deepspeed/runtime/engine.py | 12 ++++++++++++ deepspeed/runtime/zero/stage3.py | 2 +- deepspeed/utils/groups.py | 13 +++++++++---- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index efa0640fb..bd95510f9 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -145,11 +145,14 @@ class TorchBackend(Backend): def init_process_group(self, backend, timeout, init_method, rank, world_size): if not torch.distributed.is_initialized(): + local_rank = int(os.environ.get('LOCAL_RANK', 0)) torch.distributed.init_process_group(backend, timeout=timeout, init_method=init_method, rank=rank, - world_size=world_size) + world_size=world_size, + device_id=torch.device('cuda', local_rank), + ) self.using_mpi = torch.distributed.get_backend() == 'mpi' @disable_compiler_collective diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 952318e3d..f9920758c 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1288,6 +1288,15 @@ class DeepSpeedEngine(Module): self.communication_data_type = self._config.seq_parallel_communication_data_type self.seq_parallel_group = groups._get_sequence_parallel_group() + if dist.get_rank() == 0: + summary = "********** distributed groups summary **********\n" + summary += f"\t {self.dp_world_size=}\n" + summary += f"\t {self.mp_world_size=}\n" + summary += f"\t {self.seq_dp_world_size=}\n" + summary += f"\t {self.sequence_parallel_size=}\n" + summary += "********** distributed groups summary **********" + print(summary) + if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() @@ -2309,6 +2318,7 @@ class DeepSpeedEngine(Module): self.losses = None self.global_steps += 1 + #print(f"{self.global_steps=}") self.global_samples += self.train_batch_size() def step(self, lr_kwargs=None): @@ -2334,8 +2344,10 @@ class DeepSpeedEngine(Module): self._step_applied = False # assume False, will flip to True + #print("before is_gradient_accumulation_boundary") # Update the model when we reach gradient accumulation boundaries if self.is_gradient_accumulation_boundary(): + #print("inside is_gradient_accumulation_boundary") self.gas_boundary_ctr += 1 if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ec0cd92b3..649b66e43 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -291,7 +291,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): self.zeropp_loco_param = zeropp_loco_param - if mpu is None: + if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'): self.model_parallel_group = None self.model_parallel_rank = 0 else: diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 6dc750035..fc0a60848 100755 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -523,7 +523,10 @@ def _get_data_parallel_group(): if mesh_device is not None: return mesh_device.get_group(mesh_dim="data_parallel") if mpu is not None: - return mpu.get_data_parallel_group() + if hasattr(mpu, 'initialize_sequence_parallel'): + return None + else: + return mpu.get_data_parallel_group() # Return the clone of dist world group return _clone_world_group() @@ -571,6 +574,8 @@ def _get_data_parallel_world_size(): return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel")) global mpu if mpu is not None: + if hasattr(mpu, 'initialize_sequence_parallel'): + return None return mpu.get_data_parallel_world_size() return dist.get_world_size(group=_get_data_parallel_group()) @@ -578,9 +583,9 @@ def _get_data_parallel_world_size(): def _get_model_parallel_world_size(): """Return world size for the model parallel group.""" global mpu - if mpu is not None: - return mpu.get_model_parallel_world_size() - return 1 + if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'): + return 1 + return mpu.get_model_parallel_world_size() def _get_data_parallel_rank():