mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
sequence parallel mpu support
This commit is contained in:
@ -145,11 +145,14 @@ class TorchBackend(Backend):
|
||||
|
||||
def init_process_group(self, backend, timeout, init_method, rank, world_size):
|
||||
if not torch.distributed.is_initialized():
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
torch.distributed.init_process_group(backend,
|
||||
timeout=timeout,
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
world_size=world_size,
|
||||
device_id=torch.device('cuda', local_rank),
|
||||
)
|
||||
self.using_mpi = torch.distributed.get_backend() == 'mpi'
|
||||
|
||||
@disable_compiler_collective
|
||||
|
@ -1288,6 +1288,15 @@ class DeepSpeedEngine(Module):
|
||||
self.communication_data_type = self._config.seq_parallel_communication_data_type
|
||||
self.seq_parallel_group = groups._get_sequence_parallel_group()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
summary = "********** distributed groups summary **********\n"
|
||||
summary += f"\t {self.dp_world_size=}\n"
|
||||
summary += f"\t {self.mp_world_size=}\n"
|
||||
summary += f"\t {self.seq_dp_world_size=}\n"
|
||||
summary += f"\t {self.sequence_parallel_size=}\n"
|
||||
summary += "********** distributed groups summary **********"
|
||||
print(summary)
|
||||
|
||||
if not (self.amp_enabled() or is_zero_init_model):
|
||||
self._broadcast_model()
|
||||
|
||||
@ -2309,6 +2318,7 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
self.losses = None
|
||||
self.global_steps += 1
|
||||
#print(f"{self.global_steps=}")
|
||||
self.global_samples += self.train_batch_size()
|
||||
|
||||
def step(self, lr_kwargs=None):
|
||||
@ -2334,8 +2344,10 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
self._step_applied = False # assume False, will flip to True
|
||||
|
||||
#print("before is_gradient_accumulation_boundary")
|
||||
# Update the model when we reach gradient accumulation boundaries
|
||||
if self.is_gradient_accumulation_boundary():
|
||||
#print("inside is_gradient_accumulation_boundary")
|
||||
self.gas_boundary_ctr += 1
|
||||
|
||||
if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
|
||||
|
@ -291,7 +291,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
|
||||
self.zeropp_loco_param = zeropp_loco_param
|
||||
|
||||
if mpu is None:
|
||||
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
self.model_parallel_group = None
|
||||
self.model_parallel_rank = 0
|
||||
else:
|
||||
|
@ -523,7 +523,10 @@ def _get_data_parallel_group():
|
||||
if mesh_device is not None:
|
||||
return mesh_device.get_group(mesh_dim="data_parallel")
|
||||
if mpu is not None:
|
||||
return mpu.get_data_parallel_group()
|
||||
if hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
return None
|
||||
else:
|
||||
return mpu.get_data_parallel_group()
|
||||
|
||||
# Return the clone of dist world group
|
||||
return _clone_world_group()
|
||||
@ -571,6 +574,8 @@ def _get_data_parallel_world_size():
|
||||
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
|
||||
global mpu
|
||||
if mpu is not None:
|
||||
if hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
return None
|
||||
return mpu.get_data_parallel_world_size()
|
||||
return dist.get_world_size(group=_get_data_parallel_group())
|
||||
|
||||
@ -578,9 +583,9 @@ def _get_data_parallel_world_size():
|
||||
def _get_model_parallel_world_size():
|
||||
"""Return world size for the model parallel group."""
|
||||
global mpu
|
||||
if mpu is not None:
|
||||
return mpu.get_model_parallel_world_size()
|
||||
return 1
|
||||
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
return 1
|
||||
return mpu.get_model_parallel_world_size()
|
||||
|
||||
|
||||
def _get_data_parallel_rank():
|
||||
|
Reference in New Issue
Block a user