Merge branch 'master' into loadams/pyproject-toml

This commit is contained in:
Logan Adams
2025-04-21 07:54:15 -07:00
committed by GitHub

View File

@ -589,7 +589,7 @@ def _get_data_parallel_rank():
def _get_sequence_parallel_world_size(): def _get_sequence_parallel_world_size():
"""Return world size for the model parallel group.""" """Return world size for the sequence parallel group."""
global mpu global mpu
if mesh_device is not None: if mesh_device is not None:
return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel")) return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel"))
@ -599,7 +599,7 @@ def _get_sequence_parallel_world_size():
def _get_sequence_parallel_rank(): def _get_sequence_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the sequence parallel group."""
global mpu global mpu
if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'): if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'):
return mpu.get_sequence_parallel_rank() return mpu.get_sequence_parallel_rank()