mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Ulysses SP for HF Integration (#7268)
This is the Deepspeed counterpart of https://github.com/snowflakedb/ArcticTraining/pull/45 - as the new feature(s) require changes on both sides. For PR reviewers: Readiness status: - [x] Code - [x] Tests - [ ] Docs - working on it Features: - [x] add support for delaying grad addition via `param.ds_grad_is_ready` flag (used when performing tiled compute in an autograd function) - [x] add light sp-only mpu version (Jeff Rasley) - [x] improved debug - [x] added `all_gather_object` to `dist` - [x] `UlyssesSPAttentionHF` (port of UlyssesAttention from Megatron-Deepspeed plus modern MHA-variations) - [x] `UlyssesSPDataLoaderAdapter` - DL adapter to shard the normal DL batches to be used by `UlyssesSPAttentionHF` - [x] `SequenceTiledCompute` - generic autograd function to perform compute after tiling on the sequence dimension - [x] `TiledMLP` - a specific autograd function to perform tiled MLP (it's much easier to understand before trying to grok `SequenceTiledCompute`) - [x] added a differentiable `_DimZeroAllToAll` (Samyam Rajbhandari) - [x] torch-dist-check now allows `torch.distributed.nn` (which is needed since deepspeed's dist is not up to date with `torch.distributed.nn`) --------- Signed-off-by: Stas Bekman <stas.bekman@snowflake.com> Signed-off-by: Stas Bekman <stas@stason.org> Co-authored-by: Stas Bekman <stas.bekman@snowflake.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
2
.github/workflows/nv-ds-chat.yml
vendored
2
.github/workflows/nv-ds-chat.yml
vendored
@ -43,8 +43,8 @@ jobs:
|
||||
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
pip install transformers==4.48.3
|
||||
pip install .[dev]
|
||||
pip install transformers==4.48.3
|
||||
ds_report
|
||||
|
||||
- name: Install deepspeed-chat
|
||||
|
@ -242,6 +242,12 @@ def all_gather(tensor_list,
|
||||
return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
|
||||
|
||||
|
||||
@timed_op
|
||||
def all_gather_object(object_list, obj, group=None, prof=False, log_name='all_gather_object', debug=get_caller_func()):
|
||||
global cdb
|
||||
return cdb.all_gather_object(object_list=object_list, obj=obj, group=group)
|
||||
|
||||
|
||||
def has_reduce_scatter_tensor():
|
||||
global cdb
|
||||
assert cdb is not None and cdb.is_initialized(
|
||||
|
@ -268,6 +268,10 @@ class TorchBackend(Backend):
|
||||
else:
|
||||
reqs[-1].wait()
|
||||
|
||||
@disable_compiler_collective
|
||||
def all_gather_object(self, object_list, obj, group=None):
|
||||
return torch.distributed.all_gather_object(object_list=object_list, obj=obj, group=group)
|
||||
|
||||
@disable_compiler_collective
|
||||
def reduce_scatter_tensor(self, output_tensor, input_tensor, op=ReduceOp.SUM, group=None, async_op=False):
|
||||
if self.has_reduce_scatter_tensor():
|
||||
|
@ -721,14 +721,23 @@ class DeepSpeedConfig(object):
|
||||
raise ValueError(
|
||||
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
|
||||
)
|
||||
|
||||
try:
|
||||
self.global_rank = dist.get_rank()
|
||||
if mpu is not None:
|
||||
self.world_size = mpu.get_data_parallel_world_size()
|
||||
# Ulysses SP
|
||||
if not hasattr(mpu, "get_data_parallel_world_size"):
|
||||
self.world_size = dist.get_world_size() / mpu.get_sequence_parallel_world_size()
|
||||
else:
|
||||
self.world_size = mpu.get_data_parallel_world_size()
|
||||
elif mesh_device is not None:
|
||||
self.world_size = dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
|
||||
else:
|
||||
self.world_size = dist.get_world_size()
|
||||
# HF zero.init case where there is no mpu
|
||||
if "sequence_parallel_size" in config:
|
||||
self.world_size = dist.get_world_size() / config["sequence_parallel_size"]
|
||||
else:
|
||||
self.world_size = dist.get_world_size()
|
||||
except:
|
||||
self.global_rank = 0
|
||||
self.world_size = 1
|
||||
@ -941,7 +950,7 @@ class DeepSpeedConfig(object):
|
||||
micro_batch = self.train_micro_batch_size_per_gpu
|
||||
grad_acc = self.gradient_accumulation_steps
|
||||
|
||||
#print(f"train_batch = {train_batch}, micro_batch={micro_batch}")
|
||||
#print(f"in: train_batch = {train_batch}, micro_batch={micro_batch}")
|
||||
|
||||
# all values are provided nothing needs to be set
|
||||
if train_batch is not None and micro_batch is not None and grad_acc is not None:
|
||||
@ -980,6 +989,8 @@ class DeepSpeedConfig(object):
|
||||
assert False, \
|
||||
'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided'
|
||||
|
||||
#print(f"final: {self.train_batch_size=} {self.train_micro_batch_size_per_gpu=} {self.gradient_accumulation_steps=}")
|
||||
|
||||
def _configure_train_batch_size(self):
|
||||
self._set_batch_related_parameters()
|
||||
self._batch_assertion()
|
||||
|
@ -1303,6 +1303,15 @@ class DeepSpeedEngine(Module):
|
||||
self.communication_data_type = self._config.seq_parallel_communication_data_type
|
||||
self.seq_parallel_group = groups._get_sequence_parallel_group()
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
summary = "********** distributed groups summary **********\n"
|
||||
summary += f"\t {self.dp_world_size=}\n"
|
||||
summary += f"\t {self.mp_world_size=}\n"
|
||||
summary += f"\t {self.seq_dp_world_size=}\n"
|
||||
summary += f"\t {self.sequence_parallel_size=}\n"
|
||||
summary += "***********************************************"
|
||||
logger.info(summary)
|
||||
|
||||
if not (self.amp_enabled() or is_zero_init_model):
|
||||
self._broadcast_model()
|
||||
|
||||
|
4
deepspeed/runtime/sequence_parallel/__init__.py
Normal file
4
deepspeed/runtime/sequence_parallel/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) The DeepSpeed Contributors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
90
deepspeed/runtime/sequence_parallel/parallel_state_sp.py
Normal file
90
deepspeed/runtime/sequence_parallel/parallel_state_sp.py
Normal file
@ -0,0 +1,90 @@
|
||||
# Copyright (c) The DeepSpeed Contributors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
This is a slimmed-down version of parallel_state.py (mpu) from Megatron-Deepspeed
|
||||
"""
|
||||
|
||||
from deepspeed import comm as dist
|
||||
|
||||
# Sequence parallel groups to handle both data and sequence parallelisms.
|
||||
# These groups are used to reduce gradients and shard parameters and optimizer stages for ZeRO.
|
||||
_SEQUENCE_PARALLEL_GROUP = None
|
||||
_SEQUENCE_DATA_PARALLEL_GROUP = None
|
||||
|
||||
|
||||
def initialize_sequence_parallel(sequence_parallel_size: int) -> None:
|
||||
"""Initialize sequence parallel groups."""
|
||||
|
||||
assert dist.is_initialized()
|
||||
world_size: int = dist.get_world_size()
|
||||
|
||||
if world_size < sequence_parallel_size:
|
||||
raise RuntimeError(f"world_size ({world_size}) is less than sequence_parallel_size {sequence_parallel_size}")
|
||||
|
||||
if sequence_parallel_size <= 1:
|
||||
raise ValueError(f"sequence_parallel_size must be greater than 1, got {sequence_parallel_size}")
|
||||
|
||||
if world_size % sequence_parallel_size != 0:
|
||||
raise RuntimeError(
|
||||
f"world_size ({world_size}) is not divisible by sequence_parallel_size {sequence_parallel_size})")
|
||||
|
||||
data_parallel_size: int = world_size // sequence_parallel_size
|
||||
sequence_data_parallel_size: int = sequence_parallel_size * data_parallel_size
|
||||
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
|
||||
num_sequence_data_parallel_groups: int = world_size // sequence_parallel_size // data_parallel_size
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
# Build the sequence parallel groups.
|
||||
global _SEQUENCE_PARALLEL_GROUP
|
||||
assert _SEQUENCE_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
|
||||
for i in range(num_sequence_parallel_groups):
|
||||
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
|
||||
group = dist.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_SEQUENCE_PARALLEL_GROUP = group
|
||||
|
||||
# Build the sequence data parallel groups.
|
||||
global _SEQUENCE_DATA_PARALLEL_GROUP
|
||||
assert _SEQUENCE_DATA_PARALLEL_GROUP is None, "sequence data parallel group is already initialized"
|
||||
all_data_sequence_parallel_group_ranks = []
|
||||
for i in range(num_sequence_data_parallel_groups):
|
||||
ranks = range(i * sequence_data_parallel_size, (i + 1) * sequence_data_parallel_size)
|
||||
group = dist.new_group(ranks)
|
||||
all_data_sequence_parallel_group_ranks.append(list(ranks))
|
||||
if rank in ranks:
|
||||
_SEQUENCE_DATA_PARALLEL_GROUP = group
|
||||
|
||||
|
||||
def get_sequence_parallel_group():
|
||||
"""Get the sequence parallel group the caller rank belongs to."""
|
||||
assert _SEQUENCE_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
|
||||
return _SEQUENCE_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_sequence_data_parallel_group():
|
||||
"""Get the sequence parallel group the caller rank belongs to."""
|
||||
assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, "sequence data parallel group is not initialized"
|
||||
return _SEQUENCE_DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_sequence_parallel_world_size():
|
||||
"""Return world size for the sequence parallel group."""
|
||||
return dist.get_world_size(group=get_sequence_parallel_group())
|
||||
|
||||
|
||||
def get_sequence_data_parallel_world_size():
|
||||
"""Return world size for the sequence parallel group."""
|
||||
return dist.get_world_size(group=get_sequence_data_parallel_group())
|
||||
|
||||
|
||||
def get_sequence_parallel_rank():
|
||||
"""Return my rank for the sequence parallel group."""
|
||||
return dist.get_rank(group=get_sequence_parallel_group())
|
||||
|
||||
|
||||
def get_sequence_data_parallel_rank():
|
||||
"""Return my rank for the sequence data parallel group."""
|
||||
return dist.get_rank(group=get_sequence_data_parallel_group())
|
1226
deepspeed/runtime/sequence_parallel/ulysses_sp.py
Normal file
1226
deepspeed/runtime/sequence_parallel/ulysses_sp.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -141,19 +141,19 @@ def copy_to_device(item, device, criterion_func):
|
||||
return item
|
||||
|
||||
|
||||
def move_to_device(item, device, criterion_func):
|
||||
def move_to_device(item, device, criterion_func=None):
|
||||
"""
|
||||
Move tensor on to specified device by changing the storage.
|
||||
Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
|
||||
Parameters:
|
||||
item: tensor to move or (possibly nested) container of tensors to move.
|
||||
device: target device
|
||||
criterion_func: Function to restrict move operation to items meet criterion
|
||||
criterion_func: Function to restrict move operation to items meet criterion, defaults to `None` which is an equivalent to always move
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if criterion_func(item):
|
||||
if (criterion_func is not None and criterion_func(item)):
|
||||
device_copy = item.to(device)
|
||||
item.data = device_copy.data
|
||||
return item
|
||||
@ -164,7 +164,7 @@ def move_to_device(item, device, criterion_func):
|
||||
elif isinstance(item, dict):
|
||||
return {k: move_to_device(v, device, criterion_func) for k, v in item.items()}
|
||||
else:
|
||||
return item
|
||||
return item.to(device)
|
||||
|
||||
|
||||
def get_norm_with_moe_layers_fast(all_groups_norm, group):
|
||||
|
@ -208,7 +208,8 @@ class DeepSpeedZeRoOffload(object):
|
||||
zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
|
||||
else:
|
||||
group = None
|
||||
if mpu:
|
||||
# parallel_state_sp doesn't have get_data_parallel_group
|
||||
if mpu and hasattr(mpu, "get_data_parallel_group"):
|
||||
group = mpu.get_data_parallel_group()
|
||||
|
||||
Init(module=module,
|
||||
@ -480,7 +481,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
force=False)
|
||||
|
||||
param_coordinator = self.get_param_coordinator()
|
||||
param_coordinator.release_sub_module(sub_module)
|
||||
param_coordinator.release_sub_module(sub_module, forward=True)
|
||||
|
||||
see_memory_usage(
|
||||
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
|
||||
@ -502,7 +503,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
|
||||
force=False)
|
||||
|
||||
self.get_param_coordinator().release_sub_module(sub_module)
|
||||
self.get_param_coordinator().release_sub_module(sub_module, forward=False)
|
||||
|
||||
see_memory_usage(
|
||||
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
|
||||
|
@ -1264,6 +1264,9 @@ class Init(InsertPostInitMethodToModuleSubClasses):
|
||||
ds_process_group,
|
||||
)
|
||||
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
|
||||
#print_rank_0(f"{param.shape=}", force=True)
|
||||
#print_rank_0(f"{param_buffer.shape=}", force=True)
|
||||
|
||||
return AllGatherHandle(handles, param)
|
||||
else:
|
||||
if hasattr(param_ds_tensor, "ds_quant_scale"):
|
||||
|
@ -15,7 +15,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
|
||||
from deepspeed.runtime.zero.partition_parameters import *
|
||||
from deepspeed.runtime.zero.partitioned_param_profiler import PartitionedParameterProfiler
|
||||
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
|
||||
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id
|
||||
from deepspeed.utils.debug import debug_param2name_id_shape
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import deepspeed.runtime.compiler as compiler
|
||||
from deepspeed.runtime.compiler import is_compiling
|
||||
@ -267,9 +267,8 @@ class PartitionedParameterCoordinator:
|
||||
def _dump_params(self, tag, sub_module, params, step_id=None):
|
||||
if step_id is None:
|
||||
step_id = self.__step_id
|
||||
param_names = [debug_param2name_id(p) for p in params]
|
||||
print_rank_0(f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}',
|
||||
force=False)
|
||||
param_names = [debug_param2name_id_shape(p) for p in params]
|
||||
print_rank_0(f'{tag} step = {step_id} p_names = {param_names}', force=False)
|
||||
|
||||
def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
|
||||
if step_id is None:
|
||||
@ -305,7 +304,10 @@ class PartitionedParameterCoordinator:
|
||||
if fetch_numel > 0:
|
||||
event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
|
||||
self._dump_param_ids(event_name, current_submodule.ds_id,
|
||||
[p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
|
||||
[(p.ds_id, p.ds_shape)
|
||||
for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
|
||||
# self._dump_params(event_name, current_submodule, [p for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
|
||||
|
||||
self.__profiler.start_event(event_name)
|
||||
# kick off all gather for params in the immediately required submodule
|
||||
#for param in params_to_fetch:
|
||||
@ -420,9 +422,10 @@ class PartitionedParameterCoordinator:
|
||||
|
||||
@instrument_w_nvtx
|
||||
@torch.no_grad()
|
||||
def release_sub_module(self, submodule: Module) -> None:
|
||||
def release_sub_module(self, submodule: Module, forward=False) -> None:
|
||||
"""release the parameters of a sub module, assuming they meet conditions to
|
||||
be released."""
|
||||
#print_rank_0(f"release_sub_module {'fwd' if forward else 'bwd'}: {debug_module2name_id(submodule)}", force=False)
|
||||
params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set(
|
||||
p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule))))
|
||||
|
||||
@ -517,6 +520,7 @@ class PartitionedParameterCoordinator:
|
||||
if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
debug_rank0(f"-release: {param.ds_summary()}")
|
||||
print_rank_0(f"release: {debug_param2name_id_shape(param)}", force=False)
|
||||
param.partition(free_data=free_data)
|
||||
self.__n_available_params -= param.ds_numel
|
||||
|
||||
|
@ -291,7 +291,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
|
||||
self.zeropp_loco_param = zeropp_loco_param
|
||||
|
||||
if mpu is None:
|
||||
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
self.model_parallel_group = None
|
||||
self.model_parallel_rank = 0
|
||||
else:
|
||||
@ -1268,7 +1268,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
|
||||
self.__reduce_and_partition_ipg_grads()
|
||||
|
||||
self.__add_grad_to_ipg_bucket(param)
|
||||
# deal with a use-case of transient grads that will be generated in a loop for the same computation involving some model params - e.g. when performing a tiled memory calculation that shards the normal single sub-module call into a loop over a shards.
|
||||
if getattr(param, "ds_grad_is_ready", True):
|
||||
self.__add_grad_to_ipg_bucket(param)
|
||||
|
||||
@instrument_w_nvtx
|
||||
@torch.no_grad()
|
||||
|
@ -211,7 +211,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
self._configure_moe_settings()
|
||||
self._global_grad_norm = 0.
|
||||
|
||||
if mpu is None:
|
||||
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
self.model_parallel_group = None
|
||||
self.model_parallel_world_size = 1
|
||||
self.model_parallel_rank = 0
|
||||
@ -986,8 +986,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
|
||||
|
||||
self.grads_in_ipg_bucket.append(grad_reduc)
|
||||
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))
|
||||
# deal with a use-case of transient grads that will be generated in a loop for the same computation involving some model params - e.g. when performing a tiled memory calculation that shards the normal single sub-module call into a loop over a shards.
|
||||
if getattr(param, "ds_grad_is_ready", True):
|
||||
self.grads_in_ipg_bucket.append(grad_reduc)
|
||||
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))
|
||||
|
||||
#make sure the average tensor function knows how to average the gradients
|
||||
if is_moe_param(param):
|
||||
|
@ -254,6 +254,26 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
|
||||
return res
|
||||
|
||||
|
||||
class _DimZeroAllToAll(torch.autograd.Function):
|
||||
"""Differentiable All2All across dimension 0."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:
|
||||
world_size = dist.get_world_size(group)
|
||||
assert input.shape[0] == world_size, f"Dim 0 {input.shape[0]} is not world size"
|
||||
|
||||
ctx.group = group
|
||||
|
||||
output = torch.empty_like(input).contiguous()
|
||||
# torch.distributed.nn.functional.all_to_all_single(output, input.contiguous(), group=group)
|
||||
dist.all_to_all_single(output, input.contiguous(), group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _DimZeroAllToAll.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
class _SeqAllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
|
@ -3,6 +3,8 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import deepspeed.comm as dist
|
||||
|
||||
# For lazy import with printflock()
|
||||
fcntl = None
|
||||
|
||||
@ -35,7 +37,7 @@ def debug_module2name(module):
|
||||
|
||||
|
||||
def debug_module2name_id(module):
|
||||
return f"name={debug_module2name(module)} id={module.id}"
|
||||
return f"name={debug_module2name(module)}"
|
||||
|
||||
|
||||
def debug_module2name_class(module):
|
||||
@ -54,11 +56,11 @@ def debug_param2name_id(param):
|
||||
|
||||
|
||||
def debug_param2name_id_shape(param):
|
||||
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape}"
|
||||
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.ds_shape}"
|
||||
|
||||
|
||||
def debug_param2name_id_shape_device(param):
|
||||
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} device={param.device}"
|
||||
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.ds_shape} device={param.device}"
|
||||
|
||||
|
||||
def debug_param2name_id_numel(param):
|
||||
@ -66,7 +68,7 @@ def debug_param2name_id_numel(param):
|
||||
|
||||
|
||||
def debug_param2name_id_shape_status(param):
|
||||
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} status={param.ds_status}"
|
||||
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.ds_shape} status={param.ds_status}"
|
||||
|
||||
|
||||
def printflock(*msgs):
|
||||
@ -151,3 +153,21 @@ def print_backward_tensors(tensor):
|
||||
|
||||
if hasattr(tensor, 'grad_fn'):
|
||||
_print_bwd_tensors(tensor.grad_fn)
|
||||
|
||||
|
||||
def print_rank(*msg, force=False):
|
||||
"""print something on all global ranks with [rank] prefix.
|
||||
"""
|
||||
if not force:
|
||||
return
|
||||
global_rank = dist.get_rank()
|
||||
print(f"[{global_rank}]", *msg)
|
||||
|
||||
|
||||
def print_rank0(*msg, force=False):
|
||||
"""print something only on rank 0"""
|
||||
if not force:
|
||||
return
|
||||
global_rank = dist.get_rank()
|
||||
if global_rank == 0:
|
||||
print(f"[{global_rank}]", *msg)
|
||||
|
16
deepspeed/utils/groups.py
Executable file → Normal file
16
deepspeed/utils/groups.py
Executable file → Normal file
@ -523,7 +523,10 @@ def _get_data_parallel_group():
|
||||
if mesh_device is not None:
|
||||
return mesh_device.get_group(mesh_dim="data_parallel")
|
||||
if mpu is not None:
|
||||
return mpu.get_data_parallel_group()
|
||||
if hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
return None
|
||||
else:
|
||||
return mpu.get_data_parallel_group()
|
||||
|
||||
# Return the clone of dist world group
|
||||
return _clone_world_group()
|
||||
@ -571,16 +574,19 @@ def _get_data_parallel_world_size():
|
||||
return dist.get_world_size(mesh_device.get_group(mesh_dim="data_parallel"))
|
||||
global mpu
|
||||
if mpu is not None:
|
||||
return mpu.get_data_parallel_world_size()
|
||||
if hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
return None
|
||||
else:
|
||||
return mpu.get_data_parallel_world_size()
|
||||
return dist.get_world_size(group=_get_data_parallel_group())
|
||||
|
||||
|
||||
def _get_model_parallel_world_size():
|
||||
"""Return world size for the model parallel group."""
|
||||
global mpu
|
||||
if mpu is not None:
|
||||
return mpu.get_model_parallel_world_size()
|
||||
return 1
|
||||
if mpu is None or hasattr(mpu, 'initialize_sequence_parallel'):
|
||||
return 1
|
||||
return mpu.get_model_parallel_world_size()
|
||||
|
||||
|
||||
def _get_data_parallel_rank():
|
||||
|
@ -25,8 +25,9 @@ def err(s: str) -> None:
|
||||
# - we can reasonably assume it's available on all machines
|
||||
# - unlike plain grep, which is slower and has different flags on MacOS versus
|
||||
# Linux, git grep is always the same.
|
||||
# allowing `torch.distributed.nn`
|
||||
res = subprocess.run(
|
||||
["git", "grep", "-Hn", "--no-index", r"torch\.distributed", *sys.argv[1:]],
|
||||
["git", "grep", "-Hn", "--no-index", "-P", r"torch\.distributed |torch\.distributed(?!\.nn)", *sys.argv[1:]],
|
||||
capture_output=True,
|
||||
)
|
||||
if res.returncode == 0:
|
||||
|
209
tests/unit/ulysses_plus/test_tiled_compute.py
Normal file
209
tests/unit/ulysses_plus/test_tiled_compute.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright (c) The DeepSpeed Contributors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
UlyssesPlus: Tiled compute tests
|
||||
"""
|
||||
|
||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP, sequence_tiled_compute
|
||||
from deepspeed.utils import safe_get_full_grad
|
||||
from torch.nn import Linear, Module
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.util import torch_assert_equal, torch_assert_close
|
||||
import deepspeed
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def get_grad(param, zero_stage):
|
||||
return safe_get_full_grad(param)
|
||||
# z1 now has contiguous_gradients enabled by default so `param.grad is None` even under z1
|
||||
# if zero_stage == 1:
|
||||
# return param.grad
|
||||
# else:
|
||||
# return safe_get_full_grad(param)
|
||||
|
||||
|
||||
class SimpleMLP(Module):
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super().__init__()
|
||||
self.up_proj = Linear(hidden_dim, hidden_dim * 2, bias=False)
|
||||
self.down_proj = Linear(hidden_dim * 2, hidden_dim, bias=False)
|
||||
self.act = torch.nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act(self.up_proj(x)))
|
||||
|
||||
|
||||
# save the original implementation to pass through to the tiled computation wrapper
|
||||
mlp_forward_orig = SimpleMLP.forward
|
||||
|
||||
|
||||
class MyModel(Module):
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super().__init__()
|
||||
# Critical - need to use a stack of at least 2 mlps to validate that the backward of the last mlp sends the correct gradients to the previous mlp in the stack
|
||||
self.mlp1 = SimpleMLP(hidden_dim)
|
||||
self.mlp2 = SimpleMLP(hidden_dim)
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.mlp1(x)
|
||||
x = self.mlp2(x)
|
||||
return self.cross_entropy_loss(x, y)
|
||||
|
||||
|
||||
def mlp_forward_tiled_mlp(self, x):
|
||||
# this tests TiledMLP
|
||||
compute_params = [self.down_proj.weight, self.up_proj.weight]
|
||||
num_shards = 4
|
||||
|
||||
return TiledMLP.apply(
|
||||
mlp_forward_orig,
|
||||
self,
|
||||
x,
|
||||
num_shards,
|
||||
compute_params,
|
||||
)
|
||||
|
||||
|
||||
def mlp_forward_sequence_tiled_compute(self, x):
|
||||
# this tests: sequence_tiled_compute + SequenceTiledCompute - same as TiledMLP but a-non-MLP
|
||||
# specific generic implementation of tiled compute
|
||||
|
||||
kwargs_to_shard = dict(x=x)
|
||||
kwargs_to_pass = dict(self=self)
|
||||
grad_requiring_tensor_key = "x"
|
||||
compute_params = [self.down_proj.weight, self.up_proj.weight]
|
||||
seqlen = x.shape[1]
|
||||
num_shards = 4
|
||||
|
||||
return sequence_tiled_compute(
|
||||
mlp_forward_orig,
|
||||
seqlen,
|
||||
num_shards,
|
||||
kwargs_to_shard,
|
||||
kwargs_to_pass,
|
||||
grad_requiring_tensor_key,
|
||||
compute_params,
|
||||
output_unshard_dimension=1, # x
|
||||
output_reduction=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [1, 3])
|
||||
class TestTiledCompute(DistributedTest):
|
||||
world_size = 1
|
||||
|
||||
def test_tiled_mlp(self, zero_stage):
|
||||
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
}
|
||||
dtype = preferred_dtype()
|
||||
if dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
elif dtype == torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0}
|
||||
|
||||
# for debug
|
||||
# torch.set_printoptions(precision=8, sci_mode=True)
|
||||
|
||||
seed = 42
|
||||
hidden_dim = 100
|
||||
bs = 1
|
||||
seqlen = hidden_dim
|
||||
torch.manual_seed(seed)
|
||||
x = torch.rand((bs, seqlen, hidden_dim), dtype=dtype, requires_grad=True)
|
||||
y = torch.empty((bs, seqlen), dtype=torch.long, requires_grad=False).random_(hidden_dim)
|
||||
|
||||
# A. Baseline: model with normal MLP
|
||||
torch.manual_seed(seed)
|
||||
model_a = MyModel(hidden_dim=hidden_dim).to(dtype)
|
||||
model_a, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model_a,
|
||||
model_parameters=model_a.parameters())
|
||||
|
||||
x = x.to(model_a.device)
|
||||
y = y.to(model_a.device)
|
||||
|
||||
x_a = x.clone().detach().requires_grad_(True)
|
||||
y_a = y.clone().detach()
|
||||
|
||||
loss_a = model_a(x_a, y_a)
|
||||
model_a.backward(loss_a)
|
||||
grad_a1 = get_grad(model_a.module.mlp1.up_proj.weight, zero_stage)
|
||||
grad_a2 = get_grad(model_a.module.mlp2.up_proj.weight, zero_stage)
|
||||
assert grad_a1 is not None
|
||||
assert grad_a2 is not None
|
||||
|
||||
# B. model with tiled MLP using TiledMLP
|
||||
torch.manual_seed(seed)
|
||||
SimpleMLP.forward = mlp_forward_tiled_mlp
|
||||
model_b = MyModel(hidden_dim=hidden_dim).to(dtype)
|
||||
model_b, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model_b,
|
||||
model_parameters=model_b.parameters())
|
||||
|
||||
x_b = x.clone().detach().requires_grad_(True)
|
||||
y_b = y.clone().detach()
|
||||
loss_b = model_b(x_b, y_b)
|
||||
model_b.backward(loss_b)
|
||||
grad_b1 = get_grad(model_b.module.mlp1.up_proj.weight, zero_stage)
|
||||
grad_b2 = get_grad(model_b.module.mlp2.up_proj.weight, zero_stage)
|
||||
assert grad_b1 is not None
|
||||
assert grad_b2 is not None
|
||||
|
||||
# print(f"{loss_a=}")
|
||||
# print(f"{loss_b=}")
|
||||
# print(f"{grad_a1=}")
|
||||
# print(f"{grad_b1=}")
|
||||
# print(f"{grad_a2=}")
|
||||
# print(f"{grad_b2=}")
|
||||
torch_assert_equal(loss_a, loss_b)
|
||||
|
||||
# Gradient will not be exactly the same, especially under half-precision. And bf16 is
|
||||
# particularly lossy so need to lower tolerance a bit more than the default. Switch to
|
||||
# dtype torch.float or even torch.double to see that the diff is tiny - so the math is
|
||||
# correct, but accumulation error adds up. Alternatively making hidden_dim bigger makes the
|
||||
# divergence much smaller as well.
|
||||
torch_assert_close(grad_a1, grad_b1) #, rtol=1e-03, atol=1e-04)
|
||||
torch_assert_close(grad_a2, grad_b2) #, rtol=1e-03, atol=1e-04)
|
||||
|
||||
# C. model with tiled MLP using the generic version of the same via sequence_tiled_compute + SequenceTiledCompute
|
||||
torch.manual_seed(seed)
|
||||
SimpleMLP.forward = mlp_forward_sequence_tiled_compute
|
||||
model_c = MyModel(hidden_dim=hidden_dim).to(dtype)
|
||||
model_c, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model_c,
|
||||
model_parameters=model_c.parameters())
|
||||
|
||||
x_c = x.clone().detach().requires_grad_(True)
|
||||
y_c = y.clone().detach()
|
||||
loss_c = model_c(x_c, y_c)
|
||||
model_c.backward(loss_c)
|
||||
grad_c1 = get_grad(model_c.module.mlp1.up_proj.weight, zero_stage)
|
||||
grad_c2 = get_grad(model_c.module.mlp2.up_proj.weight, zero_stage)
|
||||
assert grad_c1 is not None
|
||||
assert grad_c2 is not None
|
||||
|
||||
# print(f"{loss_a=}")
|
||||
# print(f"{loss_c=}")
|
||||
# print(f"{grad_a1=}")
|
||||
# print(f"{grad_c1=}")
|
||||
# see notes for B
|
||||
torch_assert_equal(loss_a, loss_c)
|
||||
torch_assert_close(grad_a1, grad_c1) #, rtol=1e-03, atol=1e-04)
|
||||
torch_assert_close(grad_a2, grad_c2) #, rtol=1e-03, atol=1e-04)
|
187
tests/unit/ulysses_plus/test_ulysses_sp_hf.py
Normal file
187
tests/unit/ulysses_plus/test_ulysses_sp_hf.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright (c) The DeepSpeed Contributors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
"""
|
||||
UlyssesPlus: UlyssesSPHF tests
|
||||
"""
|
||||
|
||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter
|
||||
from deepspeed.runtime.utils import move_to_device
|
||||
from deepspeed.utils import groups
|
||||
from deepspeed.utils import safe_get_full_grad
|
||||
from torch import tensor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from unit.common import DistributedTest, preferred_dtype
|
||||
from unit.util import torch_assert_equal, torch_assert_close, torch_assert_dicts_of_tensors_equal
|
||||
import deepspeed
|
||||
import deepspeed.comm as dist
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def get_grad(param, zero_stage):
|
||||
return safe_get_full_grad(param)
|
||||
# z1 now has contiguous_gradients enabled by default so `param.grad is None` even under z1
|
||||
# if zero_stage == 1:
|
||||
# return param.grad
|
||||
# else:
|
||||
# return safe_get_full_grad(param)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [1, 3])
|
||||
class TestUlyssesSPHF(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
def test_ulysses_sp_hf(self, zero_stage):
|
||||
model_name_or_path = 'hf-internal-testing/tiny-random-LlamaForCausalLM'
|
||||
#model_name_or_path = 'Felladrin/Llama-160M-Chat-v1'
|
||||
max_length = 64
|
||||
sequence_parallel_size = self.world_size
|
||||
micro_batch_size = 1
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 1e-3
|
||||
}
|
||||
},
|
||||
"sequence_parallel_size": sequence_parallel_size,
|
||||
}
|
||||
|
||||
dtype = preferred_dtype()
|
||||
if dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
elif dtype == torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True, "loss_scale": 1.0}
|
||||
|
||||
# Part 1. Baseline: Setup
|
||||
def collate_fn(batch):
|
||||
input_ids, position_ids = batch[0]
|
||||
#print(f"{batch=}")
|
||||
return dict(input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids.unsqueeze(0),
|
||||
labels=input_ids.unsqueeze(0))
|
||||
|
||||
input_ids = tensor([[1, 10, 10, 10, 2, 2], [1, 20, 20, 20, 2, 2]], )
|
||||
position_ids = tensor([[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]])
|
||||
ds = torch.utils.data.TensorDataset(input_ids, position_ids)
|
||||
|
||||
# 1. Baseline: DataLoader calibration
|
||||
dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
|
||||
batch_a = next(iter(dl_a))
|
||||
#print(f"{rank=} {batch_a=}")
|
||||
expected_batch_a = {
|
||||
'input_ids': tensor([[1, 10, 10, 10, 2, 2]]),
|
||||
'position_ids': tensor([[0, 1, 2, 3, 4, 5]]),
|
||||
'labels': tensor([[1, 10, 10, 10, 2, 2]])
|
||||
}
|
||||
torch_assert_dicts_of_tensors_equal(batch_a, expected_batch_a)
|
||||
|
||||
# 2. Baseline: Attention
|
||||
model_a = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
model_a, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model_a,
|
||||
model_parameters=model_a.parameters(),
|
||||
mpu=None)
|
||||
batch_a = move_to_device(batch_a, model_a.device)
|
||||
loss_a = model_a(**batch_a).loss
|
||||
model_a.backward(loss_a)
|
||||
#print(f"{loss_a=}")
|
||||
|
||||
grad_a = get_grad(model_a.module.model.layers[0].self_attn.q_proj.weight, zero_stage)
|
||||
assert grad_a is not None
|
||||
#print(f"{grad_a}")
|
||||
|
||||
# Part 2. Ulysses: Setup
|
||||
mpu = UlyssesSPAttentionHF.register_with_transformers(
|
||||
model_name_or_path=model_name_or_path,
|
||||
core_attn_implementation="sdpa",
|
||||
sequence_parallel_size=sequence_parallel_size,
|
||||
max_length=max_length,
|
||||
micro_batch_size=micro_batch_size,
|
||||
seq_length_is_variable=True,
|
||||
)
|
||||
|
||||
model_b = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
model_b, _, _, _ = deepspeed.initialize(config=config_dict,
|
||||
model=model_b,
|
||||
model_parameters=model_b.parameters(),
|
||||
mpu=mpu)
|
||||
|
||||
# 3. Ulysses: UlyssesSPDataLoaderAdapter test
|
||||
sp_group = groups._get_sequence_parallel_group()
|
||||
sp_world_size = groups._get_sequence_parallel_world_size()
|
||||
sp_rank = groups._get_sequence_parallel_rank()
|
||||
dl_a = torch.utils.data.DataLoader(ds, batch_size=micro_batch_size, collate_fn=collate_fn)
|
||||
dl_b = UlyssesSPDataLoaderAdapter(
|
||||
dl_a,
|
||||
sp_rank=sp_rank,
|
||||
sp_group=sp_group,
|
||||
sp_world_size=sp_world_size,
|
||||
device=model_b.device,
|
||||
)
|
||||
batch_b = next(iter(dl_b))
|
||||
|
||||
expected_batch_b = [
|
||||
{
|
||||
'input_ids': tensor([[1, 10, 10]]),
|
||||
'position_ids': tensor([[0, 1, 2]]),
|
||||
'shift_labels': tensor([[10, 10, 10]]),
|
||||
},
|
||||
{
|
||||
'input_ids': tensor([[10, 2, 2]]),
|
||||
'position_ids': tensor([[3, 4, 5]]),
|
||||
'shift_labels': tensor([[2, 2, -100]]),
|
||||
},
|
||||
]
|
||||
|
||||
# here we expect each sample to be sharded in half, rank0 getting the first half and rank1 the other half
|
||||
#print(f"{sp_rank=} {batch_b=}")
|
||||
torch_assert_dicts_of_tensors_equal(batch_b, expected_batch_b[sp_rank])
|
||||
|
||||
# 4. UlyssesSPAttentionHF test
|
||||
batch_b = move_to_device(batch_b, model_b.device)
|
||||
outputs = model_b(**batch_b)
|
||||
# HF doesn't calculate loss with shift_labels yet and requires us to do it manually (liger does that)
|
||||
shift_labels = batch_b["shift_labels"]
|
||||
loss_b = model_b.module.loss_function(
|
||||
logits=outputs.logits,
|
||||
labels=None,
|
||||
shift_labels=shift_labels,
|
||||
vocab_size=model_b.module.config.vocab_size,
|
||||
)
|
||||
# print(f"{sp_rank=} {loss_b=}")
|
||||
|
||||
# differentiable weighted per-shard-loss aggregation across ranks
|
||||
losses_per_rank = torch.distributed.nn.functional.all_gather(loss_b, group=sp_group)
|
||||
good_tokens = sum((shift_labels != -100).view(-1))
|
||||
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
|
||||
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
|
||||
total_good_tokens = sum(good_tokens_per_rank)
|
||||
loss_b = total_loss / total_good_tokens
|
||||
# print(f"{sp_rank=} gathered {loss_b=}")
|
||||
model_b.backward(loss_b)
|
||||
|
||||
grad_b = get_grad(model_b.module.model.layers[0].self_attn.q_proj.weight, zero_stage)
|
||||
assert grad_b is not None
|
||||
#print(f"{grad_b}")
|
||||
|
||||
# compare loss of A (non-Ulysses Attention) and B (Ulyssses Attention)
|
||||
torch_assert_equal(loss_a, loss_b)
|
||||
|
||||
# - we are feeding the exact same sample to each rank of A
|
||||
# - for B we feed half the sample to each rank, but in total it's the same sample as each rank of A sees
|
||||
# thus we expect very similar grads (but not exact)
|
||||
if zero_stage in [1, 2]:
|
||||
# possibly some issue with z1/z2 as it requires higher tolerance than z3?
|
||||
torch_assert_close(grad_a, grad_b, rtol=1.6e-02, atol=1e-03)
|
||||
else:
|
||||
torch_assert_close(grad_a, grad_b)
|
@ -88,3 +88,34 @@ class no_child_process_in_deepspeed_io:
|
||||
|
||||
def __exit__(self, *_):
|
||||
deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = self.old_method
|
||||
|
||||
|
||||
def torch_assert_equal(actual, expected, **kwargs):
|
||||
"""
|
||||
Compare two tensors or non-tensor numbers for their equality.
|
||||
Add msg=blah to add an additional comment to when assert fails.
|
||||
"""
|
||||
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
|
||||
|
||||
|
||||
def torch_assert_close(actual, expected, **kwargs):
|
||||
"""
|
||||
Compare two tensors or non-tensor numbers for their closeness.
|
||||
|
||||
Add msg=blah to add an additional comment to when assert fails.
|
||||
|
||||
For default values of `rtol` and `atol` which are dtype dependent, see the table at https://docs.pytorch.org/docs/stable/testing.html#torch.testing.assert_close
|
||||
For example for bf16 it is `rtol=1.6e-2` and `atol=1e-5`.
|
||||
|
||||
The check doesn't assert when `|a - b| <= (atol + rtol * |b|)`
|
||||
"""
|
||||
return torch.testing.assert_close(actual, expected, **kwargs)
|
||||
|
||||
|
||||
def torch_assert_dicts_of_tensors_equal(actual, expected, **kwargs):
|
||||
"""
|
||||
Compare two dicts of tensors or non-tensor numbers for their equality.
|
||||
Add msg=blah to add an additional comment to when assert fails.
|
||||
"""
|
||||
for k in actual.keys():
|
||||
torch.testing.assert_close(actual[k], expected[k], rtol=0.0, atol=0.0, **kwargs)
|
||||
|
Reference in New Issue
Block a user