Ulysses SP for HF Integration (#7268)

This is the Deepspeed counterpart of
https://github.com/snowflakedb/ArcticTraining/pull/45 - as the new
feature(s) require changes on both sides.


For PR reviewers: 

Readiness status:
- [x] Code
- [x] Tests
- [ ] Docs - working on it


Features:

- [x] add support for delaying grad addition via
`param.ds_grad_is_ready` flag (used when performing tiled compute in an
autograd function)
- [x] add light sp-only mpu version (Jeff Rasley)
- [x] improved debug
- [x] added `all_gather_object` to `dist`
- [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from
Megatron-Deepspeed plus modern MHA-variations)
- [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL
batches to be used by `UlyssesSPAttentionHF`
- [x] `SequenceTiledCompute` - generic autograd function to perform
compute after tiling on the sequence dimension
- [x] `TiledMLP` - a specific autograd function to perform tiled MLP
(it's much easier to understand before trying to grok
`SequenceTiledCompute`)
- [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari)
- [x] torch-dist-check now allows `torch.distributed.nn` (which is
needed since deepspeed's dist is not up to date with
`torch.distributed.nn`)

---------

Signed-off-by: Stas Bekman <stas.bekman@snowflake.com>
Signed-off-by: Stas Bekman <stas@stason.org>
Co-authored-by: Stas Bekman <stas.bekman@snowflake.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Stas Bekman
2025-05-31 00:25:23 -07:00
committed by GitHub
parent 0baf79ead0
commit 4d00b38ada
21 changed files with 1868 additions and 32 deletions

View File

@ -43,8 +43,8 @@ jobs:
- name: Install deepspeed
run: |
pip install transformers==4.48.3
pip install .[dev]
pip install transformers==4.48.3
ds_report
- name: Install deepspeed-chat

View File

@ -242,6 +242,12 @@ def all_gather(tensor_list,
return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
@timed_op
def all_gather_object(object_list, obj, group=None, prof=False, log_name='all_gather_object', debug=get_caller_func()):
global cdb
return cdb.all_gather_object(object_list=object_list, obj=obj, group=group)
def has_reduce_scatter_tensor():
global cdb
assert cdb is not None and cdb.is_initialized(

View File

@ -268,6 +268,10 @@ class TorchBackend(Backend):
else:
reqs[-1].wait()
@disable_compiler_collective
def all_gather_object(self, object_list, obj, group=None):
return torch.distributed.all_gather_object(object_list=object_list, obj=obj, group=group)
@disable_compiler_collective
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
if self.has_reduce_scatter_tensor():

View File

@ -721,14 +721,23 @@ class DeepSpeedConfig(object):
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
)
try:
self.global_rank = dist.get_rank()
if mpu is not None:
self.world_size = mpu.get_data_parallel_world_size()
# Ulysses SP
if not hasattr(mpu, "get_data_parallel_world_size"):
self.world_size = dist.get_world_size() / mpu.get_sequence_parallel_world_size()
else:
self.world_size = mpu.get_data_parallel_world_size()
elif mesh_device is not None:
self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
else:
self.world_size = dist.get_world_size()
# HF zero.init case where there is no mpu
if "sequence_parallel_size" in config:
self.world_size = dist.get_world_size() / config["sequence_parallel_size"]
else:
self.world_size = dist.get_world_size()
except:
self.global_rank = 0
self.world_size = 1
@ -941,7 +950,7 @@ class DeepSpeedConfig(object):
micro_batch = self.train_micro_batch_size_per_gpu
grad_acc = self.gradient_accumulation_steps
#print(f"train_batch = {train_batch}, micro_batch={micro_batch}")
#print(f"in: train_batch = {train_batch}, micro_batch={micro_batch}")
# all values are provided nothing needs to be set
if train_batch is not None and micro_batch is not None and grad_acc is not None:
@ -980,6 +989,8 @@ class DeepSpeedConfig(object):
assert False, \
'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'
#print(f"final: {self.train_batch_size=} {self.train_micro_batch_size_per_gpu=} {self.gradient_accumulation_steps=}")
def _configure_train_batch_size(self):
self._set_batch_related_parameters()
self._batch_assertion()

View File

@ -1303,6 +1303,15 @@ class DeepSpeedEngine(Module):
self.communication_data_type = self._config.seq_parallel_communication_data_type
self.seq_parallel_group = groups._get_sequence_parallel_group()
if dist.get_rank() == 0:
summary = "********** distributed groups summary **********\n"
summary += f"\t {self.dp_world_size=}\n"
summary += f"\t {self.mp_world_size=}\n"
summary += f"\t {self.seq_dp_world_size=}\n"
summary += f"\t {self.sequence_parallel_size=}\n"
summary += "***********************************************"
logger.info(summary)
if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()

View File

@ -0,0 +1,4 @@
# Copyright (c) The DeepSpeed Contributors
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team

View File

@ -0,0 +1,90 @@
# Copyright (c) The DeepSpeed Contributors
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
This is a slimmed-down version of parallel_state.py (mpu) from Megatron-Deepspeed
"""
from deepspeed import comm as dist
# Sequence parallel groups to handle both data and sequence parallelisms.
# These groups are used to reduce gradients and shard parameters and optimizer stages for ZeRO.
_SEQUENCE_PARALLEL_GROUP = None
_SEQUENCE_DATA_PARALLEL_GROUP = None
def initialize_sequence_parallel(sequence_parallel_size: int) -> None:
"""Initialize sequence parallel groups."""
assert dist.is_initialized()
world_size: int = dist.get_world_size()
if world_size < sequence_parallel_size:
raise RuntimeError(f"world_size ({world_size}) is less than sequence_parallel_size {sequence_parallel_size}")
if sequence_parallel_size <= 1:
raise ValueError(f"sequence_parallel_size must be greater than 1, got {sequence_parallel_size}")
if world_size % sequence_parallel_size != 0:
raise RuntimeError(
f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})")
data_parallel_size: int = world_size // sequence_parallel_size
sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size
rank = dist.get_rank()
# Build the sequence parallel groups.
global _SEQUENCE_PARALLEL_GROUP
assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
_SEQUENCE_PARALLEL_GROUP = group
# Build the sequence data parallel groups.
global _SEQUENCE_DATA_PARALLEL_GROUP
assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized"
all_data_sequence_parallel_group_ranks = []
for i in range(num_sequence_data_parallel_groups):
ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size)
group = dist.new_group(ranks)
all_data_sequence_parallel_group_ranks.append(list(ranks))
if rank in ranks:
_SEQUENCE_DATA_PARALLEL_GROUP = group
def get_sequence_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
return _SEQUENCE_PARALLEL_GROUP
def get_sequence_data_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized"
return _SEQUENCE_DATA_PARALLEL_GROUP
def get_sequence_parallel_world_size():
"""Return world size for the sequence parallel group."""
return dist.get_world_size(group=get_sequence_parallel_group())
def get_sequence_data_parallel_world_size():
"""Return world size for the sequence parallel group."""
return dist.get_world_size(group=get_sequence_data_parallel_group())
def get_sequence_parallel_rank():
"""Return my rank for the sequence parallel group."""
return dist.get_rank(group=get_sequence_parallel_group())
def get_sequence_data_parallel_rank():
"""Return my rank for the sequence data parallel group."""
return dist.get_rank(group=get_sequence_data_parallel_group())

File diff suppressed because it is too large Load Diff

View File

@ -141,19 +141,19 @@ def copy_to_device(item, device, criterion_func):
return item
def move_to_device(item, device, criterion_func):
def move_to_device(item, device, criterion_func=None):
"""
Move tensor on to specified device by changing the storage.
Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to move or (possibly nested) container of tensors to move.
device: target device
criterion_func: Function to restrict move operation to items meet criterion
criterion_func: Function to restrict move operation to items meet criterion, defaults to `None` which is an equivalent to always move
Returns:
None
"""
if criterion_func(item):
if (criterion_func is not None and criterion_func(item)):
device_copy = item.to(device)
item.data = device_copy.data
return item
@ -164,7 +164,7 @@ def move_to_device(item, device, criterion_func):
elif isinstance(item, dict):
return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
else:
return item
return item.to(device)
def get_norm_with_moe_layers_fast(all_groups_norm, group):

View File

@ -208,7 +208,8 @@ class DeepSpeedZeRoOffload(object):
zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
else:
group = None
if mpu:
# parallel_state_sp doesn't have get_data_parallel_group
if mpu and hasattr(mpu, "get_data_parallel_group"):
group = mpu.get_data_parallel_group()
Init(module=module,
@ -480,7 +481,7 @@ class DeepSpeedZeRoOffload(object):
force=False)
param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)
param_coordinator.release_sub_module(sub_module, forward=True)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
@ -502,7 +503,7 @@ class DeepSpeedZeRoOffload(object):
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)
self.get_param_coordinator().release_sub_module(sub_module)
self.get_param_coordinator().release_sub_module(sub_module, forward=False)
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",

View File

@ -1264,6 +1264,9 @@ class Init(InsertPostInitMethodToModuleSubClasses):
ds_process_group,
)
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
#print_rank_0(f"{param.shape=}", force=True)
#print_rank_0(f"{param_buffer.shape=}", force=True)
return AllGatherHandle(handles, param)
else:
if hasattr(param_ds_tensor, "ds_quant_scale"):

View File

@ -15,7 +15,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partitioned_param_profiler import PartitionedParameterProfiler
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
from deepspeed.utils.debug import debug_param2name_id_shape
from deepspeed.accelerator import get_accelerator
import deepspeed.runtime.compiler as compiler
from deepspeed.runtime.compiler import is_compiling
@ -267,9 +267,8 @@ class PartitionedParameterCoordinator:
def _dump_params(self, tag, sub_module, params, step_id=None):
if step_id is None:
step_id = self.__step_id
param_names = [debug_param2name_id(p) for p in params]
print_rank_0(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}',
force=False)
param_names = [debug_param2name_id_shape(p) for p in params]
print_rank_0(f'{tag} step = {step_id} p_names = {param_names}', force=False)
def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
if step_id is None:
@ -305,7 +304,10 @@ class PartitionedParameterCoordinator:
if fetch_numel > 0:
event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
self._dump_param_ids(event_name, current_submodule.ds_id,
[p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
[(p.ds_id, p.ds_shape)
for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
# self._dump_params(event_name, current_submodule, [p for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
self.__profiler.start_event(event_name)
# kick off all gather for params in the immediately required submodule
#for param in params_to_fetch:
@ -420,9 +422,10 @@ class PartitionedParameterCoordinator:
@instrument_w_nvtx
@torch.no_grad()
def release_sub_module(self, submodule: Module) -> None:
def release_sub_module(self, submodule: Module, forward=False) -> None:
"""release the parameters of a sub module, assuming they meet conditions to
be released."""
#print_rank_0(f"release_sub_module {'fwd' if forward else 'bwd'}: {debug_module2name_id(submodule)}", force=False)
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
@ -517,6 +520,7 @@ class PartitionedParameterCoordinator:
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-release: {param.ds_summary()}")
print_rank_0(f"release: {debug_param2name_id_shape(param)}", force=False)
param.partition(free_data=free_data)
self.__n_available_params -= param.ds_numel

View File

@ -291,7 +291,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.zeropp_loco_param = zeropp_loco_param
if mpu is None:
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
self.model_parallel_group = None
self.model_parallel_rank = 0
else:
@ -1268,7 +1268,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.__reduce_and_partition_ipg_grads()
self.__add_grad_to_ipg_bucket(param)
# deal with a use-case of transient grads that will be generated in a loop for the same computation involving some model params - e.g. when performing a tiled memory calculation that shards the normal single sub-module call into a loop over a shards.
if getattr(param, "ds_grad_is_ready", True):
self.__add_grad_to_ipg_bucket(param)
@instrument_w_nvtx
@torch.no_grad()

View File

@ -211,7 +211,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self._configure_moe_settings()
self._global_grad_norm = 0.
if mpu is None:
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
self.model_parallel_group = None
self.model_parallel_world_size = 1
self.model_parallel_rank = 0
@ -986,8 +986,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))
# deal with a use-case of transient grads that will be generated in a loop for the same computation involving some model params - e.g. when performing a tiled memory calculation that shards the normal single sub-module call into a loop over a shards.
if getattr(param, "ds_grad_is_ready", True):
self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))
#make sure the average tensor function knows how to average the gradients
if is_moe_param(param):

View File

@ -254,6 +254,26 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
return res
class _DimZeroAllToAll(torch.autograd.Function):
"""Differentiable All2All across dimension 0."""
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:
world_size = dist.get_world_size(group)
assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size"
ctx.group = group
output = torch.empty_like(input).contiguous()
# torch.distributed.nn.functional.all_to_all_single(output, input.contiguous(), group=group)
dist.all_to_all_single(output, input.contiguous(), group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output))
class _SeqAllToAll(torch.autograd.Function):
@staticmethod

View File

@ -3,6 +3,8 @@
# DeepSpeed Team
import deepspeed.comm as dist
# For lazy import with printflock()
fcntl = None
@ -35,7 +37,7 @@ def debug_module2name(module):
def debug_module2name_id(module):
return f"name={debug_module2name(module)} id={module.id}"
return f"name={debug_module2name(module)}"
def debug_module2name_class(module):
@ -54,11 +56,11 @@ def debug_param2name_id(param):
def debug_param2name_id_shape(param):
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape}"
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.ds_shape}"
def debug_param2name_id_shape_device(param):
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} device={param.device}"
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.ds_shape} device={param.device}"
def debug_param2name_id_numel(param):
@ -66,7 +68,7 @@ def debug_param2name_id_numel(param):
def debug_param2name_id_shape_status(param):
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} status={param.ds_status}"
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.ds_shape} status={param.ds_status}"
def printflock(*msgs):
@ -151,3 +153,21 @@ def print_backward_tensors(tensor):
if hasattr(tensor, 'grad_fn'):
_print_bwd_tensors(tensor.grad_fn)
def print_rank(*msg, force=False):
"""print something on all global ranks with [rank] prefix.
"""
if not force:
return
global_rank = dist.get_rank()
print(f"[{global_rank}]", *msg)
def print_rank0(*msg, force=False):
"""print something only on rank 0"""
if not force:
return
global_rank = dist.get_rank()
if global_rank == 0:
print(f"[{global_rank}]", *msg)

16
deepspeed/utils/groups.py Executable file → Normal file
View File

@ -523,7 +523,10 @@ def _get_data_parallel_group():
if mesh_device is not None:
return mesh_device.get_group(mesh_dim="data_parallel")
if mpu is not None:
return mpu.get_data_parallel_group()
if hasattr(mpu, 'initialize_sequence_parallel'):
return None
else:
return mpu.get_data_parallel_group()
# Return the clone of dist world group
return _clone_world_group()
@ -571,16 +574,19 @@ def _get_data_parallel_world_size():
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
global mpu
if mpu is not None:
return mpu.get_data_parallel_world_size()
if hasattr(mpu, 'initialize_sequence_parallel'):
return None
else:
return mpu.get_data_parallel_world_size()
return dist.get_world_size(group=_get_data_parallel_group())
def _get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global mpu
if mpu is not None:
return mpu.get_model_parallel_world_size()
return 1
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
return 1
return mpu.get_model_parallel_world_size()
def _get_data_parallel_rank():

View File

@ -25,8 +25,9 @@ def err(s: str) -> None:
# - we can reasonably assume it's available on all machines
# - unlike plain grep, which is slower and has different flags on MacOS versus
# Linux, git grep is always the same.
# allowing `torch.distributed.nn`
res = subprocess.run(
["git", "grep", "-Hn", "--no-index", r"torch\.distributed", *sys.argv[1:]],
["git", "grep", "-Hn", "--no-index", "-P", r"torch\.distributed |torch\.distributed(?!\.nn)", *sys.argv[1:]],
capture_output=True,
)
if res.returncode == 0:

View File

@ -0,0 +1,209 @@
# Copyright (c) The DeepSpeed Contributors
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
UlyssesPlus: Tiled compute tests
"""
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP, sequence_tiled_compute
from deepspeed.utils import safe_get_full_grad
from torch.nn import Linear, Module
from unit.common import DistributedTest, preferred_dtype
from unit.util import torch_assert_equal, torch_assert_close
import deepspeed
import pytest
import torch
def get_grad(param, zero_stage):
return safe_get_full_grad(param)
# z1 now has contiguous_gradients enabled by default so `param.grad is None` even under z1
# if zero_stage == 1:
# return param.grad
# else:
# return safe_get_full_grad(param)
class SimpleMLP(Module):
def __init__(self, hidden_dim):
super().__init__()
self.up_proj = Linear(hidden_dim, hidden_dim * 2, bias=False)
self.down_proj = Linear(hidden_dim * 2, hidden_dim, bias=False)
self.act = torch.nn.ReLU()
def forward(self, x):
return self.down_proj(self.act(self.up_proj(x)))
# save the original implementation to pass through to the tiled computation wrapper
mlp_forward_orig = SimpleMLP.forward
class MyModel(Module):
def __init__(self, hidden_dim):
super().__init__()
# Critical - need to use a stack of at least 2 mlps to validate that the backward of the last mlp sends the correct gradients to the previous mlp in the stack
self.mlp1 = SimpleMLP(hidden_dim)
self.mlp2 = SimpleMLP(hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
def forward(self, x, y):
x = self.mlp1(x)
x = self.mlp2(x)
return self.cross_entropy_loss(x, y)
def mlp_forward_tiled_mlp(self, x):
# this tests TiledMLP
compute_params = [self.down_proj.weight, self.up_proj.weight]
num_shards = 4
return TiledMLP.apply(
mlp_forward_orig,
self,
x,
num_shards,
compute_params,
)
def mlp_forward_sequence_tiled_compute(self, x):
# this tests: sequence_tiled_compute + SequenceTiledCompute - same as TiledMLP but a-non-MLP
# specific generic implementation of tiled compute
kwargs_to_shard = dict(x=x)
kwargs_to_pass = dict(self=self)
grad_requiring_tensor_key = "x"
compute_params = [self.down_proj.weight, self.up_proj.weight]
seqlen = x.shape[1]
num_shards = 4
return sequence_tiled_compute(
mlp_forward_orig,
seqlen,
num_shards,
kwargs_to_shard,
kwargs_to_pass,
grad_requiring_tensor_key,
compute_params,
output_unshard_dimension=1, # x
output_reduction=None,
)
@pytest.mark.parametrize("zero_stage", [1, 3])
class TestTiledCompute(DistributedTest):
world_size = 1
def test_tiled_mlp(self, zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": zero_stage
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
}
dtype = preferred_dtype()
if dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
elif dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0}
# for debug
# torch.set_printoptions(precision=8, sci_mode=True)
seed = 42
hidden_dim = 100
bs = 1
seqlen = hidden_dim
torch.manual_seed(seed)
x = torch.rand((bs, seqlen, hidden_dim), dtype=dtype, requires_grad=True)
y = torch.empty((bs, seqlen), dtype=torch.long, requires_grad=False).random_(hidden_dim)
# A. Baseline: model with normal MLP
torch.manual_seed(seed)
model_a = MyModel(hidden_dim=hidden_dim).to(dtype)
model_a, _, _, _ = deepspeed.initialize(config=config_dict,
model=model_a,
model_parameters=model_a.parameters())
x = x.to(model_a.device)
y = y.to(model_a.device)
x_a = x.clone().detach().requires_grad_(True)
y_a = y.clone().detach()
loss_a = model_a(x_a, y_a)
model_a.backward(loss_a)
grad_a1 = get_grad(model_a.module.mlp1.up_proj.weight, zero_stage)
grad_a2 = get_grad(model_a.module.mlp2.up_proj.weight, zero_stage)
assert grad_a1 is not None
assert grad_a2 is not None
# B. model with tiled MLP using TiledMLP
torch.manual_seed(seed)
SimpleMLP.forward = mlp_forward_tiled_mlp
model_b = MyModel(hidden_dim=hidden_dim).to(dtype)
model_b, _, _, _ = deepspeed.initialize(config=config_dict,
model=model_b,
model_parameters=model_b.parameters())
x_b = x.clone().detach().requires_grad_(True)
y_b = y.clone().detach()
loss_b = model_b(x_b, y_b)
model_b.backward(loss_b)
grad_b1 = get_grad(model_b.module.mlp1.up_proj.weight, zero_stage)
grad_b2 = get_grad(model_b.module.mlp2.up_proj.weight, zero_stage)
assert grad_b1 is not None
assert grad_b2 is not None
# print(f"{loss_a=}")
# print(f"{loss_b=}")
# print(f"{grad_a1=}")
# print(f"{grad_b1=}")
# print(f"{grad_a2=}")
# print(f"{grad_b2=}")
torch_assert_equal(loss_a, loss_b)
# Gradient will not be exactly the same, especially under half-precision. And bf16 is
# particularly lossy so need to lower tolerance a bit more than the default. Switch to
# dtype torch.float or even torch.double to see that the diff is tiny - so the math is
# correct, but accumulation error adds up. Alternatively making hidden_dim bigger makes the
# divergence much smaller as well.
torch_assert_close(grad_a1, grad_b1) #, rtol=1e-03, atol=1e-04)
torch_assert_close(grad_a2, grad_b2) #, rtol=1e-03, atol=1e-04)
# C. model with tiled MLP using the generic version of the same via sequence_tiled_compute + SequenceTiledCompute
torch.manual_seed(seed)
SimpleMLP.forward = mlp_forward_sequence_tiled_compute
model_c = MyModel(hidden_dim=hidden_dim).to(dtype)
model_c, _, _, _ = deepspeed.initialize(config=config_dict,
model=model_c,
model_parameters=model_c.parameters())
x_c = x.clone().detach().requires_grad_(True)
y_c = y.clone().detach()
loss_c = model_c(x_c, y_c)
model_c.backward(loss_c)
grad_c1 = get_grad(model_c.module.mlp1.up_proj.weight, zero_stage)
grad_c2 = get_grad(model_c.module.mlp2.up_proj.weight, zero_stage)
assert grad_c1 is not None
assert grad_c2 is not None
# print(f"{loss_a=}")
# print(f"{loss_c=}")
# print(f"{grad_a1=}")
# print(f"{grad_c1=}")
# see notes for B
torch_assert_equal(loss_a, loss_c)
torch_assert_close(grad_a1, grad_c1) #, rtol=1e-03, atol=1e-04)
torch_assert_close(grad_a2, grad_c2) #, rtol=1e-03, atol=1e-04)

View File

@ -0,0 +1,187 @@
# Copyright (c) The DeepSpeed Contributors
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
UlyssesPlus: UlyssesSPHF tests
"""
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter
from deepspeed.runtime.utils import move_to_device
from deepspeed.utils import groups
from deepspeed.utils import safe_get_full_grad
from torch import tensor
from transformers import AutoModelForCausalLM
from unit.common import DistributedTest, preferred_dtype
from unit.util import torch_assert_equal, torch_assert_close, torch_assert_dicts_of_tensors_equal
import deepspeed
import deepspeed.comm as dist
import pytest
import torch
def get_grad(param, zero_stage):
return safe_get_full_grad(param)
# z1 now has contiguous_gradients enabled by default so `param.grad is None` even under z1
# if zero_stage == 1:
# return param.grad
# else:
# return safe_get_full_grad(param)
@pytest.mark.parametrize("zero_stage", [1, 3])
class TestUlyssesSPHF(DistributedTest):
world_size = 2
def test_ulysses_sp_hf(self, zero_stage):
model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
#model_name_or_path = 'Felladrin/Llama-160M-Chat-v1'
max_length = 64
sequence_parallel_size = self.world_size
micro_batch_size = 1
rank = dist.get_rank()
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": zero_stage,
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"sequence_parallel_size": sequence_parallel_size,
}
dtype = preferred_dtype()
if dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
elif dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0}
# Part 1. Baseline: Setup
def collate_fn(batch):
input_ids, position_ids = batch[0]
#print(f"{batch=}")
return dict(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
labels=input_ids.unsqueeze(0))
input_ids = tensor([[1, 10, 10, 10, 2, 2], [1, 20, 20, 20, 2, 2]], )
position_ids = tensor([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]])
ds = torch.utils.data.TensorDataset(input_ids, position_ids)
# 1. Baseline: DataLoader calibration
dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
batch_a = next(iter(dl_a))
#print(f"{rank=} {batch_a=}")
expected_batch_a = {
'input_ids': tensor([[1, 10, 10, 10, 2, 2]]),
'position_ids': tensor([[0, 1, 2, 3, 4, 5]]),
'labels': tensor([[1, 10, 10, 10, 2, 2]])
}
torch_assert_dicts_of_tensors_equal(batch_a, expected_batch_a)
# 2. Baseline: Attention
model_a = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model_a, _, _, _ = deepspeed.initialize(config=config_dict,
model=model_a,
model_parameters=model_a.parameters(),
mpu=None)
batch_a = move_to_device(batch_a, model_a.device)
loss_a = model_a(**batch_a).loss
model_a.backward(loss_a)
#print(f"{loss_a=}")
grad_a = get_grad(model_a.module.model.layers[0].self_attn.q_proj.weight, zero_stage)
assert grad_a is not None
#print(f"{grad_a}")
# Part 2. Ulysses: Setup
mpu = UlyssesSPAttentionHF.register_with_transformers(
model_name_or_path=model_name_or_path,
core_attn_implementation="sdpa",
sequence_parallel_size=sequence_parallel_size,
max_length=max_length,
micro_batch_size=micro_batch_size,
seq_length_is_variable=True,
)
model_b = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model_b, _, _, _ = deepspeed.initialize(config=config_dict,
model=model_b,
model_parameters=model_b.parameters(),
mpu=mpu)
# 3. Ulysses: UlyssesSPDataLoaderAdapter test
sp_group = groups._get_sequence_parallel_group()
sp_world_size = groups._get_sequence_parallel_world_size()
sp_rank = groups._get_sequence_parallel_rank()
dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
dl_b = UlyssesSPDataLoaderAdapter(
dl_a,
sp_rank=sp_rank,
sp_group=sp_group,
sp_world_size=sp_world_size,
device=model_b.device,
)
batch_b = next(iter(dl_b))
expected_batch_b = [
{
'input_ids': tensor([[1, 10, 10]]),
'position_ids': tensor([[0, 1, 2]]),
'shift_labels': tensor([[10, 10, 10]]),
},
{
'input_ids': tensor([[10, 2, 2]]),
'position_ids': tensor([[3, 4, 5]]),
'shift_labels': tensor([[2, 2, -100]]),
},
]
# here we expect each sample to be sharded in half, rank0 getting the first half and rank1 the other half
#print(f"{sp_rank=} {batch_b=}")
torch_assert_dicts_of_tensors_equal(batch_b, expected_batch_b[sp_rank])
# 4. UlyssesSPAttentionHF test
batch_b = move_to_device(batch_b, model_b.device)
outputs = model_b(**batch_b)
# HF doesn't calculate loss with shift_labels yet and requires us to do it manually (liger does that)
shift_labels = batch_b["shift_labels"]
loss_b = model_b.module.loss_function(
logits=outputs.logits,
labels=None,
shift_labels=shift_labels,
vocab_size=model_b.module.config.vocab_size,
)
# print(f"{sp_rank=} {loss_b=}")
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss_b, group=sp_group)
good_tokens = sum((shift_labels != -100).view(-1))
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss_b = total_loss / total_good_tokens
# print(f"{sp_rank=} gathered {loss_b=}")
model_b.backward(loss_b)
grad_b = get_grad(model_b.module.model.layers[0].self_attn.q_proj.weight, zero_stage)
assert grad_b is not None
#print(f"{grad_b}")
# compare loss of A (non-Ulysses Attention) and B (Ulyssses Attention)
torch_assert_equal(loss_a, loss_b)
# - we are feeding the exact same sample to each rank of A
# - for B we feed half the sample to each rank, but in total it's the same sample as each rank of A sees
# thus we expect very similar grads (but not exact)
if zero_stage in [1, 2]:
# possibly some issue with z1/z2 as it requires higher tolerance than z3?
torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03)
else:
torch_assert_close(grad_a, grad_b)

View File

@ -88,3 +88,34 @@ class no_child_process_in_deepspeed_io:
def __exit__(self, *_):
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
def torch_assert_equal(actual, expected, **kwargs):
"""
Compare two tensors or non-tensor numbers for their equality.
Add msg=blah to add an additional comment to when assert fails.
"""
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
def torch_assert_close(actual, expected, **kwargs):
"""
Compare two tensors or non-tensor numbers for their closeness.
Add msg=blah to add an additional comment to when assert fails.
For default values of `rtol` and `atol` which are dtype dependent, see the table at https://docs.pytorch.org/docs/stable/testing.html#torch.testing.assert_close
For example for bf16 it is `rtol=1.6e-2` and `atol=1e-5`.
The check doesn't assert when `|a - b| <= (atol + rtol * |b|)`
"""
return torch.testing.assert_close(actual, expected, **kwargs)
def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs):
"""
Compare two dicts of tensors or non-tensor numbers for their equality.
Add msg=blah to add an additional comment to when assert fails.
"""
for k in actual.keys():
torch.testing.assert_close(actual[k], expected[k], rtol=0.0, atol=0.0, **kwargs)